--- /dev/null
+(*
+ * Copyright (C) 2009 Citrix Ltd.
+ * Author Prashanth Mundkur <firstname.lastname@citrix.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU Lesser General Public License as published
+ * by the Free Software Foundation; version 2.1 only. with the special
+ * exception on linking described in file LICENSE.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Lesser General Public License for more details.
+ *)
+
+let verbose = ref true
+
+let dbg fmt =
+ let logger s = if !verbose then Printf.printf "%s\n" s in
+ Printf.ksprintf logger fmt
+
+type error = Unix.error * string * string
+
+type handle = Unix.file_descr
+
+type conn_callbacks =
+{
+ accept_callback : handle -> Unix.file_descr -> Unix.sockaddr -> unit;
+ connect_callback : handle -> unit;
+ recv_callback : handle -> string -> (* offset *) int -> (* length *) int -> unit;
+ send_done_callback : handle -> unit;
+ closed_callback : handle -> unit;
+ error_callback : handle -> error -> unit;
+}
+
+type conn_status =
+ | Connecting
+ | Listening
+ | Connected
+
+type conn_state =
+{
+ callbacks : conn_callbacks;
+ mutable status : conn_status;
+ mutable send_done_enabled : bool;
+ mutable recv_enabled : bool;
+
+ send_buf : Buffer.t;
+}
+
+module ConnMap = Map.Make (struct type t = Unix.file_descr let compare = compare end)
+
+(* A module that supports finding a timer by handle as well as by expiry time. *)
+module Timers = struct
+ type handle = int
+
+ type 'a entry =
+ {
+ handle: handle;
+ expires_at: float;
+ value: 'a;
+ }
+
+ module Timers_by_handle = Map.Make (struct type t = handle let compare = compare end)
+ module Timers_by_expiry = Map.Make (struct type t = float let compare = compare end)
+
+ type 'a t =
+ {
+ mutable by_handle: ('a entry) Timers_by_handle.t;
+ mutable by_expiry: (('a entry) list) Timers_by_expiry.t;
+ }
+
+ let create () = { by_handle = Timers_by_handle.empty;
+ by_expiry = Timers_by_expiry.empty }
+
+ let is_empty t = Timers_by_handle.is_empty t.by_handle
+
+ let next_handle = ref 0
+
+ let add_timer t at v =
+ incr next_handle;
+ let e = { handle = !next_handle; expires_at = at; value = v } in
+ t.by_handle <- Timers_by_handle.add e.handle e t.by_handle;
+ let es = try Timers_by_expiry.find e.expires_at t.by_expiry with Not_found -> [] in
+ t.by_expiry <- Timers_by_expiry.add e.expires_at (e :: es) t.by_expiry;
+ e.handle
+
+ let remove_timer t handle =
+ let e = Timers_by_handle.find handle t.by_handle in
+ let es = Timers_by_expiry.find e.expires_at t.by_expiry in
+ let es = List.filter (fun e' -> e'.handle <> handle) es in
+ t.by_handle <- Timers_by_handle.remove handle t.by_handle;
+ t.by_expiry <- (match es with
+ | [] -> Timers_by_expiry.remove e.expires_at t.by_expiry
+ | _ -> Timers_by_expiry.add e.expires_at es t.by_expiry
+ )
+
+ exception Found of float
+
+ (* Should only be called on a non-empty Timer set; otherwise,
+ Not_found is raised. *)
+ let get_first_expiry_time t =
+ try
+ (* This should give the earliest expiry time,
+ since iteration is done in increasing order. *)
+ Timers_by_expiry.iter (fun tim -> raise (Found tim)) t.by_expiry;
+ raise Not_found
+ with Found tim -> tim
+
+ let extract_timers_at t tim =
+ try
+ let es = Timers_by_expiry.find tim t.by_expiry in
+ t.by_expiry <- Timers_by_expiry.remove tim t.by_expiry;
+ t.by_handle <- List.fold_left (fun byh e ->
+ Timers_by_handle.remove e.handle byh
+ ) t.by_handle es;
+ List.map (fun e -> e.value) es
+ with Not_found -> []
+end
+
+type timer_callbacks =
+{
+ expiry_callback : unit -> unit
+}
+
+type t =
+{
+ mutable conns: conn_state ConnMap.t;
+ mutable timers: timer_callbacks Timers.t;
+ (* select state *)
+ readers: Unixext.Fdset.t;
+ writers: Unixext.Fdset.t;
+ excepts: Unixext.Fdset.t;
+}
+
+let create () =
+{ conns = ConnMap.empty;
+ timers = Timers.create ();
+ readers = Unixext.Fdset.create ();
+ writers = Unixext.Fdset.create ();
+ excepts = Unixext.Fdset.create ();
+}
+
+(* connections *)
+
+let register_conn t fd ?(enable_send_done=false) ?(enable_recv=true) callbacks =
+ let conn_state = { callbacks = callbacks;
+ status = Connected;
+ send_done_enabled = enable_send_done;
+ recv_enabled = enable_recv;
+ send_buf = Buffer.create 16;
+ }
+ in
+ t.conns <- ConnMap.add fd conn_state t.conns;
+ Unix.set_nonblock fd;
+ if conn_state.recv_enabled then
+ Unixext.Fdset.set t.readers fd;
+ fd
+
+let remove_conn t handle =
+ Unixext.Fdset.clear t.readers handle;
+ Unixext.Fdset.clear t.writers handle;
+ t.conns <- ConnMap.remove handle t.conns
+
+let connect t handle addr =
+ let conn_state = ConnMap.find handle t.conns in
+ conn_state.status <- Connecting;
+ try
+ Unix.connect handle addr;
+ conn_state.status <- Connected;
+ conn_state.callbacks.connect_callback handle
+ with
+ | Unix.Unix_error (Unix.EINPROGRESS, _, _) ->
+ Unixext.Fdset.set t.readers handle;
+ Unixext.Fdset.set t.writers handle
+ | Unix.Unix_error (ec, f, s) ->
+ conn_state.callbacks.error_callback handle (ec, f, s)
+
+let listen t handle =
+ let conn_state = ConnMap.find handle t.conns in
+ Unixext.Fdset.set t.readers handle;
+ conn_state.recv_enabled <- true;
+ conn_state.status <- Listening
+
+let enable_send_done t handle =
+ let conn_state = ConnMap.find handle t.conns in
+ conn_state.send_done_enabled <- true
+
+let disable_send_done t handle =
+ let conn_state = ConnMap.find handle t.conns in
+ conn_state.send_done_enabled <- false
+
+let enable_recv t handle =
+ let conn_state = ConnMap.find handle t.conns in
+ conn_state.recv_enabled <- true;
+ if conn_state.status = Connected then
+ Unixext.Fdset.set t.readers handle
+
+let disable_recv t handle =
+ let conn_state = ConnMap.find handle t.conns in
+ conn_state.recv_enabled <- false;
+ if conn_state.status = Connected then
+ Unixext.Fdset.clear t.readers handle
+
+let send t handle s =
+ let conn_state = ConnMap.find handle t.conns in
+ Buffer.add_string conn_state.send_buf s;
+ Unixext.Fdset.set t.writers handle
+
+let has_pending_send t handle =
+ let conn_state = ConnMap.find handle t.conns in
+ Buffer.length conn_state.send_buf > 0
+
+(* timers *)
+
+type timer = Timers.handle
+
+let start_timer t interval callback =
+ let at = Unix.gettimeofday () +. interval in
+ Timers.add_timer t.timers at callback
+
+let cancel_timer t timer =
+ Timers.remove_timer t.timers timer
+
+(* event dispatch *)
+
+let buf = String.create 512
+let buflen = String.length buf
+
+let dispatch_read t fd cs =
+ match cs.status with
+ | Connecting ->
+ (match Unix.getsockopt_error fd with
+ | None ->
+ cs.status <- Connected;
+ if not cs.recv_enabled then
+ Unixext.Fdset.clear t.readers fd;
+ cs.callbacks.connect_callback fd
+ | Some err ->
+ cs.callbacks.error_callback fd (err, "connect", "")
+ )
+ | Listening ->
+ (try
+ let afd, aaddr = Unix.accept fd in
+ cs.callbacks.accept_callback fd afd aaddr
+ with
+ | Unix.Unix_error (Unix.EWOULDBLOCK, _, _)
+ | Unix.Unix_error (Unix.ECONNABORTED, _, _)
+ | Unix.Unix_error (Unix.EINTR, _, _)
+ -> ()
+ )
+ | Connected ->
+ if cs.recv_enabled then
+ try
+ let read_bytes = Unix.read fd buf 0 buflen in
+ if read_bytes = 0 then
+ cs.callbacks.closed_callback fd
+ else
+ cs.callbacks.recv_callback fd buf 0 read_bytes
+ with
+ | Unix.Unix_error (Unix.EWOULDBLOCK, _, _)
+ | Unix.Unix_error (Unix.EAGAIN, _, _)
+ | Unix.Unix_error (Unix.EINTR, _, _) ->
+ ()
+ else
+ Unixext.Fdset.clear t.readers fd
+let do_send fd cs =
+ let payload = Buffer.contents cs.send_buf in
+ let payload_len = String.length payload in
+ match Unix.write fd payload 0 payload_len with
+ | 0 -> ()
+ | sent ->
+ Buffer.clear cs.send_buf;
+ Buffer.add_substring cs.send_buf payload sent (payload_len - sent)
+
+let dispatch_write t fd cs =
+ match cs.status with
+ | Connecting ->
+ (match Unix.getsockopt_error fd with
+ | None ->
+ cs.status <- Connected;
+ if cs.recv_enabled then
+ Unixext.Fdset.set t.readers fd
+ else
+ Unixext.Fdset.clear t.readers fd;
+ cs.callbacks.connect_callback fd
+ | Some err ->
+ cs.callbacks.error_callback fd (err, "connect", "")
+ )
+ | Listening ->
+ (* This should never happen, since listening sockets
+ are not set for writing. But, to avoid a busy
+ select loop in case this socket keeps firing for
+ writes, we disable the write watch. *)
+ Unixext.Fdset.clear t.writers fd
+ | Connected ->
+ do_send fd cs;
+ if Buffer.length cs.send_buf = 0 then begin
+ if cs.send_done_enabled then
+ cs.callbacks.send_done_callback fd;
+ Unixext.Fdset.clear t.writers fd
+ end
+
+let dispatch_timers t current_time =
+ let break = ref false in
+ while (not (Timers.is_empty t.timers) && not !break) do
+ let first_expired = Timers.get_first_expiry_time t.timers in
+ if first_expired > current_time then
+ break := true
+ else begin
+ let cbs = Timers.extract_timers_at t.timers first_expired in
+ List.iter (fun cb -> cb.expiry_callback ()) cbs
+ end
+ done
+
+let dispatch t interval =
+ let ctime = Unix.gettimeofday () in
+ let interval =
+ if Timers.is_empty t.timers then interval
+ else
+ (* the blocking interval for select is the
+ smaller of the specified interval, and the
+ interval before which the earliest timer
+ expires.
+ *)
+ let block_until = if interval > 0.0 then ctime +. interval else ctime in
+ let first_expiry = Timers.get_first_expiry_time t.timers in
+ let block_until = (if first_expiry < block_until then first_expiry else block_until) in
+ let interval = block_until -. ctime in
+ if interval < 0.0 then 0.0 else interval
+ in
+ let events =
+ try Some (Unixext.Fdset.select t.readers t.writers t.excepts interval)
+ with Unix.Unix_error (Unix.EINTR, _, _) -> None
+ in
+ (match events with
+ | Some (r, w, _) ->
+ ConnMap.iter (fun fd cs ->
+ if Unixext.Fdset.is_set r fd then
+ dispatch_read t fd cs;
+ if Unixext.Fdset.is_set w fd then
+ dispatch_write t fd cs
+ ) t.conns
+ | None -> ()
+ );
+
+ dispatch_timers t ctime