janestreet / async_extra

Jane Street Capital's asynchronous execution library (extra)
MIT License
15 stars 8 forks source link

Tcp.Server.create with pre-bound socket #5

Closed madroach closed 10 months ago

madroach commented 8 years ago

Hi,

I usually run my servers as non-root users. To listen on privileged ports (<1024) the server needs root privileges when binding. This is not possible with Tcp.Server.create, because I cannot pass an already bound port to it. Below you find a patch adding this feature without disturbing the interface of the Tcp module and a Cohttp server demonstrating how to use it. The new stuff happens in the call to Tcp.Where_to_listen.bind in the start_server function.

let setgroups = Unix.setgroups (* not available in Core.Std *)

open Core.Std
open Async.Std
open Cohttp_async

let drop_privileges ?user ?group ?chroot () =
  (* override Async.Std.Unix - I need strictly sequential evaluation here *)
  let module Unix = Core.Std.Unix in
  let uid = Option.map user (fun user ->
      try (Unix.Passwd.getbyname_exn user).Unix.Passwd.uid with _ ->
        failwith ("Could not get UID of " ^ user))
  in
  let gid = Option.map group (fun group ->
      try (Unix.Group.getbyname_exn group).Unix.Group.gid with _ ->
        failwith ("Could not get GID of " ^ group))
  in
  Option.iter ~f:Unix.chroot chroot;
  Unix.chdir "/";
  setgroups [||];
  Option.iter ~f:Unix.setgid gid;
  Option.iter ~f:Unix.setuid uid;
  (* Check whether priveleges are really dropped permanently *)
  (try Unix.setgid 0 with _ -> ());
  (try Unix.setuid 0 with _ -> ());
  if
    let comp_opt o x = Option.for_all ~f:((=) x) o in
    Array.length (Unix.getgroups ()) = 0
    && comp_opt uid (Unix.getuid ()) && comp_opt uid (Unix.geteuid ())
    && comp_opt gid (Unix.getgid ()) && comp_opt gid (Unix.getegid ())
  then ()
  else
    failwith "Could not drop privileges";
;;

let handler ~body:_ _sock req =
  let uri = Cohttp.Request.uri req in
  match Uri.path uri with
  | "/test" ->
       Uri.get_query_param uri "hello"
    |> Option.map ~f:(fun v -> "hello: " ^ v)
    |> Option.value ~default:"No param hello supplied"
    |> Server.respond_with_string
  | _ ->
    Server.respond_with_string ~code:`Not_found "Route not found"
;;

let start_server port () =
  Tcp.Where_to_listen.bind (Tcp.on_port port)
  >>= fun listen_on ->
  drop_privileges ~user:"www" ~group:"www" ~chroot:"/var/www" ();
  eprintf "Listening for HTTP on port %d\n" port;
  eprintf "Try 'curl http://localhost:%d/test?hello=xyz'\n%!" port;
  Cohttp_async.Server.create ~on_handler_error:`Raise listen_on handler
  >>= fun _ -> Deferred.never ()
;;

let () =
  Command.async_basic
    ~summary:"Start a hello world Async server"
    Command.Spec.(empty +>
      flag "-p" (optional_with_default 80 int)
        ~doc:"int Source port to listen on"
    ) start_server

  |> Command.run
--- async_extra/src/tcp.mli Sat Oct 10 19:55:23 2015
+++ async_extra.113.00.00/src/tcp.mli   Sat Oct 10 19:51:03 2015
@@ -87,12 +87,18 @@
     -> listening_on : ('address -> 'listening_on)
     -> ('address, 'listening_on) t

+  val bind : ('address, 'listening_on) t -> ('address, 'listening_on) t Deferred.t
+
   val address : ('address, _) t -> 'address
 end

 val on_port              : int ->    Where_to_listen.inet
 val on_port_chosen_by_os :           Where_to_listen.inet
 val on_file              : string -> Where_to_listen.unix
+val on_socket
+  :  Socket.Address.t Socket.Type.t
+  -> ([ `Bound ], Socket.Address.t) Socket.t
+  -> (Socket.Address.t, string) Where_to_listen.t

 (** A [Server.t] represents a TCP server listening on a socket. *)
 module Server : sig
--- async_extra/src/tcp.ml  Sat Oct 10 19:55:23 2015
+++ async_extra.113.00.00/src/tcp.ml    Sat Oct 10 19:51:03 2015
@@ -151,7 +151,8 @@
 module Where_to_listen = struct

   type ('address, 'listening_on) t =
-    { socket_type  : 'address Socket.Type.t
+    { socket       : ([ `Bound ], 'address) Socket.t Option.t
+    ; socket_type  : 'address Socket.Type.t
     ; address      : 'address
     ; listening_on : ('address -> 'listening_on) sexp_opaque
     }
@@ -161,13 +162,23 @@
   type unix = (Socket.Address.Unix.t, string) t with sexp_of

   let create ~socket_type ~address ~listening_on =
-    { socket_type; address; listening_on }
+    { socket = None; socket_type; address; listening_on }
+
+  let bind = function
+    | { socket = Some _; _ } as where_to_listen -> return where_to_listen
+    | { socket = None; socket_type; address; _ } as where_to_listen ->
+      let socket = create_socket socket_type in
+      close_sock_on_error socket (fun () ->
+          Socket.setopt socket Socket.Opt.reuseaddr true;
+          Socket.bind socket address)
+      >>| (fun socket -> { where_to_listen with socket = Some socket })
   ;;
 end

 let on_port port =
   { Where_to_listen.
-    socket_type  = Socket.Type.tcp
+    socket       = None
+  ; socket_type  = Socket.Type.tcp
   ; address      = Socket.Address.Inet.create_bind_any ~port
   ; listening_on = function `Inet (_, port) -> port
   }
@@ -178,12 +189,21 @@

 let on_file path =
   { Where_to_listen.
-    socket_type  = Socket.Type.unix
+    socket       = None
+  ; socket_type  = Socket.Type.unix
   ; address      = Socket.Address.Unix.create path
   ; listening_on = fun _ -> path
   }
 ;;

+let on_socket socket_type socket =
+  { Where_to_listen.
+    socket       = Some socket
+  ; socket_type
+  ; address      = Socket.getsockname socket
+  ; listening_on = fun address -> Socket.Address.to_string address
+  }
+
 module Server = struct

   type ('address, 'listening_on)  t =
@@ -297,11 +317,11 @@
     if max_connections <= 0
     then failwiths "Tcp.Server.creater got negative [max_connections]" max_connections
            sexp_of_int;
-    let socket = create_socket where_to_listen.socket_type in
+    Where_to_listen.bind where_to_listen
+    >>= fun where_to_listen ->
+    let socket = Option.value_exn ~message:"not reached" where_to_listen.socket in
     close_sock_on_error socket (fun () ->
-      Socket.setopt socket Socket.Opt.reuseaddr true;
-      Socket.bind socket where_to_listen.address
-      >>| Socket.listen ?max_pending_connections)
+        return (Socket.listen ?max_pending_connections socket))
     >>| fun socket ->
     let t =
       { socket