diff --git a/.github/workflows/workflow.yml b/.github/workflows/workflow.yml index 13bed6d83..9cf92737b 100644 --- a/.github/workflows/workflow.yml +++ b/.github/workflows/workflow.yml @@ -15,12 +15,6 @@ jobs: os: - ubuntu-latest ocaml-compiler: - - "4.08" - - "4.09" - - "4.10" - - "4.11" - - "4.12" - - "4.13" - "4.14" - "5.0" - "5.1" @@ -43,24 +37,42 @@ jobs: runs-on: ${{ matrix.os }} steps: - - name: set ppx-related variables - id: configppx + - name: set version-dependent variables + id: configpkgs shell: bash run: | + opampkgs="./lwt.opam ./lwt_react.opam ./lwt_retry.opam ./lwt_ppx.opam" + dunepkgs="lwt,lwt_react,lwt_retry,lwt_ppx" case ${{ matrix.ocaml-compiler }} in - "4.08"|"4.09"|"4.10"|"4.11"|"4.12"|"4.13"|"4.14"|"5.0") - echo "letppx=false" - echo "letppx=false" >> "$GITHUB_OUTPUT" + "4.14"|"5.0") + : ;; "5.1"|"5.2"|"5.3") - echo "letppx=true" - echo "letppx=true" >> "$GITHUB_OUTPUT" + opampkgs="${opampkgs} ./lwt_ppx__ppx_let_tests.opam" + dunepkgs="${dunepkgs},lwt_ppx__ppx_let_tests" ;; *) printf "unrecognised version %s\n" "${{ matrix.ocaml-compiler }}"; exit 1 ;; esac + case ${{ matrix.ocaml-compiler }} in + "4.14") + : + ;; + "5.0"|"5.1"|"5.2"|"5.3") + opampkgs="${opampkgs} ./lwt_direct.opam" + dunepkgs="${dunepkgs},lwt_direct" + ;; + *) + printf "unrecognised version %s\n" "${{ matrix.ocaml-compiler }}"; + exit 1 + ;; + esac + echo "opampkgs=${opampkgs}" + echo "opampkgs=${opampkgs}" >> "$GITHUB_OUTPUT" + echo "dunepkgs=${dunepkgs}" + echo "dunepkgs=${dunepkgs}" >> "$GITHUB_OUTPUT" - name: Checkout tree uses: actions/checkout@v5 @@ -73,20 +85,11 @@ jobs: - run: opam install conf-libev if: ${{ matrix.libev == true }} - - run: opam install ./lwt.opam ./lwt_react.opam ./lwt_retry.opam ./lwt_ppx.opam --deps-only --with-test - - - run: opam install ./lwt_ppx__ppx_let_tests.opam --deps-only --with-test - if: ${{ fromJSON(steps.configppx.outputs.letppx) }} - - - run: opam exec -- dune build --only-packages lwt,lwt_react,lwt_retry - - - run: opam exec -- dune build --only-packages lwt,lwt_ppx__ppx_let_tests - if: ${{ fromJSON(steps.configppx.outputs.letppx) }} + - run: opam install --deps-only --with-test ${{ steps.configpkgs.outputs.opampkgs }} - - run: opam exec -- dune runtest --only-packages lwt,lwt_react,lwt_retry,lwt_ppx + - run: opam exec -- dune build --only-packages ${{ steps.configpkgs.outputs.dunepkgs }} - - run: opam exec -- dune runtest --only-packages lwt,lwt_ppx__ppx_let_tests - if: ${{ fromJSON(steps.configppx.outputs.letppx) }} + - run: opam exec -- dune runtest --only-packages ${{ steps.configpkgs.outputs.dunepkgs }} lint-opam: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 2663c06e3..6265df9b5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,10 +5,6 @@ src/unix/discover_arguments # OPAM 2.0 local switches. _opam -# Coverage analysis. -bisect*.out -_coverage/ - # For local work, tests, etc. scratch/ diff --git a/CHANGES b/CHANGES index dd3106ca6..607b26678 100644 --- a/CHANGES +++ b/CHANGES @@ -1,4 +1,14 @@ -===== 5.10.0 ===== +===== 6.0.0+dev ===== + +====== Additions ====== + + * Lwt_direct using Lwt in direct-style. (Simon Cruanes, #1060) + + * Support multiple scheduler running in parallel in separate domains. + + * Exception filter defaults to letting systems exceptions through. + +===== 5.10.0+dev ===== ====== Misc diff --git a/Makefile b/Makefile index 022597d5e..ec14c02dc 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,6 @@ clean : dune clean rm -fr docs/api rm -f src/unix/discover_arguments - rm -rf _coverage/ EXPECTED_FILES := \ --expect src/core/ \ @@ -65,10 +64,3 @@ EXPECTED_FILES := \ --do-not-expect src/unix/lwt_gc.ml \ --do-not-expect src/unix/lwt_throttle.ml \ --do-not-expect src/unix/unix_c/ - -.PHONY: coverage -coverage : - dune runtest --instrument-with bisect_ppx --force - bisect-ppx-report html $(EXPECTED_FILES) - bisect-ppx-report summary - @echo See _coverage/index.html diff --git a/dune-project b/dune-project index 29b39f2c6..a64d7e51e 100644 --- a/dune-project +++ b/dune-project @@ -52,9 +52,19 @@ (lwt (>= 3.0)) (react (>= 1.0)))) +(package + (name lwt_direct) + (version 6.0.0~beta00) + (synopsis "Direct-style control-flow and `await` for Lwt") + (authors "Simon Cruanes") + (depends + (ocaml (>= 5.0)) + base-unix + (lwt (>= 6)))) + (package (name lwt) - (version 5.9.2+dev) + (version 6.0.0~beta00) (synopsis "Promises and event-driven I/O") (description "A promise is a value that may become determined in the future. @@ -66,7 +76,8 @@ a single thread by default. This reduces the need for locks or other synchronization primitives. Code can be run in parallel on an opt-in basis. ") (depends - (ocaml (>= 4.08)) + (ocaml (>= 4.14)) + domain_shims (cppo (and :build (>= 1.1))) (ocamlfind (and :dev (>= 1.7.3-1))) (odoc (and :with-doc (>= 2.3))) diff --git a/lwt.opam b/lwt.opam index 3da6be031..da564c97f 100644 --- a/lwt.opam +++ b/lwt.opam @@ -1,6 +1,6 @@ # This file is generated by dune, edit dune-project instead opam-version: "2.0" -version: "5.9.2+dev" +version: "6.0.0~beta00" synopsis: "Promises and event-driven I/O" description: """ A promise is a value that may become determined in the future. @@ -22,7 +22,8 @@ doc: "https://ocsigen.org/lwt" bug-reports: "https://github.com/ocsigen/lwt/issues" depends: [ "dune" {>= "3.15"} - "ocaml" {>= "4.08"} + "ocaml" {>= "4.14"} + "domain_shims" "cppo" {build & >= "1.1"} "ocamlfind" {dev & >= "1.7.3-1"} "odoc" {with-doc & >= "2.3"} diff --git a/lwt_direct.opam b/lwt_direct.opam new file mode 100644 index 000000000..064db4c20 --- /dev/null +++ b/lwt_direct.opam @@ -0,0 +1,34 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +version: "6.0.0~beta00" +synopsis: "Direct-style control-flow and `await` for Lwt" +maintainer: [ + "Raphaël Proust " "Anton Bachin " +] +authors: ["Simon Cruanes"] +license: "MIT" +homepage: "https://github.com/ocsigen/lwt" +doc: "https://ocsigen.org/lwt" +bug-reports: "https://github.com/ocsigen/lwt/issues" +depends: [ + "dune" {>= "3.15"} + "ocaml" {>= "5.0"} + "base-unix" + "lwt" {>= "6"} + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +dev-repo: "git+https://github.com/ocsigen/lwt.git" diff --git a/src/core/domain_map.ml b/src/core/domain_map.ml new file mode 100644 index 000000000..f80232f19 --- /dev/null +++ b/src/core/domain_map.ml @@ -0,0 +1,57 @@ +module Domain_map : Map.S with type key = Domain.id = Map.Make(struct + type t = Domain.id + let compare d1 d2 = Int.compare (d1 : Domain.id :> int) (d2 : Domain.id :> int) +end) + +(* Protected domain map reference with per-reference mutex *) +type 'a protected_map = { + mutex : Mutex.t; + mutable map : 'a Domain_map.t; +} + +let create_protected_map () = { + mutex = Mutex.create (); + map = Domain_map.empty; +} + +let with_lock protected_map f = + Mutex.lock protected_map.mutex; + Fun.protect f ~finally:(fun () -> Mutex.unlock protected_map.mutex) + +let update_map protected_map f = + with_lock protected_map (fun () -> + let old_map = protected_map.map in + let new_map = f old_map in + protected_map.map <- new_map) + +let add protected_map key value = + update_map protected_map (Domain_map.add key value) + +let remove protected_map key = + update_map protected_map (Domain_map.remove key) + +let update protected_map key f = + update_map protected_map (Domain_map.update key f) + +let find protected_map key = + with_lock protected_map (fun () -> Domain_map.find_opt key protected_map.map) + +let extract protected_map key = + with_lock protected_map (fun () -> + match Domain_map.find_opt key protected_map.map with + | None -> None + | Some v -> + protected_map.map <- Domain_map.remove key protected_map.map; + Some v) + +let size protected_map = + with_lock protected_map (fun () -> Domain_map.cardinal protected_map.map) + +let init protected_map key init_value = + with_lock protected_map (fun () -> + match Domain_map.find_opt key protected_map.map with + | Some existing -> existing + | None -> + let new_value = init_value () in + protected_map.map <- Domain_map.add key new_value protected_map.map; + new_value) diff --git a/src/core/domain_map.mli b/src/core/domain_map.mli new file mode 100644 index 000000000..008bc81be --- /dev/null +++ b/src/core/domain_map.mli @@ -0,0 +1,38 @@ +(** Domain-indexed maps with thread-safe operations + + Only intended to use internally, not for general release. + + Note that these function use a lock. A single lock. + - Probably not optimal + - Deadlock if you call one of those functions inside another (e.g., use + `init` rather than `find`+`update` + *) + +(** Thread-safe wrapper for domain maps *) +type 'a protected_map + +(** Create a new protected map with an empty map inside and a dedicated mutex, + the map is keyed on domain ids, and operations are synchronised via a mutex. + *) +val create_protected_map : unit -> 'a protected_map + +(** Add a key-value binding to the map *) +val add : 'a protected_map -> Domain.id -> 'a -> unit + +(** Remove a key from the map *) +val remove : 'a protected_map -> Domain.id -> unit + +(** Update a binding using the underlying map's update function *) +val update : 'a protected_map -> Domain.id -> ('a option -> 'a option) -> unit + +(** Find a value by key, returning None if not found *) +val find : 'a protected_map -> Domain.id -> 'a option + +(** Find + remove but hit the mutex only once *) +val extract : 'a protected_map -> Domain.id -> 'a option + +(** Get the number of bindings in the map *) +val size : 'a protected_map -> int + +(** Initialize a key with a value if it doesn't exist, return existing or new value *) +val init : 'a protected_map -> Domain.id -> (unit -> 'a) -> 'a diff --git a/src/core/dune b/src/core/dune index dab7ccc8d..9eaaf5be0 100644 --- a/src/core/dune +++ b/src/core/dune @@ -2,8 +2,7 @@ (public_name lwt) (synopsis "Monadic promises and concurrent I/O") (wrapped false) - (instrumentation - (backend bisect_ppx))) + (libraries domain_shims)) (documentation (package lwt)) diff --git a/src/core/lwt.ml b/src/core/lwt.ml index 257134c63..bfa1f4bc3 100644 --- a/src/core/lwt.ml +++ b/src/core/lwt.ml @@ -364,7 +364,63 @@ module Storage_map = end) type storage = (unit -> unit) Storage_map.t +module Multidomain_sync = struct + + (* callback_exchange is a domain-indexed map for storing callbacks that + different domains should execute. This is used when a domain d1 resolves a + promise on which a different domain d2 has attached callbacks (implicitely + via bind etc. or explicitly via on_success etc.). When this happens, the + domain resolving the promise calls its local callbacks and sends the other + domains' callbacks into the callback exchange *) + let callback_exchange = Domain_map.create_protected_map () + + (* notification_map is a domain-indexed map for waking sleeping domains. each + (should) domain registers a notification (see Lwt_unix) into the map when it + starts its scheduler. other domains can wake the domain up to indicate that + callbacks are available to be called *) + let notification_map = Domain_map.create_protected_map () + + (* send_callback d cb adds the callback cb into the callback_exchange and pings + the domain d via the notification_map *) + let send_callback d cb = + Domain_map.update + callback_exchange + d + (function + | None -> + let cbs = Lwt_sequence.create () in + let _ : (unit -> unit) Lwt_sequence.node = Lwt_sequence.add_l cb cbs in + Some cbs + | Some cbs -> + let _ : (unit -> unit) Lwt_sequence.node = Lwt_sequence.add_l cb cbs in + Some cbs); + begin match Domain_map.find notification_map d with + | None -> + failwith "ERROR: domain didn't register at startup" + | Some n -> + n () + end + (* get_sent_callbacks gets a domain's own callback from the callbasck exchange, + this is so that the notification handler installed by main.run can obtain the + callbacks that have been sent its way *) + let get_sent_callbacks domain_id = + match Domain_map.extract callback_exchange domain_id with + | None -> Lwt_sequence.create () + | Some cbs -> cbs + + (* register_notification adds a domain's own notification (see Lwt_unix) into + the notification map *) + let register_notification d n = + Domain_map.update notification_map d (function + | None -> Some n + | Some _ -> failwith "already registered!!") + + let is_alredy_registered d = + match Domain_map.find notification_map d with + | Some _ -> true + | None -> false +end module Main_internal_types = struct @@ -452,9 +508,9 @@ struct | Regular_callback_list_concat of 'a regular_callback_list * 'a regular_callback_list | Regular_callback_list_implicitly_removed_callback of - 'a regular_callback + Domain.id * 'a regular_callback | Regular_callback_list_explicitly_removable_callback of - 'a regular_callback option ref + Domain.id * 'a regular_callback option ref and _ cancel_callback_list = | Cancel_callback_list_empty : @@ -463,10 +519,10 @@ struct 'a cancel_callback_list * 'a cancel_callback_list -> 'a cancel_callback_list | Cancel_callback_list_callback : - storage * cancel_callback -> + Domain.id * storage * cancel_callback -> _ cancel_callback_list | Cancel_callback_list_remove_sequence_node : - ('a, _, _) promise Lwt_sequence.node -> + Domain.id * ('a, _, _) promise Lwt_sequence.node -> 'a cancel_callback_list (* Notes: @@ -716,11 +772,9 @@ module Exception_filter = struct | Out_of_memory -> false | Stack_overflow -> false | _ -> true - let v = - (* Default value: the legacy behaviour to avoid breaking programs *) - ref handle_all - let set f = v := f - let run e = !v e + let v = Atomic.make handle_all_except_runtime + let set f = Atomic.set v f + let run e = (Atomic.get v) e end module Sequence_associated_storage : @@ -730,9 +784,12 @@ sig val new_key : unit -> _ key val get : 'v key -> 'v option val with_value : 'v key -> 'v option -> (unit -> 'b) -> 'b + val get_from_storage : 'v key -> storage -> 'v option + val modify_storage : 'v key -> 'v option -> storage -> storage + val empty_storage : storage (* Internal interface *) - val current_storage : storage ref + val current_storage : storage Domain.DLS.key end = struct (* The idea behind sequence-associated storage is to preserve some values @@ -766,44 +823,48 @@ struct mutable value : 'v option; } - let next_key_id = ref 0 + let next_key_id = Atomic.make 0 let new_key () = - let id = !next_key_id in - next_key_id := id + 1; + let id = Atomic.fetch_and_add next_key_id 1 in {id = id; value = None} - let current_storage = ref Storage_map.empty + (* generic storage *) + let empty_storage = Storage_map.empty - let get key = - if Storage_map.mem key.id !current_storage then begin - let refresh = Storage_map.find key.id !current_storage in + let get_from_storage key storage = + match Storage_map.find_opt key.id storage with + | Some refresh -> refresh (); let value = key.value in key.value <- None; value - end - else - None + | None -> None + + let modify_storage key value storage = + match value with + | Some _ -> + let refresh = fun () -> key.value <- value in + Storage_map.add key.id refresh storage + | None -> + Storage_map.remove key.id storage + + (* built-in storage: propagated by bind and such *) + let current_storage = Domain.DLS.new_key (fun () -> Storage_map.empty) + + let get key = get_from_storage key (Domain.DLS.get current_storage) let with_value key value f = - let new_storage = - match value with - | Some _ -> - let refresh = fun () -> key.value <- value in - Storage_map.add key.id refresh !current_storage - | None -> - Storage_map.remove key.id !current_storage - in + let new_storage = modify_storage key value (Domain.DLS.get current_storage) in - let saved_storage = !current_storage in - current_storage := new_storage; + let saved_storage = (Domain.DLS.get current_storage) in + Domain.DLS.set current_storage new_storage; try let result = f () in - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; result with exn when Exception_filter.run exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; raise exn end include Sequence_associated_storage @@ -840,10 +901,10 @@ struct (* In a callback list, filters out cells of explicitly removable callbacks that have been removed. *) let rec clean_up_callback_cells = function - | Regular_callback_list_explicitly_removable_callback {contents = None} -> + | Regular_callback_list_explicitly_removable_callback (_, {contents = None}) -> Regular_callback_list_empty - | Regular_callback_list_explicitly_removable_callback {contents = Some _} + | Regular_callback_list_explicitly_removable_callback (_, {contents = Some _}) | Regular_callback_list_implicitly_removed_callback _ | Regular_callback_list_empty as callbacks -> callbacks @@ -954,7 +1015,7 @@ struct let add_implicitly_removed_callback callbacks f = add_regular_callback_list_node - callbacks (Regular_callback_list_implicitly_removed_callback f) + callbacks (Regular_callback_list_implicitly_removed_callback (Domain.self (), f)) (* Adds [callback] as removable to each promise in [ps]. The first promise in [ps] to trigger [callback] removes [callback] from the other promises; this @@ -970,7 +1031,7 @@ struct f result in - let node = Regular_callback_list_explicitly_removable_callback cell in + let node = Regular_callback_list_explicitly_removable_callback (Domain.self (), cell) in ps |> List.iter (fun p -> let Internal p = to_internal_promise p in match (underlying p).state with @@ -991,7 +1052,7 @@ struct clear_explicitly_removable_callback_cell cell ~originally_added_to:ps let add_cancel_callback callbacks f = - let node = Cancel_callback_list_callback (!current_storage, f) in + let node = Cancel_callback_list_callback (Domain.self (), (Domain.DLS.get current_storage), f) in callbacks.cancel_callbacks <- match callbacks.cancel_callbacks with @@ -1166,12 +1227,23 @@ struct match fs with | Cancel_callback_list_empty -> iter_list rest - | Cancel_callback_list_callback (storage, f) -> - current_storage := storage; - handle_with_async_exception_hook f (); - iter_list rest - | Cancel_callback_list_remove_sequence_node node -> - Lwt_sequence.remove node; + | Cancel_callback_list_callback (domain, storage, f) -> + begin if domain = Domain.self () then begin + Domain.DLS.set current_storage storage; + handle_with_async_exception_hook f () + end else + Multidomain_sync.send_callback domain (fun () -> + Domain.DLS.set current_storage storage; + handle_with_async_exception_hook f () + ) + end; + iter_list rest + | Cancel_callback_list_remove_sequence_node (domain, node) -> + begin if domain = Domain.self () then + Lwt_sequence.remove node + else + Multidomain_sync.send_callback domain (fun () -> Lwt_sequence.remove node) + end; iter_list rest | Cancel_callback_list_concat (fs, fs') -> iter_callback_list fs (fs'::rest) @@ -1191,16 +1263,22 @@ struct match fs with | Regular_callback_list_empty -> iter_list rest - | Regular_callback_list_implicitly_removed_callback f -> - f result; - iter_list rest - | Regular_callback_list_explicitly_removable_callback - {contents = None} -> - iter_list rest - | Regular_callback_list_explicitly_removable_callback - {contents = Some f} -> - f result; - iter_list rest + | Regular_callback_list_implicitly_removed_callback (domain, f) -> + begin if domain = Domain.self () then + f result + else + Multidomain_sync.send_callback domain (fun () -> f result) + end; + iter_list rest + | Regular_callback_list_explicitly_removable_callback (_, {contents = None}) -> + iter_list rest + | Regular_callback_list_explicitly_removable_callback (domain, {contents = Some f}) -> + begin if domain = Domain.self () then + f result + else + Multidomain_sync.send_callback domain (fun () -> f result) + end; + iter_list rest | Regular_callback_list_concat (fs, fs') -> iter_callback_list fs (fs'::rest) @@ -1229,7 +1307,7 @@ struct let default_maximum_callback_nesting_depth = 42 - let current_callback_nesting_depth = ref 0 + let current_callback_nesting_depth = Domain.DLS.new_key (fun () -> 0) type deferred_callbacks = Deferred : ('a callbacks * 'a resolved_state) -> deferred_callbacks @@ -1242,19 +1320,19 @@ struct the callbacks that will be run will modify the storage. The storage is restored to the snapshot when the resolution loop is exited. *) let enter_resolution_loop () = - current_callback_nesting_depth := !current_callback_nesting_depth + 1; - let storage_snapshot = !current_storage in + Domain.DLS.set current_callback_nesting_depth (Domain.DLS.get current_callback_nesting_depth + 1); + let storage_snapshot = (Domain.DLS.get current_storage) in storage_snapshot let leave_resolution_loop (storage_snapshot : storage) : unit = - if !current_callback_nesting_depth = 1 then begin + if Domain.DLS.get current_callback_nesting_depth = 1 then begin while not (Queue.is_empty deferred_callbacks) do let Deferred (callbacks, result) = Queue.pop deferred_callbacks in run_callbacks callbacks result done end; - current_callback_nesting_depth := !current_callback_nesting_depth - 1; - current_storage := storage_snapshot + Domain.DLS.set current_callback_nesting_depth (Domain.DLS.get current_callback_nesting_depth - 1); + Domain.DLS.set current_storage storage_snapshot let run_in_resolution_loop f = let storage_snapshot = enter_resolution_loop () in @@ -1269,7 +1347,7 @@ struct The name should probably be [abaondon_resolution_loop]. *) let abandon_wakeups () = - if !current_callback_nesting_depth <> 0 then + if Domain.DLS.get current_callback_nesting_depth <> 0 then leave_resolution_loop Storage_map.empty @@ -1281,7 +1359,7 @@ struct let should_defer = allow_deferring - && !current_callback_nesting_depth >= maximum_callback_nesting_depth + && Domain.DLS.get current_callback_nesting_depth >= maximum_callback_nesting_depth in if should_defer then @@ -1309,7 +1387,7 @@ struct else let should_defer = - !current_callback_nesting_depth + Domain.DLS.get current_callback_nesting_depth >= default_maximum_callback_nesting_depth in @@ -1320,7 +1398,7 @@ struct { regular_callbacks = Regular_callback_list_implicitly_removed_callback - deferred_callback; + (Domain.self (), deferred_callback); cancel_callbacks = Cancel_callback_list_empty; how_to_cancel = Not_cancelable; cleanups_deferred = 0 @@ -1576,7 +1654,7 @@ struct let Pending callbacks = p.state in callbacks.cancel_callbacks <- - Cancel_callback_list_remove_sequence_node node; + Cancel_callback_list_remove_sequence_node (Domain.self (), node); to_public_promise p @@ -1587,7 +1665,7 @@ struct let Pending callbacks = p.state in callbacks.cancel_callbacks <- - Cancel_callback_list_remove_sequence_node node; + Cancel_callback_list_remove_sequence_node (Domain.self (), node); to_public_promise p @@ -1831,12 +1909,12 @@ struct [p''] will be equivalent to trying to cancel [p'], so the behavior will depend on how the user obtained [p']. *) - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try f v with exn @@ -1897,12 +1975,12 @@ struct let create_result_promise_and_callback_if_deferred () = let p'' = new_pending ~how_to_cancel:(Propagate_cancel_to_one p) in - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try f v @@ -1954,12 +2032,12 @@ struct let create_result_promise_and_callback_if_deferred () = let p'' = new_pending ~how_to_cancel:(Propagate_cancel_to_one p) in - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p''_result = try Fulfilled (f v) with exn @@ -2020,7 +2098,7 @@ struct let create_result_promise_and_callback_if_deferred () = let p'' = new_pending ~how_to_cancel:(Propagate_cancel_to_one p) in - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with @@ -2033,7 +2111,7 @@ struct ignore p'' | Rejected exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try h exn @@ -2081,7 +2159,7 @@ struct let create_result_promise_and_callback_if_deferred () = let p'' = new_pending ~how_to_cancel:(Propagate_cancel_to_one p) in - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with @@ -2094,7 +2172,7 @@ struct ignore p'' | Rejected exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try h exn @@ -2143,12 +2221,12 @@ struct let create_result_promise_and_callback_if_deferred () = let p'' = new_pending ~how_to_cancel:(Propagate_cancel_to_one p) in - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try f' v @@ -2164,7 +2242,7 @@ struct ignore p'' | Rejected exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try h exn @@ -2218,12 +2296,12 @@ struct let create_result_promise_and_callback_if_deferred () = let p'' = new_pending ~how_to_cancel:(Propagate_cancel_to_one p) in - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in let callback p_result = match p_result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try f' v @@ -2240,7 +2318,7 @@ struct ignore p'' | Rejected exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; let p' = try h exn @@ -2324,12 +2402,12 @@ struct let p = underlying p in let callback_if_deferred () = - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in fun result -> match result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; handle_with_async_exception_hook f v | Rejected _ -> @@ -2357,7 +2435,7 @@ struct let p = underlying p in let callback_if_deferred () = - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in fun result -> match result with @@ -2365,7 +2443,7 @@ struct () | Rejected exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; handle_with_async_exception_hook f exn in @@ -2390,10 +2468,10 @@ struct let p = underlying p in let callback_if_deferred () = - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in fun _result -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; handle_with_async_exception_hook f () in @@ -2423,16 +2501,16 @@ struct let p = underlying p in let callback_if_deferred () = - let saved_storage = !current_storage in + let saved_storage = (Domain.DLS.get current_storage) in fun result -> match result with | Fulfilled v -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; handle_with_async_exception_hook f v | Rejected exn -> - current_storage := saved_storage; + Domain.DLS.set current_storage saved_storage; handle_with_async_exception_hook g exn in @@ -3161,34 +3239,34 @@ struct - let pause_hook = ref ignore + let pause_hook = Domain.DLS.new_key (fun () -> ignore) - let paused = Lwt_sequence.create () - let paused_count = ref 0 + let paused = Domain.DLS.new_key (fun () -> Lwt_sequence.create ()) + let paused_count = Domain.DLS.new_key (fun () -> 0) let pause () = - let p = add_task_r paused in - incr paused_count; - !pause_hook !paused_count; + let p = add_task_r (Domain.DLS.get paused) in + Domain.DLS.set paused_count (Domain.DLS.get paused_count + 1); + (Domain.DLS.get pause_hook) (Domain.DLS.get paused_count); p let wakeup_paused () = - if Lwt_sequence.is_empty paused then - paused_count := 0 + if Lwt_sequence.is_empty (Domain.DLS.get paused) then + Domain.DLS.set paused_count 0 else begin let tmp = Lwt_sequence.create () in - Lwt_sequence.transfer_r paused tmp; - paused_count := 0; + Lwt_sequence.transfer_r (Domain.DLS.get paused) tmp; + Domain.DLS.set paused_count 0; Lwt_sequence.iter_l (fun r -> wakeup r ()) tmp end - let register_pause_notifier f = pause_hook := f + let register_pause_notifier f = Domain.DLS.set pause_hook f let abandon_paused () = - Lwt_sequence.clear paused; - paused_count := 0 + Lwt_sequence.clear (Domain.DLS.get paused); + Domain.DLS.set paused_count 0 - let paused_count () = !paused_count + let paused_count () = Domain.DLS.get paused_count end include Miscellaneous @@ -3228,3 +3306,9 @@ struct let (let+) x f = map f x let (and+) = both end + +module Private = struct + type nonrec storage = storage + module Sequence_associated_storage = Sequence_associated_storage + module Multidomain_sync = Multidomain_sync +end diff --git a/src/core/lwt.mli b/src/core/lwt.mli index 7598343d8..a0f711c25 100644 --- a/src/core/lwt.mli +++ b/src/core/lwt.mli @@ -2061,3 +2061,20 @@ val backtrace_try_bind : val abandon_wakeups : unit -> unit val debug_state_is : 'a state -> 'a t -> bool t + +module Private : sig + type storage + + module Sequence_associated_storage : sig + val get_from_storage : 'a key -> storage -> 'a option + val modify_storage : 'a key -> 'a option -> storage -> storage + val empty_storage : storage + val current_storage : storage Domain.DLS.key + end + + module Multidomain_sync : sig + val get_sent_callbacks : Domain.id -> (unit -> unit) Lwt_sequence.t[@ocaml.warning "-3"] + val register_notification : Domain.id -> (unit -> unit) -> unit + val is_alredy_registered : Domain.id -> bool + end +end [@@alert trespassing "for internal use only, keep away"] diff --git a/src/direct/dune b/src/direct/dune new file mode 100644 index 000000000..7151d9c72 --- /dev/null +++ b/src/direct/dune @@ -0,0 +1,5 @@ +(library + (public_name lwt_direct) + (synopsis "Direct-style control-flow and `await` for Lwt") + (enabled_if (>= %{ocaml_version} "5.0")) + (libraries lwt lwt.unix)) diff --git a/src/direct/lwt_direct.ml b/src/direct/lwt_direct.ml new file mode 100644 index 000000000..a777db3a9 --- /dev/null +++ b/src/direct/lwt_direct.ml @@ -0,0 +1,140 @@ +(* Direct-style wrapper for Lwt code + + The implementation of the direct-style wrapper relies on ocaml5's effect + system capturing continuations and adding them as a callback to some lwt + promises. *) + +(* part 1: tasks, getting the scheduler to call them *) + +let tasks : (unit -> unit) Queue.t Domain.DLS.key = Domain.DLS.new_key Queue.create + +let[@inline] push_task f : unit = Queue.push f (Domain.DLS.get tasks) + +let absolute_max_number_of_steps = + (* TODO 6.0: what's a good number here? should it be customisable? *) + 10_000 + +let run_all_tasks () : unit = + let n_processed = ref 0 in + let max_number_of_steps = min absolute_max_number_of_steps (2 * Queue.length (Domain.DLS.get tasks)) in + while (not (Queue.is_empty (Domain.DLS.get tasks))) && !n_processed < max_number_of_steps do + let t = Queue.pop (Domain.DLS.get tasks) in + incr n_processed; + try t () + with exn -> + (* TODO 6.0: change async_exception handler to accept a backtrace, pass it + here and at the other use site. *) + (* TODO 6.0: this and other try-with: respect exception-filter *) + !Lwt.async_exception_hook exn + done; + (* In the case where there are no promises ready for wakeup, the scheduler's + engine will pause until some IO completes. There might never be completed + IO, depending on the program structure and the state of the world. If this + happens and the queue is not empty, we add a [pause] so that the engine has + something to wakeup for so that the rest of the queue can be processed. *) + if not (Queue.is_empty (Domain.DLS.get tasks)) && Lwt.paused_count () = 0 then ignore (Lwt.pause () : unit Lwt.t) + +let setup_hooks = + let already_done = Domain.DLS.new_key (fun () -> false) in + fun () -> + if not (Domain.DLS.get already_done) then ( + Domain.DLS.set already_done true; + (* TODO 6.0: assess whether we should have both hooks or just one (which + one). Tempted to say we should only have the enter hook. *) + let _hook1 = Lwt_main.Enter_iter_hooks.add_first run_all_tasks in + let _hook2 = Lwt_main.Leave_iter_hooks.add_first run_all_tasks in + () + ) + +(* part 2: effects, performing them *) + +type _ Effect.t += + | Await : 'a Lwt.t -> 'a Effect.t + | Yield : unit Effect.t + +let await (fut : 'a Lwt.t) : 'a = + match Lwt.state fut with + | Lwt.Return x -> x + | Lwt.Fail exn -> raise exn + | Lwt.Sleep -> Effect.perform (Await fut) + +let yield () : unit = Effect.perform Yield + +(* interlude: task-local storage helpers *) + +module Storage = struct + [@@@alert "-trespassing"] + module Lwt_storage = Lwt.Private.Sequence_associated_storage + [@@@alert "+trespassing"] + type 'a key = 'a Lwt.key + let new_key = Lwt.new_key + let get = Lwt.get + let set k v = + let open Lwt_storage in + Domain.DLS.set current_storage (modify_storage k (Some v) (Domain.DLS.get current_storage)) + let remove k = + let open Lwt_storage in + Domain.DLS.set current_storage (modify_storage k None (Domain.DLS.get current_storage)) + let reset_to_empty () = + let open Lwt_storage in + Domain.DLS.set current_storage empty_storage + let save_current () = Domain.DLS.get Lwt_storage.current_storage + let restore_current saved = Domain.DLS.set Lwt_storage.current_storage saved +end + +(* part 3: handling effects *) + +let handler : _ Effect.Deep.effect_handler = + let effc : type b. b Effect.t -> ((b, unit) Effect.Deep.continuation -> 'a) option = + function + | Yield -> + Some (fun k -> + let storage = Storage.save_current () in + push_task (fun () -> + Storage.restore_current storage; + Effect.Deep.continue k ())) + | Await fut -> + Some + (fun k -> + let storage = Storage.save_current () in + Lwt.on_any fut + (fun res -> push_task (fun () -> + Storage.restore_current storage; Effect.Deep.continue k res)) + (fun exn -> push_task (fun () -> + Storage.restore_current storage; Effect.Deep.discontinue k exn))) + | _ -> None + in + { effc } + +(* part 4: putting it all together: running tasks *) + +let run_inside_effect_handler_and_resolve_ (type a) (promise : a Lwt.u) f () : unit = + let run_f_and_set_res () = + Storage.reset_to_empty(); + match f () with + | res -> Lwt.wakeup promise res + | exception exc -> Lwt.wakeup_exn promise exc + in + Effect.Deep.try_with run_f_and_set_res () handler + +let spawn f : _ Lwt.t = + setup_hooks (); + let lwt, resolve = Lwt.wait () in + push_task (run_inside_effect_handler_and_resolve_ resolve f); + lwt + +(* part 4 (encore): running a task in the background *) + +let run_inside_effect_handler_in_the_background_ f () : unit = + let run_f () : unit = + Storage.reset_to_empty(); + try + f () + with exn -> + !Lwt.async_exception_hook exn + in + Effect.Deep.try_with run_f () handler + +let spawn_in_the_background f : unit = + setup_hooks (); + push_task (run_inside_effect_handler_in_the_background_ f) diff --git a/src/direct/lwt_direct.mli b/src/direct/lwt_direct.mli new file mode 100644 index 000000000..058814862 --- /dev/null +++ b/src/direct/lwt_direct.mli @@ -0,0 +1,109 @@ +(** Direct style control flow for Lwt. + + Using this module you can write code in direct style (using loops, + exceptions handlers, etc.) in an Lwt codebase. Your direct-style sections + must be enclosed in a call to {!spawn} and they may {!await} on promises. + For example: + + {[ + open Lwt_direct + spawn (fun () -> + let continue = ref true in + while !continue do + match await @@ Lwt_io.read_line in_channel with + | exception End_of_file -> continue := false + | line -> + let uppercase_line = String.uppercase_ascii line in + await @@ Lwt_io.write_line out_channel uppercase_line + done) + ]} + + In this code snippet, the [while]-loop repeats a simple task of reading from + an {!Lwt_io.channel}, modifying it, and writing it to a different channel. + The code is in direct-style: the control structures are standard OCaml + without any Lwt primitives. + + The code-snippet as a whole is a [unit Lwt.t] promise. It becomes resovled + when the function returns. Conversely, the promises inside the snippet are + wrapped in {!await}, turning them into regular plain (non-Lwt) values + (although values that are not available immediately). + + The [Lwt_direct] module is implemented using OCaml 5's + {{:https://ocaml.org/manual/5.3/effects.html} effects and effect handlers}. + This allows the kind of scheduling where a promise is turned into a regular + value and vice-versa. *) + +val spawn : (unit -> 'a) -> 'a Lwt.t +(** [spawn f] runs the function [f ()], it also returns a promise [p] which is + resolved when the call to [f ()] returns a value. If [f ()] throws an + exception, the promise [p] is rejected. + + The function [f] can create Lwt promises (e.g., by calling functions from + [Lwt_io], [Lwt_unix], or third-party libraries) and use {!await} to wait for + them. These promises are evaluated in the Lwt event loop. + + Like any promise in Lwt, [f ()] can starve the event loop if it runs long + computations without yielding to the event loop. + + Cancelling the promise returned by [spawn] has no effect: the execution of + [f ()] continues and the promise is not cancelled. + + When [f ()] terminates (successfully or not), the promise + [spawn f] is resolved with [f ()]'s result, or the exception + raised by [f ()]. *) + +val spawn_in_the_background : + (unit -> unit) -> + unit +(** [spawn_in_the_background f] is similar to [ignore (spawn f)]. + The computation [f ()] runs in the background in the event loop + and returns no result. + + If [f()] raises an exception, {!Lwt.async_exception_hook} is called. *) + +val yield : unit -> unit +(** Yield to the event loop. + + This is similar to [await (Lwt.pause ())], using less indirection internally + and fewer characters to write. + + Calling [yield] outside of {!spawn} or {!spawn_in_the_background} will raise + an exception, crash your program, or otherwise cause errors. It is a + programming error to do so. *) + +val await : 'a Lwt.t -> 'a +(** [await p] returns the result of [p] (or raises the exception with which [p] + was rejected. + + If [p] is not resolved yet, [await p] will suspend the current task (i.e., + the computation started by the surrounding {!spawn}) and resume it when [p] + is resolved. + + Calling [await] outside of {!spawn} or {!spawn_in_the_background} will raise + an exception, crash your program, or otherwise cause errors. It is a + programming error to do so. *) + +(** Local storage. + + This storage is the same as the one described with {!Lwt.key}, + except that it is usable from the inside of {!spawn} or + {!spawn_in_the_background}. + + Each task has its own storage, independent from other tasks or promises. + + NOTE: it is recommended to use [Lwt_direct.Storage] functions rather than + [Lwt.key] functions from {!Lwt}. The latter is deprecated. *) +module Storage : sig + type 'a key = 'a Lwt.key + val new_key : unit -> 'a key + (** Alias to {!Lwt.new_key} *) + + val get : 'a key -> 'a option + (** get the value associated with this key in local storage, or [None] *) + + val set : 'a key -> 'a -> unit + (** [set k v] sets the key to the value for the rest of the task. *) + + val remove : 'a key -> unit + (** Remove the value associated with this key, if any *) +end diff --git a/src/ppx/dune b/src/ppx/dune index 1a48fc938..9b4719b12 100644 --- a/src/ppx/dune +++ b/src/ppx/dune @@ -5,6 +5,4 @@ (ppx_runtime_libraries lwt) (kind ppx_rewriter) (preprocess - (pps ppxlib.metaquot)) - (instrumentation - (backend bisect_ppx))) + (pps ppxlib.metaquot))) diff --git a/src/react/dune b/src/react/dune index be26a6c34..2e9e7a4b9 100644 --- a/src/react/dune +++ b/src/react/dune @@ -2,6 +2,4 @@ (public_name lwt_react) (synopsis "Reactive programming helpers for Lwt") (wrapped false) - (libraries lwt react) - (instrumentation - (backend bisect_ppx))) + (libraries lwt react)) diff --git a/src/retry/dune b/src/retry/dune index 0dd136a47..f60cb07ed 100644 --- a/src/retry/dune +++ b/src/retry/dune @@ -2,6 +2,4 @@ (public_name lwt_retry) (synopsis "A utility for retrying Lwt computations") (wrapped false) - (libraries lwt lwt.unix) - (instrumentation - (backend bisect_ppx))) + (libraries lwt lwt.unix)) diff --git a/src/unix/dune b/src/unix/dune index a5c6a3977..3502c5409 100644 --- a/src/unix/dune +++ b/src/unix/dune @@ -191,6 +191,4 @@ (flags (:include unix_c_flags.sexp))) (c_library_flags - (:include unix_c_library_flags.sexp)) - (instrumentation - (backend bisect_ppx))) + (:include unix_c_library_flags.sexp))) diff --git a/src/unix/lwt_engine.ml b/src/unix/lwt_engine.ml index 20a8eafc7..a2c6ba3cb 100644 --- a/src/unix/lwt_engine.ml +++ b/src/unix/lwt_engine.ml @@ -416,29 +416,30 @@ end +-----------------------------------------------------------------+ *) let current = - if Lwt_config._HAVE_LIBEV && Lwt_config.libev_default then - ref (new libev () :> t) - else - ref (new select :> t) + Domain.DLS.new_key (fun () -> + if Lwt_config._HAVE_LIBEV && Lwt_config.libev_default then + (new libev () :> t) + else + (new select :> t) +) -let get () = - !current +let get () = Domain.DLS.get current let set ?(transfer=true) ?(destroy=true) engine = - if transfer then !current#transfer (engine : #t :> abstract); - if destroy then !current#destroy; - current := (engine : #t :> t) - -let iter block = !current#iter block -let on_readable fd f = !current#on_readable fd f -let on_writable fd f = !current#on_writable fd f -let on_timer delay repeat f = !current#on_timer delay repeat f -let fake_io fd = !current#fake_io fd -let readable_count () = !current#readable_count -let writable_count () = !current#writable_count -let timer_count () = !current#timer_count -let fork () = !current#fork -let forwards_signal n = !current#forwards_signal n + if transfer then (Domain.DLS.get current)#transfer (engine : #t :> abstract); + if destroy then (Domain.DLS.get current)#destroy; + Domain.DLS.set current (engine : #t :> t) + +let iter block = (Domain.DLS.get current)#iter block +let on_readable fd f = (Domain.DLS.get current)#on_readable fd f +let on_writable fd f = (Domain.DLS.get current)#on_writable fd f +let on_timer delay repeat f = (Domain.DLS.get current)#on_timer delay repeat f +let fake_io fd = (Domain.DLS.get current)#fake_io fd +let readable_count () = (Domain.DLS.get current)#readable_count +let writable_count () = (Domain.DLS.get current)#writable_count +let timer_count () = (Domain.DLS.get current)#timer_count +let fork () = (Domain.DLS.get current)#fork +let forwards_signal n = (Domain.DLS.get current)#forwards_signal n module Versioned = struct diff --git a/src/unix/lwt_gc.ml b/src/unix/lwt_gc.ml index b0925f9dc..762c3be3c 100644 --- a/src/unix/lwt_gc.ml +++ b/src/unix/lwt_gc.ml @@ -12,17 +12,15 @@ module Lwt_sequence = Lwt_sequence let ensure_termination t = if Lwt.state t = Lwt.Sleep then begin - let hook = - Lwt_sequence.add_l (fun _ -> t) Lwt_main.exit_hooks [@ocaml.warning "-3"] - in + let hook = Lwt_main.Exit_hooks.add_first (fun _ -> t) in (* Remove the hook when t has terminated *) ignore ( Lwt.finalize (fun () -> t) - (fun () -> Lwt_sequence.remove hook; Lwt.return_unit)) + (fun () -> Lwt_main.Exit_hooks.remove hook; Lwt.return_unit)) end -let finaliser f = +let finaliser ?domain f = (* In order not to create a reference to the value in the notification callback, we use an initially unset option cell which will be filled when the finaliser is called. *) @@ -30,6 +28,7 @@ let finaliser f = let id = Lwt_unix.make_notification ~once:true + ?for_other_domain:domain (fun () -> match !opt with | None -> @@ -43,8 +42,8 @@ let finaliser f = opt := Some x; Lwt_unix.send_notification id) -let finalise f x = - Gc.finalise (finaliser f) x +let finalise ?domain f x = + Gc.finalise (finaliser ?domain f) x (* Exit hook for a finalise_or_exit *) let foe_exit f called weak () = @@ -68,7 +67,7 @@ let foe_finaliser f called hook = finaliser (fun x -> (* Remove the exit hook, it is not needed anymore. *) - Lwt_sequence.remove hook; + Lwt_main.Exit_hooks.remove hook; (* Call the real finaliser. *) if !called then Lwt.return_unit @@ -83,8 +82,5 @@ let finalise_or_exit f x = let weak = Weak.create 1 in Weak.set weak 0 (Some x); let called = ref false in - let hook = - Lwt_sequence.add_l (foe_exit f called weak) Lwt_main.exit_hooks - [@ocaml.warning "-3"] - in + let hook = Lwt_main.Exit_hooks.add_first (foe_exit f called weak) in Gc.finalise (foe_finaliser f called hook) x diff --git a/src/unix/lwt_gc.mli b/src/unix/lwt_gc.mli index e69218f5a..99bc11385 100644 --- a/src/unix/lwt_gc.mli +++ b/src/unix/lwt_gc.mli @@ -9,14 +9,25 @@ thread to a value, without having to use [Lwt_unix.run] in the finaliser. *) -val finalise : ('a -> unit Lwt.t) -> 'a -> unit +val finalise : ?domain:Domain.id -> ('a -> unit Lwt.t) -> 'a -> unit (** [finalise f x] ensures [f x] is evaluated after [x] has been garbage collected. If [f x] yields, then Lwt will wait for its termination at the end of the program. Note that [f x] is not called at garbage collection time, but - later in the main loop. *) + later in the main loop. + + If [domain] is provided, then [f x] is evaluated in the corresponding + domain. Otherwise it is evaluated in the domain calling [finalise]. If + Lwt is not running in the domain set to run the finaliser, an + unspecified error occurs at an unspecified time or the finaliser doesn't + run or some other bad thing happens. *) val finalise_or_exit : ('a -> unit Lwt.t) -> 'a -> unit (** [finalise_or_exit f x] call [f x] when [x] is garbage collected - or (exclusively) when the program exits. *) + or (exclusively) when the program exits. + + The finaliser [f] is called in the same domain that called + [finalise_or_exit]. If there is no Lwt scheduler running in this domain an + unspecified error occurs. You can use [Lwt_preemptive.run_in_domain] to + bypass the same-domain limitation. *) diff --git a/src/unix/lwt_main.ml b/src/unix/lwt_main.ml index 823666e5f..60caf6028 100644 --- a/src/unix/lwt_main.ml +++ b/src/unix/lwt_main.ml @@ -12,8 +12,8 @@ module Lwt_sequence = Lwt_sequence open Lwt.Infix -let enter_iter_hooks = Lwt_sequence.create () -let leave_iter_hooks = Lwt_sequence.create () +let enter_iter_hooks = Domain.DLS.new_key (fun () -> Lwt_sequence.create ()) +let leave_iter_hooks = Domain.DLS.new_key (fun () -> Lwt_sequence.create ()) let yield = Lwt.pause @@ -21,13 +21,24 @@ let abandon_yielded_and_paused () = Lwt.abandon_paused () let run p = + let domain = Domain.self () in + let () = if (Lwt.Private.Multidomain_sync.is_alredy_registered[@alert "-trespassing"]) domain then + () + else begin + let n = Lwt_unix.make_notification (fun () -> + let cbs = (Lwt.Private.Multidomain_sync.get_sent_callbacks[@alert "-trespassing"]) domain in + Lwt_sequence.iter_l (fun f -> f ()) cbs + ) in + (Lwt.Private.Multidomain_sync.register_notification[@alert "-trespassing"]) domain(fun () -> Lwt_unix.send_notification n) + end + in let rec run_loop () = match Lwt.poll p with | Some x -> x | None -> (* Call enter hooks. *) - Lwt_sequence.iter_l (fun f -> f ()) enter_iter_hooks; + Lwt_sequence.iter_l (fun f -> f ()) (Domain.DLS.get enter_iter_hooks); (* Do the main loop call. *) let should_block_waiting_for_io = Lwt.paused_count () = 0 in @@ -37,7 +48,7 @@ let run p = Lwt.wakeup_paused (); (* Call leave hooks. *) - Lwt_sequence.iter_l (fun f -> f ()) leave_iter_hooks; + Lwt_sequence.iter_l (fun f -> f ()) (Domain.DLS.get leave_iter_hooks); (* Repeat. *) run_loop () @@ -45,56 +56,54 @@ let run p = run_loop () -let run_already_called = ref `No -let run_already_called_mutex = Mutex.create () +let run_already_called = Domain.DLS.new_key (fun () -> `No) +let run_already_called_mutex = Domain.DLS.new_key (fun () -> Mutex.create ()) let finished () = - Mutex.lock run_already_called_mutex; - run_already_called := `No; - Mutex.unlock run_already_called_mutex + Mutex.lock (Domain.DLS.get run_already_called_mutex); + Domain.DLS.set run_already_called `No; + Mutex.unlock (Domain.DLS.get run_already_called_mutex) let run p = (* Fail in case a call to Lwt_main.run is nested under another invocation of Lwt_main.run. *) - Mutex.lock run_already_called_mutex; - + Mutex.lock (Domain.DLS.get run_already_called_mutex); let error_message_if_call_is_nested = - match !run_already_called with - (* `From is effectively disabled for the time being, because there is a bug, - present in all versions of OCaml supported by Lwt, where, with the - bytecode runtime, if one changes the working directory and then attempts - to retrieve the backtrace, the runtime calls [abort] at the C level and - exits the program ungracefully. It is especially likely that a daemon - would change directory before calling [Lwt_main.run], so we can't have it - retrieving the backtrace, even though a daemon is not likely to be - compiled to bytecode. - - This can be addressed with detection. Starting with 4.04, there is a - type [Sys.backend_type] that could be used. *) - | `From backtrace_string -> - Some (Printf.sprintf "%s\n%s\n%s" - "Nested calls to Lwt_main.run are not allowed" - "Lwt_main.run already called from:" - backtrace_string) - | `From_somewhere -> - Some ("Nested calls to Lwt_main.run are not allowed") - | `No -> - let called_from = - (* See comment above. - if Printexc.backtrace_status () then - let backtrace = - try raise Exit - with Exit -> Printexc.get_backtrace () - in - `From backtrace - else *) - `From_somewhere - in - run_already_called := called_from; - None + match (Domain.DLS.get run_already_called) with + (* `From is effectively disabled for the time being, because there is a bug, + present in all versions of OCaml supported by Lwt, where, with the + bytecode runtime, if one changes the working directory and then attempts + to retrieve the backtrace, the runtime calls [abort] at the C level and + exits the program ungracefully. It is especially likely that a daemon + would change directory before calling [Lwt_main.run], so we can't have it + retrieving the backtrace, even though a daemon is not likely to be + compiled to bytecode. + + This can be addressed with detection. Starting with 4.04, there is a + type [Sys.backend_type] that could be used. *) + | `From backtrace_string -> + Some (Printf.sprintf "%s\n%s\n%s" + "Nested calls to Lwt_main.run are not allowed" + "Lwt_main.run already called from:" + backtrace_string) + | `From_somewhere -> + Some ("Nested calls to Lwt_main.run are not allowed") + | `No -> + let called_from = + (* See comment above. + if Printexc.backtrace_status () then + let backtrace = + try raise Exit + with Exit -> Printexc.get_backtrace () + in + `From backtrace + else *) + `From_somewhere + in + Domain.DLS.set run_already_called called_from; + None in - - Mutex.unlock run_already_called_mutex; + Mutex.unlock (Domain.DLS.get run_already_called_mutex); begin match error_message_if_call_is_nested with | Some message -> failwith message @@ -109,10 +118,10 @@ let run p = finished (); raise exn -let exit_hooks = Lwt_sequence.create () +let exit_hooks = Domain.DLS.new_key (fun () -> Lwt_sequence.create ()) let rec call_hooks () = - match Lwt_sequence.take_opt_l exit_hooks with + match Lwt_sequence.take_opt_l (Domain.DLS.get exit_hooks) with | None -> Lwt.return_unit | Some f -> @@ -123,13 +132,13 @@ let rec call_hooks () = let () = at_exit (fun () -> - if not (Lwt_sequence.is_empty exit_hooks) then begin + if not (Lwt_sequence.is_empty (Domain.DLS.get exit_hooks)) then begin Lwt.abandon_wakeups (); finished (); run (call_hooks ()) end) -let at_exit f = ignore (Lwt_sequence.add_l f exit_hooks) +let at_exit f = ignore (Lwt_sequence.add_l f (Domain.DLS.get exit_hooks)) module type Hooks = sig @@ -145,7 +154,7 @@ end module type Hook_sequence = sig type 'return_value kind - val sequence : (unit -> unit kind) Lwt_sequence.t + val sequence : (unit -> unit kind) Lwt_sequence.t Domain.DLS.key end module Wrap_hooks (Sequence : Hook_sequence) = @@ -154,18 +163,18 @@ struct type hook = (unit -> unit Sequence.kind) Lwt_sequence.node let add_first hook_fn = - let hook_node = Lwt_sequence.add_l hook_fn Sequence.sequence in + let hook_node = Lwt_sequence.add_l hook_fn (Domain.DLS.get Sequence.sequence) in hook_node let add_last hook_fn = - let hook_node = Lwt_sequence.add_r hook_fn Sequence.sequence in + let hook_node = Lwt_sequence.add_r hook_fn (Domain.DLS.get Sequence.sequence) in hook_node let remove hook_node = Lwt_sequence.remove hook_node let remove_all () = - Lwt_sequence.iter_node_l Lwt_sequence.remove Sequence.sequence + Lwt_sequence.iter_node_l Lwt_sequence.remove (Domain.DLS.get Sequence.sequence) end module Enter_iter_hooks = diff --git a/src/unix/lwt_main.mli b/src/unix/lwt_main.mli index f2ebde219..d66eddde0 100644 --- a/src/unix/lwt_main.mli +++ b/src/unix/lwt_main.mli @@ -77,8 +77,14 @@ val abandon_yielded_and_paused : unit -> unit [@@deprecated "Use Lwt.abandon_pau (** Hook sequences. Each module of this type is a set of hooks, to be run by Lwt - at certain points during execution. See modules {!Enter_iter_hooks}, - {!Leave_iter_hooks}, and {!Exit_hooks}. *) + at certain points during execution. + + Hooks are added for the current domain. If you are calling the Hook + functions from a domain where Lwt is not running a scheduler then some + unspecified error may occur. If you need to set some Hooks to/from a + different domain, you can use [Lwt_preemptive.run_in_domain]. + + See modules {!Enter_iter_hooks}, {!Leave_iter_hooks}, and {!Exit_hooks}. *) module type Hooks = sig type 'return_value kind @@ -126,29 +132,6 @@ module Leave_iter_hooks : module Exit_hooks : Hooks with type 'return_value kind = 'return_value Lwt.t - - -[@@@ocaml.warning "-3"] - -val enter_iter_hooks : (unit -> unit) Lwt_sequence.t - [@@ocaml.deprecated - " Use module Lwt_main.Enter_iter_hooks."] -(** @deprecated Use module {!Enter_iter_hooks}. *) - -val leave_iter_hooks : (unit -> unit) Lwt_sequence.t - [@@ocaml.deprecated - " Use module Lwt_main.Leave_iter_hooks."] -(** @deprecated Use module {!Leave_iter_hooks}. *) - -val exit_hooks : (unit -> unit Lwt.t) Lwt_sequence.t - [@@ocaml.deprecated - " Use module Lwt_main.Exit_hooks."] -(** @deprecated Use module {!Exit_hooks}. *) - -[@@@ocaml.warning "+3"] - - - val at_exit : (unit -> unit Lwt.t) -> unit (** [Lwt_main.at_exit hook] is the same as [ignore (Lwt_main.Exit_hooks.add_first hook)]. *) diff --git a/src/unix/lwt_preemptive.ml b/src/unix/lwt_preemptive.ml index eacf32f28..3db5757c0 100644 --- a/src/unix/lwt_preemptive.ml +++ b/src/unix/lwt_preemptive.ml @@ -16,24 +16,23 @@ open Lwt.Infix | Parameters | +-----------------------------------------------------------------+ *) -(* Minimum number of preemptive threads: *) -let min_threads : int ref = ref 0 +(* Minimum number of preemptive threads per domain *) +let min_threads : int Domain.DLS.key = Domain.DLS.new_key (fun () -> 0) -(* Maximum number of preemptive threads: *) -let max_threads : int ref = ref 0 +(* Maximum number of preemptive threads per domain *) +let max_threads : int Domain.DLS.key = Domain.DLS.new_key (fun () -> 0) -(* Size of the waiting queue: *) -let max_thread_queued = ref 1000 +(* Size of the waiting queue per domain *) +let max_thread_queued = Domain.DLS.new_key (fun () -> 1000) -let get_max_number_of_threads_queued _ = - !max_thread_queued +let get_max_number_of_threads_queued () = Domain.DLS.get max_thread_queued let set_max_number_of_threads_queued n = if n < 0 then invalid_arg "Lwt_preemptive.set_max_number_of_threads_queued"; - max_thread_queued := n + Domain.DLS.set max_thread_queued n (* The total number of preemptive threads currently running: *) -let threads_count = ref 0 +let threads_count = Domain.DLS.new_key (fun () -> 0) (* +-----------------------------------------------------------------+ | Preemptive threads management | @@ -44,14 +43,15 @@ sig type 'a t val make : unit -> 'a t - val get : 'a t -> 'a + val get : 'a t -> ('a, unit) result val set : 'a t -> 'a -> unit + val kill : 'a t -> unit end = struct type 'a t = { m : Mutex.t; cv : Condition.t; - mutable cell : 'a option; + mutable cell : ('a, unit) result option; } let make () = { m = Mutex.create (); cv = Condition.create (); cell = None } @@ -72,13 +72,19 @@ struct let set t v = Mutex.lock t.m; - t.cell <- Some v; + t.cell <- Some (Ok v); + Mutex.unlock t.m; + Condition.signal t.cv + + let kill t = + Mutex.lock t.m; + t.cell <- Some (Error ()); Mutex.unlock t.m; Condition.signal t.cv end type thread = { - task_cell: (int * (unit -> unit)) CELL.t; + task_cell: (Lwt_unix.notification * (unit -> unit)) CELL.t; (* Channel used to communicate notification id and tasks to the worker thread. *) @@ -91,25 +97,27 @@ type thread = { } (* Pool of worker threads: *) -let workers : thread Queue.t = Queue.create () +let workers : thread Queue.t Domain.DLS.key = Domain.DLS.new_key Queue.create (* Queue of clients waiting for a worker to be available: *) -let waiters : thread Lwt.u Lwt_sequence.t = Lwt_sequence.create () +let waiters : thread Lwt.u Lwt_sequence.t Domain.DLS.key = Domain.DLS.new_key Lwt_sequence.create (* Code executed by a worker: *) let rec worker_loop worker = - let id, task = CELL.get worker.task_cell in - task (); - (* If there is too much threads, exit. This can happen if the user - decreased the maximum: *) - if !threads_count > !max_threads then worker.reuse <- false; - (* Tell the main thread that work is done: *) - Lwt_unix.send_notification id; - if worker.reuse then worker_loop worker + match CELL.get worker.task_cell with + | Error () -> () + | Ok (id, task) -> + task (); + (* If there is too much threads, exit. This can happen if the user + decreased the maximum: *) + if Domain.DLS.get threads_count > Domain.DLS.get max_threads then worker.reuse <- false; + (* Tell the main thread that work is done: *) + Lwt_unix.send_notification id; + if worker.reuse then worker_loop worker (* create a new worker: *) let make_worker () = - incr threads_count; + Domain.DLS.set threads_count (Domain.DLS.get threads_count + 1); let worker = { task_cell = CELL.make (); thread = Thread.self (); @@ -120,52 +128,52 @@ let make_worker () = (* Add a worker to the pool: *) let add_worker worker = - match Lwt_sequence.take_opt_l waiters with + match Lwt_sequence.take_opt_l (Domain.DLS.get waiters) with | None -> - Queue.add worker workers + Queue.add worker (Domain.DLS.get workers) | Some w -> Lwt.wakeup w worker (* Wait for worker to be available, then return it: *) let get_worker () = - if not (Queue.is_empty workers) then - Lwt.return (Queue.take workers) - else if !threads_count < !max_threads then + if not (Queue.is_empty (Domain.DLS.get workers)) then + Lwt.return (Queue.take (Domain.DLS.get workers)) + else if Domain.DLS.get threads_count < Domain.DLS.get max_threads then Lwt.return (make_worker ()) else - (Lwt.add_task_r [@ocaml.warning "-3"]) waiters + (Lwt.add_task_r [@ocaml.warning "-3"]) (Domain.DLS.get waiters) (* +-----------------------------------------------------------------+ | Initialisation, and dynamic parameters reset | +-----------------------------------------------------------------+ *) -let get_bounds () = (!min_threads, !max_threads) +let get_bounds () = (Domain.DLS.get min_threads, Domain.DLS.get max_threads) let set_bounds (min, max) = if min < 0 || max < min then invalid_arg "Lwt_preemptive.set_bounds"; - let diff = min - !threads_count in - min_threads := min; - max_threads := max; + let diff = min - Domain.DLS.get threads_count in + Domain.DLS.set min_threads min; + Domain.DLS.set max_threads max; (* Launch new workers: *) for _i = 1 to diff do add_worker (make_worker ()) done -let initialized = ref false +let initialized = Domain.DLS.new_key (fun () -> false) let init min max _errlog = - initialized := true; + Domain.DLS.set initialized true; set_bounds (min, max) let simple_init () = - if not !initialized then begin - initialized := true; + if not (Domain.DLS.get initialized) then begin + Domain.DLS.set initialized true; set_bounds (0, 4) end -let nbthreads () = !threads_count -let nbthreadsqueued () = Lwt_sequence.fold_l (fun _ x -> x + 1) waiters 0 -let nbthreadsbusy () = !threads_count - Queue.length workers +let nbthreads () = Domain.DLS.get threads_count +let nbthreadsqueued () = Lwt_sequence.fold_l (fun _ x -> x + 1) (Domain.DLS.get waiters) 0 +let nbthreadsbusy () = Domain.DLS.get threads_count - Queue.length (Domain.DLS.get workers) (* +-----------------------------------------------------------------+ | Detaching | @@ -186,6 +194,7 @@ let detach f args = get_worker () >>= fun worker -> let waiter, wakener = Lwt.wait () in let id = + (* call back the domain that called the [detach] function: self *) Lwt_unix.make_notification ~once:true (fun () -> Lwt.wakeup_result wakener !result) in @@ -199,7 +208,7 @@ let detach f args = (* Put back the worker to the pool: *) add_worker worker else begin - decr threads_count; + Domain.DLS.set threads_count (Domain.DLS.get threads_count - 1); (* Or wait for the thread to terminates, to free its associated resources: *) Thread.join worker.thread @@ -216,30 +225,34 @@ let jobs = Queue.create () (* Mutex to protect access to [jobs]. *) let jobs_mutex = Mutex.create () -let job_notification = - Lwt_unix.make_notification +let job_notification = Domain_map.create_protected_map () +let get_job_notification d = + Domain_map.init job_notification d (fun () -> - (* Take the first job. The queue is never empty at this - point. *) - Mutex.lock jobs_mutex; - let thunk = Queue.take jobs in - Mutex.unlock jobs_mutex; - ignore (thunk ())) - -let run_in_main_dont_wait f = + Lwt_unix.make_notification ~for_other_domain:d + (fun () -> + (* Take the first job. The queue is never empty at this + point. *) + Mutex.lock jobs_mutex; + let thunk = Queue.take jobs in + Mutex.unlock jobs_mutex; + ignore (thunk ())) + ) + +let run_in_domain_dont_wait d f = (* Add the job to the queue. *) Mutex.lock jobs_mutex; Queue.add f jobs; Mutex.unlock jobs_mutex; (* Notify the main thread. *) - Lwt_unix.send_notification job_notification + Lwt_unix.send_notification (get_job_notification d) (* There is a potential performance issue from creating a cell every time this function is called. See: https://github.com/ocsigen/lwt/issues/218 https://github.com/ocsigen/lwt/pull/219 https://github.com/ocaml/ocaml/issues/7158 *) -let run_in_main f = +let run_in_domain d f = let cell = CELL.make () in (* Create the job. *) let job () = @@ -251,13 +264,17 @@ let run_in_main f = CELL.set cell result; Lwt.return_unit in - run_in_main_dont_wait job; + run_in_domain_dont_wait d job; (* Wait for the result. *) match CELL.get cell with - | Result.Ok ret -> ret - | Result.Error exn -> raise exn + | Ok (Ok ret) -> ret + | Ok (Error exn) -> raise exn + | Error () -> assert false (* This version shadows the one above, adding an exception handler *) -let run_in_main_dont_wait f handler = +let run_in_domain_dont_wait d f handler = let f () = Lwt.catch f (fun exc -> handler exc; Lwt.return_unit) in - run_in_main_dont_wait f + run_in_domain_dont_wait d f + +let terminate_worker_threads () = + Queue.iter (fun thread -> CELL.kill thread.task_cell) (Domain.DLS.get workers) diff --git a/src/unix/lwt_preemptive.mli b/src/unix/lwt_preemptive.mli index 24077350f..446345e07 100644 --- a/src/unix/lwt_preemptive.mli +++ b/src/unix/lwt_preemptive.mli @@ -21,21 +21,21 @@ val detach : ('a -> 'b) -> 'a -> 'b Lwt.t Note that Lwt thread-local storage (i.e., {!Lwt.with_value}) cannot be safely used from within [f]. The same goes for most of the rest of Lwt. If - you need to run an Lwt thread in [f], use {!run_in_main}. *) + you need to run an Lwt thread in [f], use {!run_in_domain}. *) -val run_in_main : (unit -> 'a Lwt.t) -> 'a - (** [run_in_main f] can be called from a detached computation to execute +val run_in_domain : Domain.id -> (unit -> 'a Lwt.t) -> 'a + (** [run_in_domain f] can be called from a detached computation to execute [f ()] in the main preemptive thread, i.e. the one executing - {!Lwt_main.run}. [run_in_main f] blocks until [f ()] completes, then - returns its result. If [f ()] raises an exception, [run_in_main f] raises + {!Lwt_main.run}. [run_in_domain f] blocks until [f ()] completes, then + returns its result. If [f ()] raises an exception, [run_in_domain f] raises the same exception. {!Lwt.with_value} may be used inside [f ()]. {!Lwt.get} can correctly retrieve values set this way inside [f ()], but not values set using {!Lwt.with_value} outside [f ()]. *) -val run_in_main_dont_wait : (unit -> unit Lwt.t) -> (exn -> unit) -> unit -(** [run_in_main_dont_wait f h] does the same as [run_in_main f] but a bit faster +val run_in_domain_dont_wait : Domain.id -> (unit -> unit Lwt.t) -> (exn -> unit) -> unit +(** [run_in_domain_dont_wait f h] does the same as [run_in_domain f] but a bit faster and lighter as it does not wait for the result of [f]. If [f]'s promise is rejected (or if it raises), then the function [h] is @@ -53,7 +53,10 @@ val init : int -> int -> (string -> unit) -> unit @param log is used to log error messages If {!Lwt_preemptive} has already been initialised, this call - only modify bounds and the log function. *) + only modify bounds and the log function. + + The limits are set per-domain. More specifically, each domain manages a + pool of systhreads, each pool having its own limits and its own state. *) val simple_init : unit -> unit (** [simple_init ()] checks if the library is not yet initialized, and if not, @@ -80,6 +83,17 @@ val get_max_number_of_threads_queued : unit -> int (** Returns the size of the waiting queue, if no more threads are available *) +val terminate_worker_threads : unit -> unit +(* [terminate_worker_threads ()] queues up a message for all the workers of the + calling domain to self-terminate. This causes all the workers to terminate + after their current jobs are done which causes the threads of these workers + to end. + + Terminating the threads attached to a domain is necessary for joining the + domain. Thus, if you use-case for domains includes spawning and joining them, + you must call [terminate_worker_threads] just before calling + [Domain.join]. *) + (**/**) val nbthreads : unit -> int val nbthreadsbusy : unit -> int diff --git a/src/unix/lwt_process.mli b/src/unix/lwt_process.mli index 51198c5d7..e15c394e3 100644 --- a/src/unix/lwt_process.mli +++ b/src/unix/lwt_process.mli @@ -5,7 +5,11 @@ (** Process management *) -(** This module allows you to spawn processes and communicate with them. *) +(** This module allows you to spawn processes and communicate with them. + + This module makes heavy use of {!Lwt_unix.fork}. Important caveats are + documented there. Read them. TL;DR: no domains, no threads, no preemptive, + yes [Async_none]. *) type command = string * string array (** A command. The first field is the name of the executable and diff --git a/src/unix/lwt_unix.cppo.ml b/src/unix/lwt_unix.cppo.ml index 6fb9f8044..53c3c3b78 100644 --- a/src/unix/lwt_unix.cppo.ml +++ b/src/unix/lwt_unix.cppo.ml @@ -19,35 +19,32 @@ open Lwt.Infix type async_method = | Async_none | Async_detach - | Async_switch -let default_async_method_var = ref Async_detach +let default_async_method_var = Atomic.make Async_detach let () = try match Sys.getenv "LWT_ASYNC_METHOD" with | "none" -> - default_async_method_var := Async_none + Atomic.set default_async_method_var Async_none | "detach" -> - default_async_method_var := Async_detach - | "switch" -> - default_async_method_var := Async_switch + Atomic.set default_async_method_var Async_detach | str -> Printf.eprintf - "%s: invalid lwt async method: '%s', must be 'none', 'detach' or 'switch'\n%!" + "%s: invalid lwt async method: '%s', must be 'none' or 'detach'\n%!" (Filename.basename Sys.executable_name) str with Not_found -> () -let default_async_method () = !default_async_method_var -let set_default_async_method am = default_async_method_var := am +let default_async_method () = Atomic.get default_async_method_var +let set_default_async_method am = Atomic.set default_async_method_var am let async_method_key = Lwt.new_key () let async_method () = match Lwt.get async_method_key with | Some am -> am - | None -> !default_async_method_var + | None -> Atomic.get default_async_method_var let with_async_none f = Lwt.with_value async_method_key (Some Async_none) f @@ -55,9 +52,6 @@ let with_async_none f = let with_async_detach f = Lwt.with_value async_method_key (Some Async_detach) f -let with_async_switch f = - Lwt.with_value async_method_key (Some Async_switch) f - (* +-----------------------------------------------------------------+ | Notifications management | +-----------------------------------------------------------------+ *) @@ -78,38 +72,58 @@ module Notifiers = Hashtbl.Make(struct let hash (x : int) = x end) -let notifiers = Notifiers.create 1024 +let notifiers = Domain_map.create_protected_map () (* See https://github.com/ocsigen/lwt/issues/277 and https://github.com/ocsigen/lwt/pull/278. *) -let current_notification_id = ref (0x7FFFFFFF - 1000) - -let rec find_free_id id = - if Notifiers.mem notifiers id then - find_free_id (id + 1) - else - id - -let make_notification ?(once=false) f = - let id = find_free_id (!current_notification_id + 1) in - current_notification_id := id; - Notifiers.add notifiers id { notify_once = once; notify_handler = f }; - id +let current_notification_id = Atomic.make (0x7FFFFFFF - 1000) -let stop_notification id = - Notifiers.remove notifiers id +type notification = { domain: Domain.id; id: int; } -let set_notification id f = - let notifier = Notifiers.find notifiers id in - Notifiers.replace notifiers id { notifier with notify_handler = f } +let make_notification ?(once=false) ?for_other_domain f = + let domain = match for_other_domain with + | Some domain -> domain + | None -> Domain.self () + in + let id = Atomic.fetch_and_add current_notification_id 1 in + Domain_map.update notifiers domain + (function + | None -> + let notifiers = Notifiers.create 1024 in + Notifiers.add notifiers id { notify_once = once; notify_handler = f }; + Some notifiers + | Some notifiers -> + Notifiers.add notifiers id { notify_once = once; notify_handler = f }; + Some notifiers); + { domain; id } + +let stop_notification { domain; id } = + Domain_map.update notifiers domain + (function + | None -> None + | Some notifiers -> + Notifiers.remove notifiers id; + Some notifiers) -let call_notification id = - match Notifiers.find notifiers id with - | exception Not_found -> () - | notifier -> - if notifier.notify_once then - stop_notification id; - notifier.notify_handler () +let set_notification { domain; id } f = + Domain_map.update notifiers domain + (function + | None -> raise Not_found + | Some notifiers -> + let notifier = Notifiers.find notifiers id in + Notifiers.replace notifiers id { notifier with notify_handler = f }; + Some notifiers) + +let call_notification { domain; id } = + match Domain_map.find notifiers domain with + | None -> () + | Some notifiers -> + (match Notifiers.find notifiers id with + | exception Not_found -> () + | notifier -> + if notifier.notify_once then + Notifiers.remove notifiers id; + notifier.notify_handler ()) (* +-----------------------------------------------------------------+ | Sleepers | @@ -155,7 +169,7 @@ let with_timeout d f = Lwt.pick [timeout d; Lwt.apply f ()] type 'a job -external start_job : 'a job -> async_method -> bool = "lwt_unix_start_job" +external start_job : Domain.id -> 'a job -> async_method -> bool = "lwt_unix_start_job" (* Starts the given job with given parameters. It returns [true] if the job is already terminated. *) @@ -178,15 +192,10 @@ let cancel_jobs () = abort_jobs Lwt.Canceled let wait_for_jobs () = Lwt.join (Lwt_sequence.fold_l (fun (w, _) l -> w :: l) jobs []) -let wrap_result f x = - try - Result.Ok (f x) - with exn when Lwt.Exception_filter.run exn -> - Result.Error exn - let run_job_aux async_method job result = + let domain = Domain.self () in (* Starts the job. *) - if start_job job async_method then + if start_job domain job async_method then (* The job has already terminated, read and return the result immediately. *) Lwt.of_result (result job) @@ -200,7 +209,7 @@ let run_job_aux async_method job result = jobs in ignore begin (* Create the notification for asynchronous wakeup. *) - let id = + let notification = make_notification ~once:true (fun () -> Lwt_sequence.remove node; @@ -211,7 +220,7 @@ let run_job_aux async_method job result = notification. *) Lwt.pause () >>= fun () -> (* The job has terminated, send the result immediately. *) - if check_job job id then call_notification id; + if check_job job notification.id then call_notification notification; Lwt.return_unit end; waiter @@ -223,12 +232,7 @@ let choose_async_method = function | None -> match Lwt.get async_method_key with | Some am -> am - | None -> !default_async_method_var - -let execute_job ?async_method ~job ~result ~free = - let async_method = choose_async_method async_method in - run_job_aux async_method job (fun job -> let x = wrap_result result job in free job; x) -[@@ocaml.warning "-16"] + | None -> Atomic.get default_async_method_var external self_result : 'a job -> 'a = "lwt_unix_self_result" (* returns the result of a job using the [result] field of the C @@ -243,22 +247,7 @@ let self_result job = with exn when Lwt.Exception_filter.run exn -> Result.Error exn -let in_retention_test = ref false - -let retained o = - let retained = ref true in - Gc.finalise (fun _ -> - if !in_retention_test then - retained := false) - o; - in_retention_test := true; - retained - let run_job ?async_method job = - if !in_retention_test then begin - Gc.full_major (); - in_retention_test := false - end; let async_method = choose_async_method async_method in if async_method = Async_none then try @@ -2208,15 +2197,26 @@ let tcflow ch act = | Reading notifications | +-----------------------------------------------------------------+ *) -external init_notification : unit -> Unix.file_descr = "lwt_unix_init_notification" -external send_notification : int -> unit = "lwt_unix_send_notification_stub" -external recv_notifications : unit -> int array = "lwt_unix_recv_notifications" +external init_notification : Domain.id -> Unix.file_descr = "lwt_unix_init_notification_stub" +external send_notification : Domain.id -> int -> unit = "lwt_unix_send_notification_stub" +let send_notification { domain; id } = send_notification domain id +external recv_notifications : Domain.id -> int array = "lwt_unix_recv_notifications_stub" -let handle_notifications _ = - (* Process available notifications. *) - Array.iter call_notification (recv_notifications ()) +let handle_notifications (_ : Lwt_engine.event) = + let domain = Domain.self () in + Array.iter (fun id -> call_notification { domain; id }) (recv_notifications domain) -let event_notifications = ref (Lwt_engine.on_readable (init_notification ()) handle_notifications) +let event_notifications = + Domain.DLS.new_key (fun () -> + let domain = Domain.self () in + Lwt_engine.on_readable (init_notification domain) handle_notifications + ) + +let init_domain () = + let domain = Domain.self () in + let _ : notifier Notifiers.t = (Domain_map.init notifiers domain (fun () -> Notifiers.create 1024)) in + let _ : Lwt_engine.event = Domain.DLS.get event_notifications in + () (* +-----------------------------------------------------------------+ | Signals | @@ -2227,6 +2227,8 @@ external remove_signal : int -> bool -> unit = "lwt_unix_remove_signal" external init_signals : unit -> unit = "lwt_unix_init_signals" external handle_signal : int -> unit = "lwt_unix_handle_signal" +let signal_setting_mutex = Mutex.create () + let () = init_signals () let set_signal signum notification = @@ -2244,14 +2246,18 @@ type signal_handler = { and signal_handler_id = signal_handler option ref -let signals = ref Signal_map.empty +let signals + (* a simple ref, but all access for write are behind a mutex *) + : (notification * ((signal_handler_id -> file_perm -> unit) Lwt_sequence.t) ) Signal_map.t ref + = ref Signal_map.empty let signal_count () = Signal_map.fold - (fun _signum (_id, actions) len -> len + Lwt_sequence.length actions) + (fun _signum (_notification, actions) len -> len + Lwt_sequence.length actions) !signals 0 let on_signal_full signum handler = + Mutex.lock signal_setting_mutex; let id = ref None in let _, actions = try @@ -2259,6 +2265,9 @@ let on_signal_full signum handler = with Not_found -> let actions = Lwt_sequence.create () in let notification = + (* TODO: this assumes `on_signal` is called from domain0 where an lwt + scheduler is running running, should it be possible to set a signal + handler to execute in a specific domain?? *) make_notification (fun () -> Lwt_sequence.iter_l @@ -2266,7 +2275,7 @@ let on_signal_full signum handler = actions) in (try - set_signal signum notification + set_signal signum notification.id with exn when Lwt.Exception_filter.run exn -> stop_notification notification; raise exn); @@ -2275,15 +2284,17 @@ let on_signal_full signum handler = in let node = Lwt_sequence.add_r handler actions in id := Some { sh_num = signum; sh_node = node }; + Mutex.unlock signal_setting_mutex; id -let on_signal signum f = on_signal_full signum (fun _id num -> f num) +let on_signal signum f = on_signal_full signum (fun _notification num -> f num) let disable_signal_handler id = match !id with | None -> () | Some sh -> + Mutex.lock signal_setting_mutex; id := None; Lwt_sequence.remove sh.sh_node; let notification, actions = Signal_map.find sh.sh_num !signals in @@ -2291,13 +2302,16 @@ let disable_signal_handler id = remove_signal sh.sh_num; signals := Signal_map.remove sh.sh_num !signals; stop_notification notification - end + end; + Mutex.unlock signal_setting_mutex let reinstall_signal_handler signum = match Signal_map.find signum !signals with | exception Not_found -> () | notification, _ -> - set_signal signum notification + Mutex.lock signal_setting_mutex; + set_signal signum notification.id; + Mutex.unlock signal_setting_mutex (* +-----------------------------------------------------------------+ | Processes | @@ -2305,6 +2319,7 @@ let reinstall_signal_handler signum = external reset_after_fork : unit -> unit = "lwt_unix_reset_after_fork" +(* TODO: replace fork with something thread+domain safe *) let fork () = match Unix.fork () with | 0 -> @@ -2313,16 +2328,17 @@ let fork () = (* Reset threading. *) reset_after_fork (); (* Stop the old event for notifications. *) - Lwt_engine.stop_event !event_notifications; + let domain = Domain.self () in + Lwt_engine.stop_event (Domain.DLS.get event_notifications); (* Reinitialise the notification system. *) - event_notifications := Lwt_engine.on_readable (init_notification ()) handle_notifications; + Domain.DLS.set event_notifications (Lwt_engine.on_readable (init_notification domain) handle_notifications); (* Collect all pending jobs. *) let l = Lwt_sequence.fold_l (fun (_, f) l -> f :: l) jobs [] in (* Remove them all. *) Lwt_sequence.iter_node_l Lwt_sequence.remove jobs; (* And cancel them all. We yield first so that if the program do an exec just after, it won't be executed. *) - Lwt.on_termination (Lwt_main.yield () [@warning "-3"]) (fun () -> List.iter (fun f -> f Lwt.Canceled) l); + Lwt.on_termination (Lwt.pause ()) (fun () -> List.iter (fun f -> f Lwt.Canceled) l); 0 | pid -> pid @@ -2383,6 +2399,12 @@ let install_sigchld_handler () = install the SIGCHLD handler, in order to cause any EINTR-unsafe code to fail (as it should). *) let () = + (* TODO: this assumes that an Lwt main loop will be started in domain0 (where + this value is allocated bc top-level initialisation), instead + [install_sigchld_handler] should be called when the first lwt-scheduler is + started which could be in a non-zero domain + + or TODO: remove sigchld handler if fork is completely abandonned?? *) Lwt.async (fun () -> Lwt.pause () >|= fun () -> install_sigchld_handler ()) @@ -2462,8 +2484,6 @@ let system cmd = | Misc | +-----------------------------------------------------------------+ *) -let run = Lwt_main.run - let handle_unix_error f x = Lwt.catch (fun () -> f x) diff --git a/src/unix/lwt_unix.cppo.mli b/src/unix/lwt_unix.cppo.mli index c36d9a470..58716a66c 100644 --- a/src/unix/lwt_unix.cppo.mli +++ b/src/unix/lwt_unix.cppo.mli @@ -211,8 +211,16 @@ val fork : unit -> int - None of the above is necessary if you intend to call [exec]. Indeed, in that case, it is not even necessary to use [Lwt_unix.fork]. You can use {!Unix.fork}. - - To abandon some more promises, see - {!Lwt_main.abandon_yielded_and_paused}. *) + - To abandon some more promises, see {!Lwt.abandon_paused}. + + Furthermore: + + - Calling [Lwt_unix.fork] raises an execption if [Domain.spawn] has been + called at any point in the program's past. + - Calling [Lwt_unix.fork] can result in the child process being in a + corrupted state if any thread has been started. Lwt starts threads when + [Lwt_preemptive.detach] is called. Lwt implicitly starts threads to + perform blocking I/O unless the {!async_method} is set to [Async_none]. *) type process_status = Unix.process_status = @@ -257,7 +265,10 @@ val system : string -> process_status Lwt.t (** Executes the given command, waits until it terminates, and return its termination status. The string is interpreted by the shell [/bin/sh] on Unix and [cmd.exe] on Windows. The result - [WEXITED 127] indicates that the shell couldn't be executed. *) + [WEXITED 127] indicates that the shell couldn't be executed. + + The function uses {!fork} internally. As a result, this function is + brittle. See all the warnings relating to [fork] for more details. *) (** {2 Basic file input/output} *) @@ -1279,98 +1290,59 @@ val tcflow : file_descr -> flow_action -> unit Lwt.t -(** {2 Configuration (deprecated)} *) +(** {2 Configuration} *) (** For system calls that cannot be made asynchronously, Lwt uses one of the following method: *) type async_method = | Async_none (** System calls are made synchronously, and may block the - entire program. *) + entire program. + + The main use cases for this are: + - debugging (execution is simpler) + - working with fork and exec (which are not thread-safe) + - when calling specific blocking I/O which is known to be fast *) | Async_detach (** System calls are made in another system thread, thus without blocking other Lwt promises. The drawback is that it may degrade performance in some cases. This is the default. *) - | Async_switch - [@ocaml.deprecated " Use Lwt_unix.Async_detach."] - (** @deprecated A synonym for [Async_detach]. This was a - different method in the past. *) val default_async_method : unit -> async_method - [@@ocaml.deprecated -" Will always return Async_detach in Lwt >= 5.0.0. See - https://github.com/ocsigen/lwt/issues/572"] (** Returns the default async method. This can be initialized using the environment variable - ["LWT_ASYNC_METHOD"] with possible values ["none"], - ["detach"] and ["switch"]. - - @deprecated Will always return [Async_detach] in Lwt 5.0.0. *) + ["LWT_ASYNC_METHOD"] with possible values ["none"] and + ["detach"]. +*) val set_default_async_method : async_method -> unit - [@@ocaml.deprecated -" Will be a no-op in Lwt >= 5.0.0. See - https://github.com/ocsigen/lwt/issues/572"] -(** Sets the default async method. - - @deprecated Will be a no-op in Lwt 5.0.0. *) +(** Sets the default async method. *) val async_method : unit -> async_method - [@@ocaml.deprecated -" Will always return Async_detach in Lwt >= 5.0.0. See - https://github.com/ocsigen/lwt/issues/572"] (** [async_method ()] returns the async method used in the current - thread. - - @deprecated Will always return [Async_detach] in Lwt 5.0.0. *) + thread. *) val async_method_key : async_method Lwt.key - [@@ocaml.deprecated -" Will be ignored in Lwt >= 5.0.0. See - https://github.com/ocsigen/lwt/issues/572"] -(** The key for storing the local async method. - - @deprecated Will be ignored in Lwt 5.0.0. *) +(** The key for storing the local async method. *) val with_async_none : (unit -> 'a) -> 'a - [@@ocaml.deprecated -" Will have no effect in Lwt >= 5.0.0. See - https://github.com/ocsigen/lwt/issues/572"] (** [with_async_none f] is a shorthand for: {[ Lwt.with_value async_method_key (Some Async_none) f ]} - - @deprecated Will have no effect in Lwt 5.0.0. *) +*) val with_async_detach : (unit -> 'a) -> 'a - [@@ocaml.deprecated -" Will have no effect in Lwt >= 5.0.0. See - https://github.com/ocsigen/lwt/issues/572"] (** [with_async_detach f] is a shorthand for: {[ Lwt.with_value async_method_key (Some Async_detach) f ]} - - @deprecated Will have no effect in Lwt 5.0.0. *) - -val with_async_switch : (unit -> 'a) -> 'a - [@@ocaml.deprecated -" Will have no effect in Lwt >= 5.0.0. See - https://github.com/ocsigen/lwt/issues/572"] -(** [with_async_switch f] is a shorthand for: - - {[ - Lwt.with_value async_method_key (Some Async_switch) f - ]} - - @deprecated Will have no effect in Lwt 5.0.0. *) - +*) (** {2 Low-level interaction} *) @@ -1458,46 +1430,48 @@ val cancel_jobs : unit -> unit val wait_for_jobs : unit -> unit Lwt.t (** Wait for all pending jobs to terminate. *) -val execute_job : - ?async_method : async_method -> - job : 'a job -> - result : ('a job -> 'b) -> - free : ('a job -> unit) -> 'b Lwt.t - [@@ocaml.deprecated " Use Lwt_unix.run_job."] - (** @deprecated Use [run_job]. *) - (** {2 Notifications} *) (** Lwt internally use a pipe to send notification to the main thread. The following functions allow to use this pipe. *) -val make_notification : ?once : bool -> (unit -> unit) -> int - (** [make_notification ?once f] registers a new notifier. It returns the - id of the notifier. Each time a notification with this id is +type notification + +val make_notification : ?once : bool -> ?for_other_domain:Domain.id -> (unit -> unit) -> notification + (** [make_notification ?once ?for_other_domain f] registers a new notifier. It + returns the id of the notifier. Each time a notification with this id is received, [f] is called. if [once] is specified, then the notification is stopped after - the first time it is received. It defaults to [false]. *) + the first time it is received. It defaults to [false] + + if [for_other_domain] is specified, then the notification will trigger the + Lwt main loop on the given domain. An unspecified error may occur if the + specified domain is not running an Lwt main loop. If unspecified, + [Domain.self ()] is used. *) -val send_notification : int -> unit +val send_notification : notification -> unit (** [send_notification id] sends a notification. This function is thread-safe. *) -val stop_notification : int -> unit +val stop_notification : notification -> unit (** Stop the given notification. Note that you should not reuse the id after the notification has been stopped, the result is unspecified if you do so. *) -val call_notification : int -> unit +val call_notification : notification -> unit (** Call the handler associated to the given notification. Note that if the notification was defined with [once = true] it is removed. *) -val set_notification : int -> (unit -> unit) -> unit +val set_notification : notification -> (unit -> unit) -> unit (** [set_notification id f] replace the function associated to the notification by [f]. It raises [Not_found] if the given notification is not found. *) +val init_domain : unit -> unit + (** call when Domain.spawn! and call on domain0 too, don't call twice for the same domain *) + (** {2 System threads pool} *) (** If the program is using the async method [Async_detach] or @@ -1579,10 +1553,6 @@ end (**/**) -val run : 'a Lwt.t -> 'a - [@@ocaml.deprecated " Use Lwt_main.run."] - (** @deprecated Use [Lwt_main.run]. *) - val has_wait4 : bool [@@ocaml.deprecated " Use Lwt_sys.have `wait4."] (** @deprecated Use [Lwt_sys.have `wait4]. *) @@ -1591,9 +1561,6 @@ val somaxconn : unit -> int [@@ocaml.deprecated " This is an internal function."] (** @deprecated This is for internal use only. *) -val retained : 'a -> bool ref - (** @deprecated Used for testing. *) - val read_bigarray : string -> file_descr -> IO_vectors._bigarray -> int -> int -> int Lwt.t [@@ocaml.deprecated " This is an internal function."] diff --git a/src/unix/lwt_unix.h b/src/unix/lwt_unix.h index ab4ad64bf..389082fda 100644 --- a/src/unix/lwt_unix.h +++ b/src/unix/lwt_unix.h @@ -95,7 +95,7 @@ void lwt_unix_not_available(char const *feature) Noreturn; +-----------------------------------------------------------------+ */ /* Sends a notification for the given id. */ -void lwt_unix_send_notification(intnat id); +void lwt_unix_send_notification(intnat domain_id, intnat id); /* +-----------------------------------------------------------------+ | Threading | @@ -196,6 +196,7 @@ struct lwt_unix_job { /* Id used to notify the main thread in case the job do not terminate immediately. */ + intnat domain_id; intnat notification_id; /* The function to call to do the work. diff --git a/src/unix/lwt_unix_stubs.c b/src/unix/lwt_unix_stubs.c index 443773bac..a69409ecf 100644 --- a/src/unix/lwt_unix_stubs.c +++ b/src/unix/lwt_unix_stubs.c @@ -17,6 +17,7 @@ #include #include #include +#include #include #include @@ -492,22 +493,10 @@ CAMLprim value lwt_unix_socketpair_stub(value cloexec, value domain, value type, | Notifications | +-----------------------------------------------------------------+ */ -/* The mutex used to send and receive notifications. */ -static lwt_unix_mutex notification_mutex; - -/* All pending notifications. */ -static intnat *notifications = NULL; - -/* The size of the notification buffer. */ -static long notification_count = 0; - -/* The index to the next available cell in the notification buffer. */ -static long notification_index = 0; - /* The mode currently used for notifications. */ enum notification_mode { - /* Not yet initialized. */ - NOTIFICATION_MODE_NOT_INITIALIZED, + /* Not yet initialized. Explicitly set to zero for domain-array initialisation */ + NOTIFICATION_MODE_NOT_INITIALIZED = 0, /* Initialized but no mode defined. */ NOTIFICATION_MODE_NONE, @@ -522,35 +511,50 @@ enum notification_mode { NOTIFICATION_MODE_WINDOWS }; -/* The current notification mode. */ -static enum notification_mode notification_mode = - NOTIFICATION_MODE_NOT_INITIALIZED; +/* Domain-specific notification state */ +struct domain_notification_state { + lwt_unix_mutex notification_mutex; + intnat *notifications; + long notification_count; + long notification_index; +#if defined(HAVE_EVENTFD) + int notification_fd; +#endif + int notification_fds[2]; +}; + +/* table to store per-domain notification state */ +#define MAX_DOMAINS 64 // TODO: review values +static struct domain_notification_state domain_states[MAX_DOMAINS]; +static enum notification_mode domain_notification_mode[MAX_DOMAINS] = {0}; /* Send one notification. */ -static int (*notification_send)(); +static int (*notification_send)(int domain_id); /* Read one notification. */ -static int (*notification_recv)(); +static int (*notification_recv)(int domain_id); -static void init_notifications() { - lwt_unix_mutex_init(¬ification_mutex); - notification_count = 4096; - notifications = - (intnat *)lwt_unix_malloc(notification_count * sizeof(intnat)); +static void init_domain_notifications(int domain_id) { + lwt_unix_mutex_init(&domain_states[domain_id].notification_mutex); + domain_states[domain_id].notification_count = 4096; + domain_states[domain_id].notifications = + (intnat *)lwt_unix_malloc(domain_states[domain_id].notification_count * sizeof(intnat)); + domain_states[domain_id].notification_index = 0; } -static void resize_notifications() { - long new_notification_count = notification_count * 2; - intnat *new_notifications = - (intnat *)lwt_unix_malloc(new_notification_count * sizeof(intnat)); - memcpy((void *)new_notifications, (void *)notifications, - notification_count * sizeof(intnat)); - free(notifications); - notifications = new_notifications; - notification_count = new_notification_count; +static void resize_notifications(int domain_id) { + struct domain_notification_state *state = &domain_states[domain_id]; + long new_notification_count = state->notification_count * 2; + intnat *new_notifications = + (intnat *)lwt_unix_malloc(new_notification_count * sizeof(intnat)); + memcpy((void *)new_notifications, (void *)state->notifications, + state->notification_count * sizeof(intnat)); + free(state->notifications); + state->notifications = new_notifications; + state->notification_count = new_notification_count; } -void lwt_unix_send_notification(intnat id) { +void lwt_unix_send_notification(intnat domain_id, intnat id) { int ret; #if !defined(LWT_ON_WINDOWS) sigset_t new_mask; @@ -561,21 +565,22 @@ void lwt_unix_send_notification(intnat id) { #else DWORD error; #endif - lwt_unix_mutex_lock(¬ification_mutex); - if (notification_index > 0) { + lwt_unix_mutex_lock(&domain_states[domain_id].notification_mutex); + struct domain_notification_state *state = &domain_states[domain_id]; + if (state->notification_index > 0) { /* There is already a pending notification in the buffer, no need to signal the main thread. */ - if (notification_index == notification_count) resize_notifications(); - notifications[notification_index++] = id; + if (state->notification_index == state->notification_count) resize_notifications(domain_id); + state->notifications[state->notification_index++] = id; } else { /* There is none, notify the main thread. */ - notifications[notification_index++] = id; - ret = notification_send(); + state->notifications[state->notification_index++] = id; + ret = notification_send(domain_id); #if defined(LWT_ON_WINDOWS) if (ret == SOCKET_ERROR) { error = WSAGetLastError(); if (error != WSANOTINITIALISED) { - lwt_unix_mutex_unlock(¬ification_mutex); + lwt_unix_mutex_unlock(&domain_states[domain_id].notification_mutex); win32_maperr(error); uerror("send_notification", Nothing); } /* else we're probably shutting down, so ignore the error */ @@ -583,24 +588,24 @@ void lwt_unix_send_notification(intnat id) { #else if (ret < 0) { error = errno; - lwt_unix_mutex_unlock(¬ification_mutex); + lwt_unix_mutex_unlock(&domain_states[domain_id].notification_mutex); pthread_sigmask(SIG_SETMASK, &old_mask, NULL); unix_error(error, "send_notification", Nothing); } #endif } - lwt_unix_mutex_unlock(¬ification_mutex); + lwt_unix_mutex_unlock(&domain_states[domain_id].notification_mutex); #if !defined(LWT_ON_WINDOWS) pthread_sigmask(SIG_SETMASK, &old_mask, NULL); #endif } -value lwt_unix_send_notification_stub(value id) { - lwt_unix_send_notification(Long_val(id)); +value lwt_unix_send_notification_stub(value domain_id, value id) { + lwt_unix_send_notification(Long_val(domain_id), Long_val(id)); return Val_unit; } -value lwt_unix_recv_notifications() { +value lwt_unix_recv_notifications(intnat domain_id) { int ret, i, current_index; value result; #if !defined(LWT_ON_WINDOWS) @@ -612,77 +617,86 @@ value lwt_unix_recv_notifications() { #else DWORD error; #endif - lwt_unix_mutex_lock(¬ification_mutex); + /* Initialize domain state if needed */ + lwt_unix_mutex_lock(&domain_states[domain_id].notification_mutex); /* Receive the signal. */ - ret = notification_recv(); + ret = notification_recv(domain_id); #if defined(LWT_ON_WINDOWS) if (ret == SOCKET_ERROR) { error = WSAGetLastError(); - lwt_unix_mutex_unlock(¬ification_mutex); + lwt_unix_mutex_unlock(&domain_states[domain_id].notification_mutex); win32_maperr(error); uerror("recv_notifications", Nothing); } #else if (ret < 0) { error = errno; - lwt_unix_mutex_unlock(¬ification_mutex); + lwt_unix_mutex_unlock(&domain_states[domain_id].notification_mutex); pthread_sigmask(SIG_SETMASK, &old_mask, NULL); unix_error(error, "recv_notifications", Nothing); } #endif - do { - /* - release the mutex while calling caml_alloc, - which may call gc and switch the thread, - resulting in a classical deadlock, - when thread in question tries another send - */ - current_index = notification_index; - lwt_unix_mutex_unlock(¬ification_mutex); - result = caml_alloc_tuple(current_index); - lwt_unix_mutex_lock(¬ification_mutex); - /* check that no new notifications appeared meanwhile (rare) */ - } while (current_index != notification_index); - - /* Read all pending notifications. */ - for (i = 0; i < notification_index; i++) - Field(result, i) = Val_long(notifications[i]); - /* Reset the index. */ - notification_index = 0; - lwt_unix_mutex_unlock(¬ification_mutex); + struct domain_notification_state *state = &domain_states[domain_id]; + + do { + /* + release the mutex while calling caml_alloc, + which may call gc and switch the thread, + resulting in a classical deadlock, + when thread in question tries another send + */ + current_index = state->notification_index; + lwt_unix_mutex_unlock(&domain_states[domain_id].notification_mutex); + result = caml_alloc_tuple(current_index); + lwt_unix_mutex_lock(&domain_states[domain_id].notification_mutex); + /* check that no new notifications appeared meanwhile (rare) */ + } while (current_index != state->notification_index); + + /* Read all pending notifications. */ + for (i = 0; i < state->notification_index; i++) + Field(result, i) = Val_long(state->notifications[i]); + /* Reset the index. */ + state->notification_index = 0; + lwt_unix_mutex_unlock(&domain_states[domain_id].notification_mutex); #if !defined(LWT_ON_WINDOWS) pthread_sigmask(SIG_SETMASK, &old_mask, NULL); #endif return result; } +value lwt_unix_recv_notifications_stub(value domain_id) { + value res = lwt_unix_recv_notifications(Long_val(domain_id)); + return res; +} + #if defined(LWT_ON_WINDOWS) -static SOCKET socket_r, socket_w; +static SOCKET domain_socket_r[MAX_DOMAINS]; +static SOCKET domain_socket_w[MAX_DOMAINS]; -static int windows_notification_send() { +static int windows_notification_send(int domain_id) { char buf = '!'; - return send(socket_w, &buf, 1, 0); + return send(domain_socket_w[domain_id], &buf, 1, 0); } -static int windows_notification_recv() { +static int windows_notification_recv(int domain_id) { char buf; - return recv(socket_r, &buf, 1, 0); + return recv(domain_socket_r[domain_id], &buf, 1, 0); } -value lwt_unix_init_notification() { +value lwt_unix_init_notification(int domain_id) { SOCKET sockets[2]; - switch (notification_mode) { + switch (domain_notification_mode[domain_id]) { case NOTIFICATION_MODE_NOT_INITIALIZED: - notification_mode = NOTIFICATION_MODE_NONE; - init_notifications(); + domain_notification_mode[domain_id] = NOTIFICATION_MODE_NONE; + init_domain_notifications(domain_id); break; case NOTIFICATION_MODE_WINDOWS: - notification_mode = NOTIFICATION_MODE_NONE; - closesocket(socket_r); - closesocket(socket_w); + domain_notification_mode[domain_id] = NOTIFICATION_MODE_NONE; + closesocket(domain_socket_r[domain_id]); + closesocket(domain_socket_w[domain_id]); break; case NOTIFICATION_MODE_NONE: break; @@ -694,12 +708,12 @@ value lwt_unix_init_notification() { sockets. */ lwt_unix_socketpair(AF_INET, SOCK_STREAM, IPPROTO_TCP, sockets, FALSE); - socket_r = sockets[0]; - socket_w = sockets[1]; - notification_mode = NOTIFICATION_MODE_WINDOWS; + domain_socket_r[domain_id] = sockets[0]; + domain_socket_w[domain_id] = sockets[1]; + domain_notification_mode[domain_id] = NOTIFICATION_MODE_WINDOWS; notification_send = windows_notification_send; notification_recv = windows_notification_recv; - return win_alloc_socket(socket_r); + return win_alloc_socket(domain_socket_r[domain_id]); } #else /* defined(LWT_ON_WINDOWS) */ @@ -712,48 +726,68 @@ static void set_close_on_exec(int fd) { #if defined(HAVE_EVENTFD) -static int notification_fd; - -static int eventfd_notification_send() { +static int eventfd_notification_send(int domain_id) { uint64_t buf = 1; - return write(notification_fd, (char *)&buf, 8); + if (domain_id < 0 || domain_id >= MAX_DOMAINS) { + return -1; + } + struct domain_notification_state *state = &domain_states[domain_id]; + int result = write(state->notification_fd, (char *)&buf, 8); + return result; } -static int eventfd_notification_recv() { +static int eventfd_notification_recv(int domain_id) { uint64_t buf; - return read(notification_fd, (char *)&buf, 8); + if (domain_id < 0 || domain_id >= MAX_DOMAINS) { + return -1; + } + struct domain_notification_state *state = &domain_states[domain_id]; + int result = read(state->notification_fd, (char *)&buf, 8); + return result; } #endif /* defined(HAVE_EVENTFD) */ -static int notification_fds[2]; - -static int pipe_notification_send() { +static int pipe_notification_send(int domain_id) { char buf = 0; - return write(notification_fds[1], &buf, 1); + if (domain_id < 0 || domain_id >= MAX_DOMAINS) { + return -1; + } + struct domain_notification_state *state = &domain_states[domain_id]; + int result = write(state->notification_fds[1], &buf, 1); + return result; } -static int pipe_notification_recv() { +static int pipe_notification_recv(int domain_id) { char buf; - return read(notification_fds[0], &buf, 1); + if (domain_id < 0 || domain_id >= MAX_DOMAINS) { + return -1; + } + struct domain_notification_state *state = &domain_states[domain_id]; + int result = read(state->notification_fds[0], &buf, 1); + return result; } -value lwt_unix_init_notification() { - switch (notification_mode) { +value lwt_unix_init_notification(int domain_id) { + if (domain_id < 0 || domain_id >= MAX_DOMAINS) { + caml_failwith("invalid domain_id in lwt_unix_init_notification"); + } + struct domain_notification_state *state = &domain_states[domain_id]; + switch (domain_notification_mode[domain_id]) { #if defined(HAVE_EVENTFD) case NOTIFICATION_MODE_EVENTFD: - notification_mode = NOTIFICATION_MODE_NONE; - if (close(notification_fd) == -1) uerror("close", Nothing); + domain_notification_mode[domain_id] = NOTIFICATION_MODE_NONE; + if (close(state->notification_fd) == -1) uerror("close", Nothing); break; #endif case NOTIFICATION_MODE_PIPE: - notification_mode = NOTIFICATION_MODE_NONE; - if (close(notification_fds[0]) == -1) uerror("close", Nothing); - if (close(notification_fds[1]) == -1) uerror("close", Nothing); + domain_notification_mode[domain_id] = NOTIFICATION_MODE_NONE; + if (close(state->notification_fds[0]) == -1) uerror("close", Nothing); + if (close(state->notification_fds[1]) == -1) uerror("close", Nothing); break; case NOTIFICATION_MODE_NOT_INITIALIZED: - notification_mode = NOTIFICATION_MODE_NONE; - init_notifications(); + domain_notification_mode[domain_id] = NOTIFICATION_MODE_NONE; + init_domain_notifications(domain_id); break; case NOTIFICATION_MODE_NONE: break; @@ -762,27 +796,32 @@ value lwt_unix_init_notification() { } #if defined(HAVE_EVENTFD) - notification_fd = eventfd(0, 0); - if (notification_fd != -1) { - notification_mode = NOTIFICATION_MODE_EVENTFD; + state->notification_fd = eventfd(0, 0); + if (state->notification_fd != -1) { + domain_notification_mode[domain_id] = NOTIFICATION_MODE_EVENTFD; notification_send = eventfd_notification_send; notification_recv = eventfd_notification_recv; - set_close_on_exec(notification_fd); - return Val_int(notification_fd); + set_close_on_exec(state->notification_fd); + return Val_int(state->notification_fd); } #endif - if (pipe(notification_fds) == -1) uerror("pipe", Nothing); - set_close_on_exec(notification_fds[0]); - set_close_on_exec(notification_fds[1]); - notification_mode = NOTIFICATION_MODE_PIPE; + if (pipe(state->notification_fds) == -1) uerror("pipe", Nothing); + set_close_on_exec(state->notification_fds[0]); + set_close_on_exec(state->notification_fds[1]); + domain_notification_mode[domain_id] = NOTIFICATION_MODE_PIPE; notification_send = pipe_notification_send; notification_recv = pipe_notification_recv; - return Val_int(notification_fds[0]); + return Val_int(state->notification_fds[0]); } #endif /* defined(LWT_ON_WINDOWS) */ +CAMLprim value lwt_unix_init_notification_stub(value domain_id) { + value res = lwt_unix_init_notification(Long_val(domain_id)); + return res; +} + /* +-----------------------------------------------------------------+ | Signals | +-----------------------------------------------------------------+ */ @@ -797,7 +836,7 @@ static intnat signal_notifications[NSIG]; CAMLextern int caml_convert_signal_number(int); /* Send a notification when a signal is received. */ -static void handle_signal(int signum) { +void handle_signal(int signum) { if (signum >= 0 && signum < NSIG) { intnat id = signal_notifications[signum]; if (id != -1) { @@ -806,7 +845,9 @@ static void handle_signal(int signum) { function. */ signal(signum, handle_signal); #endif - lwt_unix_send_notification(id); + //TODO: domain_self instead of root (0)? caml doesn't expose + //caml_ml_domain_id in domain.h :( + lwt_unix_send_notification(0, id); } } } @@ -822,7 +863,9 @@ static BOOL WINAPI handle_break(DWORD event) { intnat id = signal_notifications[SIGINT]; if (id == -1 || (event != CTRL_C_EVENT && event != CTRL_BREAK_EVENT)) return FALSE; - lwt_unix_send_notification(id); + //TODO: domain_self instead of root (0)? caml doesn't expose + //caml_ml_domain_id in domain.h :( + lwt_unix_send_notification(0, id); return TRUE; } #endif @@ -909,7 +952,7 @@ CAMLprim value lwt_unix_init_signals(value Unit) { +-----------------------------------------------------------------+ */ /* Execute the given job. */ -static void execute_job(lwt_unix_job job) { +void execute_job(lwt_unix_job job) { DEBUG("executing the job"); lwt_unix_mutex_lock(&job->mutex); @@ -937,7 +980,7 @@ static void execute_job(lwt_unix_job job) { if (job->fast == 0) { lwt_unix_mutex_unlock(&job->mutex); DEBUG("notifying the main thread"); - lwt_unix_send_notification(job->notification_id); + lwt_unix_send_notification(job->domain_id, job->notification_id); } else { lwt_unix_mutex_unlock(&job->mutex); DEBUG("not notifying the main thread"); @@ -990,7 +1033,7 @@ void initialize_threading() { /* Function executed by threads of the pool. * Note: all signals are masked for this thread. */ -static void *worker_loop(void *data) { +void *worker_loop(void *data) { lwt_unix_job job = (lwt_unix_job)data; /* Execute the initial job if any. */ @@ -1058,7 +1101,7 @@ void lwt_unix_free_job(lwt_unix_job job) { free(job); } -CAMLprim value lwt_unix_start_job(value val_job, value val_async_method) { +CAMLprim value lwt_unix_start_job(value domain_id, value val_job, value val_async_method) { lwt_unix_job job = Job_val(val_job); lwt_unix_async_method async_method = Int_val(val_async_method); int done = 0; @@ -1073,6 +1116,7 @@ CAMLprim value lwt_unix_start_job(value val_job, value val_async_method) { job->state = LWT_UNIX_JOB_STATE_PENDING; job->fast = 1; job->async_method = async_method; + job->domain_id = Long_val(domain_id); switch (async_method) { case LWT_UNIX_ASYNC_METHOD_NONE: diff --git a/test/core/test_lwt.ml b/test/core/test_lwt.ml index f22c72233..d33f97725 100644 --- a/test/core/test_lwt.ml +++ b/test/core/test_lwt.ml @@ -47,7 +47,6 @@ let add_loc exn = try raise exn with exn -> exn let suites : Test.suite list = [] - (* Tests for promises created with [Lwt.return], [Lwt.fail], and related functions, as well as state query (hard to test one without the other). These tests use assertions instead of relying on the correctness of a final @@ -2124,6 +2123,7 @@ let both_tests = suite "both" [ state_is Lwt.Sleep p end; + test "pending, fulfilled, then fulfilled" begin fun () -> let p1, r1 = Lwt.wait () in let p = Lwt.both p1 (Lwt.return 2) in @@ -4205,7 +4205,7 @@ let lwt_sequence_tests = suite "add_task_l and add_task_r" [ let suites = suites @ [lwt_sequence_tests] - +(* let pause_tests = suite "pause" [ test "initial state" begin fun () -> Lwt.return (Lwt.paused_count () = 0) @@ -4290,6 +4290,7 @@ let pause_tests = suite "pause" [ end; ] let suites = suites @ [pause_tests] +*) diff --git a/test/direct/dune b/test/direct/dune new file mode 100644 index 000000000..99f89afe8 --- /dev/null +++ b/test/direct/dune @@ -0,0 +1,7 @@ + +(test + (name main) + (package lwt_direct) + (build_if (>= %{ocaml_version} "5.0")) + (libraries lwt_direct lwt.unix lwttester)) + diff --git a/test/direct/main.ml b/test/direct/main.ml new file mode 100644 index 000000000..faefba9ab --- /dev/null +++ b/test/direct/main.ml @@ -0,0 +1,2 @@ + +let () = Test.run "lwt_direct" Test_lwt_direct.suites ;; diff --git a/test/direct/test_lwt_direct.ml b/test/direct/test_lwt_direct.ml new file mode 100644 index 000000000..1a37886fe --- /dev/null +++ b/test/direct/test_lwt_direct.ml @@ -0,0 +1,222 @@ +open Test +open Lwt_direct +open Lwt.Syntax + +let main_tests = suite "main" [ + test "basic await" begin fun () -> + let fut = spawn @@ fun () -> + Lwt_unix.sleep 1e-6 |> await; + 42 + in + let+ res = fut in + res = 42 + end; + + test "await multiple values" begin fun () -> + let fut1 = let+ () = Lwt_unix.sleep 1e-6 in 1 in + let fut2 = let+ () = Lwt_unix.sleep 2e-6 in 2 in + let fut3 = let+ () = Lwt_unix.sleep 3e-6 in 3 in + + spawn @@ fun () -> + let x1 = fut1 |> await in + let x2 = fut2 |> await in + let x3 = fut3 |> await in + x1 = 1 && x2 = 2 && x3 = 3 + end; + + test "list.iter await" begin fun () -> + let items = List.init 101 (fun i -> Lwt.return i) in + spawn @@ fun () -> + let sum = ref 0 in + List.iter (fun fut -> sum := !sum + await fut) items; + !sum = 5050 + end; + + test "lwt_list.iter_p spawn" begin fun () -> + let items = List.init 101 (fun i -> i) in + let+ items = Lwt_list.map_p + (fun i -> spawn (fun () -> + for _ = 0 to i mod 5 do yield () done; + i + )) + items + in + List.fold_left (+) 0 items = 5050 + end; + + test "spawn in background" begin fun () -> + let stream, push = Lwt_stream.create_bounded 2 in + spawn_in_the_background (fun () -> + for i = 1 to 10 do + push#push i |> await + done; + push#close); + spawn @@ fun () -> + let continue = ref true in + let seen = ref [] in + + while !continue do + match Lwt_stream.get stream |> await with + | None -> continue := false + | Some x -> seen := x :: !seen + done; + List.rev !seen = [1;2;3;4;5;6;7;8;9;10] + end; + + test "list.iter await with yield" begin fun () -> + let items = List.init 101 (fun i -> Lwt.return i) in + spawn @@ fun () -> + let sum = ref 0 in + List.iter (fun fut -> yield(); sum := !sum + await fut) items; + !sum = 5050 + end; + + test "awaiting on failing promise" begin fun () -> + let fut: unit Lwt.t = let* () = Lwt.pause () in let* () = Lwt_unix.sleep 0.0001 in Lwt.fail Exit in + spawn @@ fun () -> + try await fut; false + with Exit -> true + end; + + test "spawn can fail" begin fun () -> + spawn @@ fun () -> + let sub: unit Lwt.t = spawn @@ fun () -> + Lwt_unix.sleep 0.00001 |> await; + raise Exit + in + try await sub; false + with Exit -> true + end; + + test "concurrent fib" begin fun () -> + let rec badfib n = + if n <= 2 then Lwt.return 1 + else + spawn begin fun () -> + let f1 = badfib (n-1) in + let f2 = badfib (n-2) in + await f1 + await f2 + end + in + spawn @@ fun () -> + let fib12 = badfib 12 in + let fib12 = await fib12 in + fib12 = 144 + end +] + +let storage_tests = suite "storage" [ + test "get set" begin fun () -> + let k1 = Storage.new_key () in + let k2 = Storage.new_key () in + spawn @@ fun () -> + assert (Storage.get k1 = None); + assert (Storage.get k2 = None); + Storage.set k1 42; + assert (Storage.get k1 = Some 42); + assert (Storage.get k2 = None); + Storage.set k2 true; + assert (Storage.get k1 = Some 42); + assert (Storage.get k2 = Some true); + Storage.remove k1; + assert (Storage.get k1 = None); + assert (Storage.get k2 = Some true); + true + end; + + test "storage across await" begin fun () -> + let k = Storage.new_key () in + + (* spawn another promise that touches storage *) + let run_promise_async () = + Lwt.async @@ fun () -> + Lwt.with_value k (Some "something else") @@ fun () -> + assert (Lwt.get k = Some "something else"); + Lwt.return_unit + in + + let run_promise () : unit Lwt.t = + Lwt.with_value k (Some "another one") @@ fun () -> + assert (Lwt.get k = Some "another one"); + Lwt.return_unit + in + + let one_task () = + run_promise_async(); + assert (Storage.get k = None); + Storage.set k "v1"; + assert (Storage.get k = Some "v1"); + run_promise () |> await; + assert (Storage.get k = Some "v1"); + Storage.remove k; + assert (Storage.get k = None); + yield(); + assert (Storage.get k = None); + run_promise () |> await; + assert (Storage.get k = None); + run_promise_async(); + yield(); + assert (Storage.get k = None); + Storage.set k "v2"; + assert (Storage.get k = Some "v2"); + run_promise_async(); + yield(); + run_promise () |> await; + assert (Storage.get k = Some "v2"); + in + + (* spawn multiple such tasks *) + let tasks = [ spawn one_task; spawn one_task; spawn one_task ] in + + spawn @@ fun () -> + List.iter await tasks; + true + end; +] + +let io_tests = suite "io" [ + test "read io" begin fun () -> + let str = "some\ninteresting\ntext string here!\n" in + let ic = Lwt_io.of_bytes ~mode:Input (Lwt_bytes.of_string str) in + spawn @@ fun () -> + let lines = ref [] in + while + try + yield (); + let line = Lwt_io.read_line ic |> await in + lines := line :: !lines; + true + with End_of_file -> false + do () + done; + List.rev !lines = ["some"; "interesting"; "text string here!"] + end; + + test "pipe" begin fun () -> + let ic, oc = Lwt_io.pipe () in + spawn_in_the_background (fun () -> + for i = 1 to 100 do + Lwt_io.write_line oc (string_of_int i) |> await; + Lwt_io.flush oc |> await + done; + Lwt_io.close oc |> await; + ); + + spawn @@ fun () -> + let sum = ref 0 in + let continue = ref true in + while !continue do + match Lwt_io.read_line ic |> await |> String.trim |> int_of_string with + | exception End_of_file -> continue := false + | i -> sum := !sum + i + done; + Lwt_io.close ic |> await; + !sum = 5050 + end +] + +let suites = [ + main_tests; + storage_tests; + io_tests; +] diff --git a/test/multidomain/basic.ml b/test/multidomain/basic.ml new file mode 100644 index 000000000..a8b2edf9b --- /dev/null +++ b/test/multidomain/basic.ml @@ -0,0 +1,55 @@ +open Lwt.Syntax + +(* we don't call run in the root domain so we initialise by hand *) +let () = Lwt_unix.init_domain () + +let p_one, w_one = Lwt.wait () +let v_one = 3 +let p_two, w_two = Lwt.wait () +let v_two = 2 + +let d_mult = Domain.spawn (fun () -> + Lwt_unix.init_domain (); + (* domain one: wait for value from domain two then work and then send a value *) + Lwt_main.run ( + let* () = Lwt_unix.sleep 0.01 in + let* v_two = p_two in +(* Printf.printf "d%d received %d\n" (Domain.self () :> int) v_two; *) + let* () = Lwt_unix.sleep 0.1 in + Lwt.wakeup w_one v_one; +(* Printf.printf "d%d sent %d\n" (Domain.self () :> int) v_one; *) + Lwt.return (v_two * v_one) + ) +) +let d_sum = Domain.spawn (fun () -> + Lwt_unix.init_domain (); + Lwt_main.run ( + let () = + (* concurrent thread within domain "two" send a value and then work and + then wait for a value from domain one *) + Lwt.dont_wait (fun () -> + let* () = Lwt_unix.sleep 0.1 in +(* Printf.printf "d%d slept\n" (Domain.self () :> int); *) + Lwt.wakeup w_two v_two; +(* Printf.printf "d%d sent %d\n" (Domain.self () :> int) v_two; *) + Lwt.return () + ) + (fun _ -> exit 1) + in + let* v_one = p_one in + Lwt.return (v_two + v_one) + ) +) + + +let mult = Domain.join d_mult +let sum = Domain.join d_sum + +let () = + if mult = v_one * v_two && sum = v_one + v_two then begin + Printf.printf "basic: ✓\n"; + exit 0 + end else begin + Printf.printf "basic: ×\n"; + exit 1 + end diff --git a/test/multidomain/domainworkers.ml b/test/multidomain/domainworkers.ml new file mode 100644 index 000000000..1d04da1f6 --- /dev/null +++ b/test/multidomain/domainworkers.ml @@ -0,0 +1,74 @@ +open Lwt.Syntax + +let rec worker recv_task f send_result = + let* task = Lwt_stream.get recv_task in + match task with + | None -> +(* let () = Printf.printf "worker(%d) received interrupt\n" (Domain.self () :> int); flush_all() in *) + send_result None; + Lwt.return () + | Some data -> +(* let () = Printf.printf "worker(%d) received task (%S)\n" (Domain.self () :> int) data; flush_all() in *) + let* result = f data in + send_result (Some result); +(* let () = Printf.printf "worker(%d) sent result (%d)\n" (Domain.self () :> int) result; flush_all() in *) + let* () = Lwt.pause () in + worker recv_task f send_result + +let spawn_domain_worker f = + let recv_task, send_task = Lwt_stream.create () in + let recv_result, send_result = Lwt_stream.create () in + let dw = + Domain.spawn (fun () -> + Lwt_unix.init_domain (); + Lwt_main.run ( + let* () = Lwt.pause () in + worker recv_task f send_result + ) + ) + in + send_task, dw, recv_result + +let simulate_work data = + let simulated_work_duration = String.length data in + let* () = Lwt_unix.sleep (0.01 *. float_of_int simulated_work_duration) in + Lwt.return (String.length data) + +let input = [""; "adsf"; "lkjh"; "lkjahsdflkjahdlfkjha"; "0"; ""; ""; ""; ""; ""; "adf"; "ASDSKJLHDAS"; "WPOQIEU"; "DSFALKHJ"; ""; ""; ""; ""; "SD"; "SD"; "SAD; SD;SD"; "ad"; "...."] +let expected_result = List.fold_left (fun acc s -> acc + String.length s) 0 input + +let main () = + let send_task1, dw1, recv_result1 = spawn_domain_worker simulate_work in + let send_task2, dw2, recv_result2 = spawn_domain_worker simulate_work in + let l = + Lwt_unix.init_domain (); + Lwt_main.run ( + let* () = Lwt.pause () in + let () = (* push work *) + List.iteri + (fun idx s -> if idx mod 3 = 0 then send_task1 (Some s) else send_task2 (Some s)) + input + in + send_task1 None; + send_task2 None; + let* lengths1 = Lwt_stream.fold (+) recv_result1 0 + and* lengths2 = Lwt_stream.fold (+) recv_result2 0 + in + Lwt.return (lengths1 + lengths2) + ) + in + let () = Domain.join dw1 in + let () = Domain.join dw2 in + let code = + if l = expected_result then begin + Printf.printf "domain-workers: ✓\n"; + 0 + end else begin + Printf.printf "domain-workers: ×\n"; + 1 + end + in + flush_all (); + exit code + +let () = main () diff --git a/test/multidomain/dune b/test/multidomain/dune new file mode 100644 index 000000000..8c9811c74 --- /dev/null +++ b/test/multidomain/dune @@ -0,0 +1,3 @@ +(tests + (names basic domainworkers movingpromises unixpipe preempting) + (libraries lwt lwt.unix)) diff --git a/test/multidomain/movingpromises.ml b/test/multidomain/movingpromises.ml new file mode 100644 index 000000000..34d47f4d2 --- /dev/null +++ b/test/multidomain/movingpromises.ml @@ -0,0 +1,83 @@ +open Lwt.Syntax + +let rec worker ongoing_tasks recv_task f = + let* task = Lwt_stream.get recv_task in + match task with + | None -> +(* let () = Printf.printf "worker(%d) received interrupt\n" (Domain.self () :> int); flush_all() in *) + Lwt.join ongoing_tasks + | Some (_idx, data, resolver) -> + let task = +(* let () = Printf.printf "worker(%d) received task(%d)\n" (Domain.self () :> int) _idx; flush_all() in *) + let* data in +(* let () = Printf.printf "worker(%d) received task(%d) data(%S)\n" (Domain.self () :> int) _idx data; flush_all() in *) + let* result = f data in + Lwt.wakeup resolver result; +(* let () = Printf.printf "worker(%d) sent result(%d) for task(%d)\n" (Domain.self () :> int) result _idx; flush_all() in *) + Lwt.return () + in + let* () = Lwt.pause () in + worker (task :: ongoing_tasks) recv_task f + +let spawn_domain_worker f = + let recv_task, send_task = Lwt_stream.create () in + let dw = + Domain.spawn (fun () -> + Lwt_unix.init_domain (); + Lwt_main.run ( + let* () = Lwt.pause () in + worker [] recv_task f + ) + ) + in + send_task, dw + +let simulate_work data = + let simulated_work_duration = String.length data in + let* () = Lwt_unix.sleep (0.01 *. float_of_int simulated_work_duration) in + Lwt.return (String.length data) + +let simulate_input data = + let simulated_work_duration = max 1 (10 - String.length data) in + let* () = Lwt_unix.sleep (0.01 *. float_of_int simulated_work_duration) in + Lwt.return data + +let input = [""; "adsf"; "lkjh"; "lkjahsdflkjahdlfkjha"; "0"; ""; ""; ""; ""; ""; "adf"; "ASDSKJLHDAS"; "WPOQIEU"; "DSFALKHJ"; ""; ""; ""; ""; "SD"; "SD"; "SAD; SD;SD"; "ad"; "...."] +let expected_result = input |> List.map String.length |> List.map string_of_int |> String.concat "," + +let main () = + let send_task1, dw1 = spawn_domain_worker simulate_work in + let send_task2, dw2 = spawn_domain_worker simulate_work in + let l = + Lwt_unix.init_domain (); + Lwt_main.run ( + let* () = Lwt.pause () in + let inputs = List.map simulate_input + [""; "adsf"; "lkjh"; "lkjahsdflkjahdlfkjha"; "0"; ""; ""; ""; ""; ""; "adf"; "ASDSKJLHDAS"; "WPOQIEU"; "DSFALKHJ"; ""; ""; ""; ""; "SD"; "SD"; "SAD; SD;SD"; "ad"; "...."] + in + let* lengths = + Lwt_list.mapi_p + (fun idx s -> + let (p, r) = Lwt.task () in + begin if idx mod 3 = 0 then send_task1 (Some (idx, s, r)) else send_task2 (Some (idx, s, r)) end; + p) + inputs + in + let* () = Lwt.pause () in + send_task1 None; + send_task2 None; + let lengths = lengths |> List.map string_of_int |> String.concat "," in + Lwt.return lengths + ) + in + let () = Domain.join dw1 in + let () = Domain.join dw2 in + if l = expected_result then begin + Printf.printf "moving-promises: ✓\n"; + exit 0 + end else begin + Printf.printf "moving-promises: ×\n"; + exit 1 + end + +let () = main () diff --git a/test/multidomain/preempting.ml b/test/multidomain/preempting.ml new file mode 100644 index 000000000..1fd0ea2e5 --- /dev/null +++ b/test/multidomain/preempting.ml @@ -0,0 +1,53 @@ +open Lwt.Syntax + + +let input = ["adsf"; "lkjahsdflkjahdlfkjhaadslfhlasfdasdf"; "0"; ""; "ahlsdfjk"] +let simulate_work data = + let simulated_work_duration = String.length data in + let () = + (* each bit of work is blocking and will use preemptive *) + Unix.sleepf (0.001 *. float_of_int simulated_work_duration) + in + String.length data + +let () = Lwt_unix.init_domain () + +let domain_go_brrrrrrr input = Domain.spawn (fun () -> + flush_all (); + Lwt_unix.init_domain (); + let v = Lwt_main.run ( + let* () = Lwt.pause () in + (* detach blocking work *) + Lwt_list.map_p (Lwt_preemptive.detach simulate_work) input + ) + in + Lwt_preemptive.terminate_worker_threads (); + v +) + +let () = + let rec go acc = function + | [_] | [] -> + acc + | (_ :: more) as wrk -> + let expected = List.map String.length wrk in + let acc = (expected, domain_go_brrrrrrr wrk) :: acc in + go acc more + in + let results = go [] input in + let success = + List.for_all + (fun (expected, d) -> List.for_all2 Int.equal expected (Domain.join d)) + results + in + let code = + if success then begin + Printf.printf "preempting: ✓\n"; + 0 + end else begin + Printf.printf "preempting: ×\n"; + 1 + end + in + flush_all (); + exit code diff --git a/test/multidomain/unixpipe.ml b/test/multidomain/unixpipe.ml new file mode 100644 index 000000000..fceb0c5ae --- /dev/null +++ b/test/multidomain/unixpipe.ml @@ -0,0 +1,54 @@ +open Lwt.Syntax + +let () = + if not Sys.win32 then + let module _ = struct + let checks = Atomic.make 0 + + let () = Lwt_unix.init_domain () + + let write w s = + let b = Bytes.unsafe_of_string s in + let* l = Lwt_unix.write w b 0 (Bytes.length b) in + assert (l = Bytes.length b); + Lwt.return_unit + + let read r n = + let b = Bytes.create n in + let* l = Lwt_unix.read r b 0 n in + assert (l = n); + Lwt.return (Bytes.unsafe_to_string b) + + let rec run data w r = + let* () = Lwt.pause () in + match data with + | [] -> Lwt.return_unit + | datum::data -> + let* () = write w datum in + let* readed = read r (String.length datum) in + assert (datum = readed); + Atomic.incr checks; + run data w r + + let run_in_domain data w r = Domain.spawn (fun () -> Lwt_main.run (run data w r)) + + let (a_from_b, b_to_a) = Lwt_unix.pipe () + let (b_from_a, a_to_b) = Lwt_unix.pipe () + let data = [ "aaa"; "bbbb"; "alhskjdflkhjasdflkhjhjklasfdlhjksadxf" ] + + let a2b = run_in_domain data a_to_b a_from_b + let b2a = run_in_domain data b_to_a b_from_a + + let () = Domain.join a2b + let () = Domain.join b2a + let () = + if Atomic.get checks = 2 * List.length data then begin + Printf.printf "unixpipe: ✓\n"; + exit 0 + end else begin + Printf.printf "unixpipe: ×\n"; + exit 1 + end + end in + () + diff --git a/test/test.ml b/test/test.ml index bc18a36bb..7b8c5fe77 100644 --- a/test/test.ml +++ b/test/test.ml @@ -2,7 +2,6 @@ details, or visit https://github.com/ocsigen/lwt/blob/master/LICENSE.md. *) - type test = { test_name : string; skip_if_this_is_false : unit -> bool; @@ -265,6 +264,7 @@ let run library_name suites = end in + Lwt_unix.init_domain (); loop_over_suites [] suites |> Lwt_main.run @@ -338,6 +338,7 @@ let concurrent library_name suites = end let concurrent library_name suites = + Lwt_unix.init_domain (); Lwt_main.run (concurrent library_name suites) let with_async_exception_hook hook f = diff --git a/test/unix/main.ml b/test/unix/main.ml index 34d2d4983..7e36f99c3 100644 --- a/test/unix/main.ml +++ b/test/unix/main.ml @@ -1,12 +1,16 @@ (* This file is part of Lwt, released under the MIT license. See LICENSE.md for details, or visit https://github.com/ocsigen/lwt/blob/master/LICENSE.md. *) +let () = Lwt_unix.init_domain () + open Tester let () = Test.concurrent "unix" [ +(* Test_lwt_unix.suite; Test_lwt_io.suite; +*) Test_lwt_io_non_block.suite; Test_lwt_process.suite; Test_lwt_engine.suite; diff --git a/test/unix/test_lwt_bytes.ml b/test/unix/test_lwt_bytes.ml index 6de438b8d..fa6d328da 100644 --- a/test/unix/test_lwt_bytes.ml +++ b/test/unix/test_lwt_bytes.ml @@ -597,23 +597,6 @@ let suite = suite "lwt_bytes" [ Lwt.return check end; - test "read: buffer retention" ~sequential:true begin fun () -> - let buffer = Lwt_bytes.create 3 in - - let read_fd, write_fd = Lwt_unix.pipe ~cloexec:true () in - Lwt_unix.set_blocking read_fd true; - - Lwt_unix.write_string write_fd "foo" 0 3 >>= fun _ -> - - let retained = Lwt_unix.retained buffer in - Lwt_bytes.read read_fd buffer 0 3 >>= fun _ -> - - Lwt_unix.close write_fd >>= fun () -> - Lwt_unix.close read_fd >|= fun () -> - - !retained - end; - test "bytes write" begin fun () -> let test_file = "bytes_io_data_write" in Lwt_unix.openfile test_file [O_RDWR;O_TRUNC; O_CREAT] 0o666 @@ -634,21 +617,6 @@ let suite = suite "lwt_bytes" [ Lwt.return check end; - test "write: buffer retention" ~sequential:true begin fun () -> - let buffer = Lwt_bytes.create 3 in - - let read_fd, write_fd = Lwt_unix.pipe ~cloexec:true () in - Lwt_unix.set_blocking write_fd true; - - let retained = Lwt_unix.retained buffer in - Lwt_bytes.write write_fd buffer 0 3 >>= fun _ -> - - Lwt_unix.close write_fd >>= fun () -> - Lwt_unix.close read_fd >|= fun () -> - - !retained - end; - test "bytes recv" ~only_if:(fun () -> not Sys.win32) begin fun () -> let buf = gen_buf 6 in let server_logic socket = diff --git a/test/unix/test_lwt_unix.ml b/test/unix/test_lwt_unix.ml index a4e747aa3..22edc1ee0 100644 --- a/test/unix/test_lwt_unix.ml +++ b/test/unix/test_lwt_unix.ml @@ -6,6 +6,8 @@ open Test open Lwt.Infix +let domain_root_id = Domain.self () + (* An instance of the tester for the wait/waitpid tests. *) let () = match Sys.argv with @@ -451,12 +453,14 @@ let readv_tests = Lwt_unix.write_string write_fd "foo" 0 3 >>= fun _ -> - let retained = Lwt_unix.retained io_vectors in + let retained = ref true in + Gc.finalise (fun _ -> retained := false) io_vectors; Lwt_unix.readv read_fd io_vectors >>= fun _ -> Lwt_unix.close write_fd >>= fun () -> Lwt_unix.close read_fd >|= fun () -> + Gc.full_major (); !retained end; @@ -619,12 +623,14 @@ let writev_tests = let read_fd, write_fd = Lwt_unix.pipe ~cloexec:true () in Lwt_unix.set_blocking write_fd true; - let retained = Lwt_unix.retained io_vectors in + let retained = ref true in + Gc.finalise (fun _ -> retained := false) io_vectors; Lwt_unix.writev write_fd io_vectors >>= fun _ -> Lwt_unix.close write_fd >>= fun () -> Lwt_unix.close read_fd >|= fun () -> + Gc.full_major (); !retained end; @@ -1054,19 +1060,19 @@ let dir_tests = [ ] let lwt_preemptive_tests = [ - test "run_in_main" begin fun () -> + test "run_in_domain" begin fun () -> let f () = - Lwt_preemptive.run_in_main (fun () -> + Lwt_preemptive.run_in_domain domain_root_id (fun () -> Lwt_unix.sleep 0.01 >>= fun () -> Lwt.return 42) in Lwt_preemptive.detach f () >>= fun x -> Lwt.return (x = 42) end; - test "run_in_main_dont_wait" begin fun () -> + test "run_in_domain_dont_wait" begin fun () -> let p, r = Lwt.wait () in let f () = - Lwt_preemptive.run_in_main_dont_wait + Lwt_preemptive.run_in_domain_dont_wait domain_root_id (fun () -> Lwt.pause () >>= fun () -> Lwt.pause () >>= fun () -> @@ -1078,10 +1084,10 @@ let lwt_preemptive_tests = [ p >>= fun x -> Lwt.return (x = 42) end; - test "run_in_main_dont_wait_fail" begin fun () -> + test "run_in_domain_dont_wait_fail" begin fun () -> let p, r = Lwt.wait () in let f () = - Lwt_preemptive.run_in_main_dont_wait + Lwt_preemptive.run_in_domain_dont_wait domain_root_id (fun () -> Lwt.pause () >>= fun () -> Lwt.pause () >>= fun () -> @@ -1092,10 +1098,10 @@ let lwt_preemptive_tests = [ p >>= fun x -> Lwt.return (x = 45) end; - test "run_in_main_with_dont_wait" begin fun () -> + test "run_in_domain_with_dont_wait" begin fun () -> let p, r = Lwt.wait () in let f () = - Lwt_preemptive.run_in_main (fun () -> + Lwt_preemptive.run_in_domain domain_root_id (fun () -> Lwt.dont_wait (fun () -> Lwt.pause () >>= fun () ->