From 96a4e55fcf9ca194157605188ef64df9a8bfae14 Mon Sep 17 00:00:00 2001 From: Robert Leahy Date: Tue, 23 Jun 2026 06:56:57 -0400 Subject: [PATCH] when_all: Conditional Stop Source This commit implements P4217R1 and P4269R0 by making the following changes to when_all: - when_all() is now expression-equivalent to just() - when_all(s) is now expression-equivalent to auto(s) (i.e. it decay- copies s) - The operation state for when_all(s0, ..., sN) now only contains a stop source if: - Any of s0, ..., sN can complete unsuccessfully in the ambient environment, or - Decay-copying the values sent by any of s0, ... sN can throw an exception The last bullet prevents when_all from "hallucinating" set_stopped: Previously when_all unconditionally managed a stop source, and injected stop tokens therefrom into its children. Consider when_all(s0, s1) where s0 and s1: - Have only set_value_t() as a completion signature in the ambient environment, and - Have set_value_t() and set_stopped_t() as completion signatures in an environment with an inplace_stop_token when_all(s0, s1) plainly does not need a stop source because neither s0 nor s1 will ever complete unsuccessfully (see bullet one above). However when_all(s0, s1) would previously create a stop source anyway, and inject tokens therefrom into s0 and s1. This would cause s0 and s1 to report an additional stopped completion (see bullet two above) which would cause when_all to report set_value_t() and set_stopped_t() as completion signatures even though the latter is never actually emitted (because s0 and s1 don't fail except by stop request, but the only way for a stop request to be emitted is for either of them to fail). --- include/stdexec/__detail/__when_all.hpp | 180 ++++++++++++------ test/stdexec/algos/adaptors/test_when_all.cpp | 136 +++++++++++++ 2 files changed, 256 insertions(+), 60 deletions(-) diff --git a/include/stdexec/__detail/__when_all.hpp b/include/stdexec/__detail/__when_all.hpp index ea6844419..6f5be1edf 100644 --- a/include/stdexec/__detail/__when_all.hpp +++ b/include/stdexec/__detail/__when_all.hpp @@ -51,17 +51,15 @@ namespace STDEXEC //! their value datums. //! //! @c when_all is the canonical *parallel composition* primitive in the - //! sender model. You give it one or more senders; it returns a single + //! sender model. You give it zero or more senders; it returns a single //! sender that, when connected and started, starts *all* of the input //! senders concurrently. When every input has completed, @c when_all's //! sender completes with a value tuple that is the concatenation of every //! input's value datums. //! - //! If any one input fails or is stopped, @c when_all requests stop on the - //! others (via an internal @c inplace_stop_source) and completes with - //! that error (or with @c set_stopped). This makes @c when_all naturally - //! fail-fast: as soon as one branch has gone bad, the rest are asked to - //! wind down. + //! If any one input can fail or be stopped, @c when_all uses an internal + //! @c inplace_stop_source. An unhappy completion requests stop on the + //! other inputs before the combined operation completes. //! //! @code{.cpp} //! auto s = stdexec::when_all( @@ -101,8 +99,7 @@ namespace STDEXEC //! set_value_t(V1..., V2..., ..., Vn...) // concatenation of every input //! set_error_t(Eij)... // union across all inputs //! set_error_t(std::exception_ptr) // added if any decay-copy may throw - //! set_stopped_t() // added if any input has it, - //! // or if cancellation may happen + //! set_stopped_t() // added if any input has it //! @endcode //! //! The value datums of each input are decay-copied into the resulting @@ -163,19 +160,34 @@ namespace STDEXEC //! input has completed. //! //! @tparam _Senders A pack of types each satisfying @c stdexec::sender. - //! Must be non-empty. Each must have exactly one - //! @c set_value_t completion signature in the - //! ambient environment. + //! Each must have exactly one @c set_value_t completion + //! signature in the ambient environment. //! //! @param __sndrs The senders to compose. Forwarded into the result. //! - //! @returns A sender that, when connected and started, concurrently - //! starts every input and value-completes with the - //! concatenation of the input's value datums. - template - constexpr auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto + //! @returns @c just() for no inputs, the input sender for one input, or a + //! sender that concurrently starts every input and concatenates + //! their value datums for two or more inputs. + constexpr auto operator()() const noexcept + { + return just(); + } + + template + constexpr auto operator()(_Sender&& __sndr) const noexcept(__nothrow_decay_copyable<_Sender>) { - return __make_sexpr(__(), static_cast<_Senders&&>(__sndrs)...); + return static_cast<_Sender&&>(__sndr); + } + + template + constexpr auto operator()(_Sender0&& __sndr0, _Sender1&& __sndr1, _Senders&&... __sndrs) const + noexcept(__nothrow_decay_copyable<_Sender0, _Sender1, _Senders...>) -> __well_formed_sender + auto + { + return __make_sexpr(__(), + static_cast<_Sender0&&>(__sndr0), + static_cast<_Sender1&&>(__sndr1), + static_cast<_Senders&&>(__sndrs)...); } }; @@ -394,8 +406,8 @@ namespace STDEXEC } template - using __env_t = decltype(__when_all::__mk_env(__declval<_Env>(), - __declval())); + using __stoppable_env_t = decltype(__when_all::__mk_env(__declval<_Env>(), + __declval())); template concept __max1_sender = @@ -423,6 +435,17 @@ namespace STDEXEC using __nothrow_decay_copyable_results_t = STDEXEC::__nothrow_decay_copyable_results_t<__completion_signatures_of_t<_Sender, _Env...>>; + template + inline constexpr bool __can_fail = !__never_sends + || sends_stopped<_Sender, _Env> + || !__nothrow_decay_copyable_results_t<_Sender, _Env>::value; + + template + inline constexpr bool __uses_stop_source = (__can_fail<_Senders, _Env> || ...); + + template + using __env_t = __if_c<__uses_stop_source<_Env, _Senders...>, __stoppable_env_t<_Env>, _Env>; + template struct __completions { @@ -460,6 +483,23 @@ namespace STDEXEC __concat_completion_signatures_t>...>; }; + template + struct __completions_for; + + template <> + struct __completions_for<> + { + template + using __f = __completions<>::template __f<_Senders...>; + }; + + template + struct __completions_for<_Env> + { + template + using __f = __completions<__env_t<_Env, _Senders...>>::template __f<_Senders...>; + }; + template constexpr void __set_values(_Receiver& __rcvr, _ValuesTuple& __values) noexcept { @@ -472,29 +512,33 @@ namespace STDEXEC static_cast<_ValuesTuple&&>(__values)); } - template - using __values_opt_tuple_t = - value_types_of_t<_Sender, __env_t<_Env>, __decayed_tuple, __optional>; + template + using __values_opt_tuple_t = value_types_of_t<_Sender, _ChildEnv, __decayed_tuple, __optional>; - template >... _Senders> + template + requires(__max1_sender<_Senders, __env_t<_Env, _Senders...>> && ...) struct __traits { + using __child_env = __env_t<_Env, _Senders...>; + // tuple>, optional>, ...> - using __values_tuple = __minvoke< - __mwith_default<__mtransform<__mbind_front_q<__values_opt_tuple_t, _Env>, __q<__tuple>>, - __ignore>, - _Senders...>; + using __values_tuple = + __minvoke<__mwith_default< + __mtransform<__mbind_front_q<__values_opt_tuple_t, __child_env>, __q<__tuple>>, + __ignore>, + _Senders...>; using __collect_errors = __mbind_front_q<__mset_insert, __mset<>>; using __errors_list = __minvoke<__mconcat<>, - __if<__mand<__nothrow_decay_copyable_results_t<_Senders, _Env>...>, + __if<__mand<__nothrow_decay_copyable_results_t<_Senders, __child_env>...>, __mlist<>, __mlist>, - __error_types_of_t<_Senders, __env_t<_Env>, __q<__mlist>>...>; + __error_types_of_t<_Senders, __child_env, __q<__mlist>>...>; - using __errors_variant = __mapply<__q<__uniqued_variant>, __errors_list>; + using __errors_variant = __mapply<__q<__uniqued_variant>, __errors_list>; + static constexpr bool __uses_stop_source = __when_all::__uses_stop_source<_Env, _Senders...>; }; struct _INVALID_ARGUMENTS_TO_WHEN_ALL_ @@ -515,7 +559,10 @@ namespace STDEXEC // error state, which trumps cancellation.) if (__state_->__state_.compare_exchange_strong(__expected, __stopped)) { - __state_->__stop_source_.request_stop(); + if constexpr (_State::__uses_stop_source) + { + __state_->__stop_source_.request_stop(); + } } // Arrive in order to decrement the count again and complete if needed. @@ -525,12 +572,19 @@ namespace STDEXEC _State* __state_; }; - template + template struct __state { - using __receiver_t = _Receiver; + using __receiver_t = _Receiver; + static constexpr bool __uses_stop_source = _UsesStopSource; using __stop_callback_t = stop_callback_for_t>, __forward_stop_request<__state>>; + using __stop_source_t = __if_c<_UsesStopSource, inplace_stop_source, __empty>; + using __on_stop_t = __if_c<_UsesStopSource, __optional<__stop_callback_t>, __empty>; constexpr void __arrive() noexcept { @@ -543,7 +597,10 @@ namespace STDEXEC constexpr void __complete() noexcept { // Stop callback is no longer needed. Destroy it. - __on_stop_.reset(); + if constexpr (_UsesStopSource) + { + __on_stop_.reset(); + } // All child operations have completed and arrived at the barrier. switch (__state_.load(__std::memory_order_relaxed)) { @@ -579,13 +636,15 @@ namespace STDEXEC STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS _Receiver __rcvr_; __std::atomic __count_; - inplace_stop_source __stop_source_{}; + STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS + __stop_source_t __stop_source_{}; // Could be non-atomic here and atomic_ref everywhere except __completion_fn - __std::atomic<__state_t> __state_{__started}; - _ErrorsVariant __errors_{__no_init}; + __std::atomic<__state_t> __state_{__started}; + _ErrorsVariant __errors_{__no_init}; + STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS + _ValuesTuple __values_{}; STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS - _ValuesTuple __values_{}; - __optional<__stop_callback_t> __on_stop_{}; + __on_stop_t __on_stop_{}; }; template @@ -622,18 +681,12 @@ namespace STDEXEC } }; - // A when_all with no senders completes inline with no values. template <> struct __attrs<> { + template [[nodiscard]] - constexpr auto query(__get_completion_behavior_t) const noexcept - { - return __completion_behavior::__inline_completion; - } - - [[nodiscard]] - constexpr auto query(__get_completion_behavior_t) const noexcept + constexpr auto query(__get_completion_behavior_t<_Tag>) const noexcept { return __completion_behavior::__inline_completion; } @@ -642,17 +695,18 @@ namespace STDEXEC template static constexpr auto __mk_state_fn(_Receiver&& __rcvr) noexcept { - return [&]<__max1_sender<__env_t>>... _Child>(__ignore, - __ignore, - _Child&&...) noexcept + return [&](__ignore, __ignore, _Child&&...) noexcept + requires(__max1_sender<_Child, __env_t, _Child...>> && ...) { using _Traits = __traits, _Child...>; using _ErrorsVariant = _Traits::__errors_variant; using _ValuesTuple = _Traits::__values_tuple; + using _ChildEnv = _Traits::__child_env; using _State = __state<_ErrorsVariant, _ValuesTuple, _Receiver, - (sends_stopped<_Child, env_of_t<_Receiver>> || ...)>; + (sends_stopped<_Child, _ChildEnv> || ...), + _Traits::__uses_stop_source>; return _State{static_cast<_Receiver&&>(__rcvr), sizeof...(_Child)}; }; } @@ -663,7 +717,7 @@ namespace STDEXEC struct __when_all_impl : __sexpr_defaults { template - using __completions_t = __children_of<_Self, __when_all::__completions<__env_t<_Env>...>>; + using __completions_t = __children_of<_Self, __when_all::__completions_for<_Env...>>; static constexpr auto __get_attrs = [](__ignore, __ignore, _Child const &...) noexcept @@ -694,9 +748,15 @@ namespace STDEXEC } static constexpr auto __get_env = [](__ignore, _State const & __state) noexcept - -> __env_t> { - return __when_all::__mk_env(STDEXEC::get_env(__state.__rcvr_), __state.__stop_source_); + if constexpr (_State::__uses_stop_source) + { + return __when_all::__mk_env(STDEXEC::get_env(__state.__rcvr_), __state.__stop_source_); + } + else + { + return STDEXEC::get_env(__state.__rcvr_); + } }; static constexpr auto __get_state = @@ -711,19 +771,18 @@ namespace STDEXEC [](_State& __state, _Operations&... __child_ops) noexcept -> void { - // register stop callback: - __state.__on_stop_.emplace(get_stop_token(STDEXEC::get_env(__state.__rcvr_)), - __forward_stop_request<_State>{&__state}); - (STDEXEC::start(__child_ops), ...); - if constexpr (sizeof...(__child_ops) == 0) + if constexpr (_State::__uses_stop_source) { - __state.__complete(); + __state.__on_stop_.emplace(get_stop_token(STDEXEC::get_env(__state.__rcvr_)), + __forward_stop_request<_State>{&__state}); } + (STDEXEC::start(__child_ops), ...); }; template static constexpr void __set_error(_State& __state, _Error&& __err) noexcept { + static_assert(_State::__uses_stop_source); // Transition to the "error" state and switch on the prior state. // TODO: What memory orderings are actually needed here? switch (__state.__state_.exchange(__error)) @@ -769,6 +828,7 @@ namespace STDEXEC } else if constexpr (__same_as<_Set, set_stopped_t>) { + static_assert(_State::__uses_stop_source); __state_t __expected = __started; // Transition to the "stopped" state if and only if we're in the // "started" state. (If this fails, it's because we're in an diff --git a/test/stdexec/algos/adaptors/test_when_all.cpp b/test/stdexec/algos/adaptors/test_when_all.cpp index a5f55df97..bd64442d3 100644 --- a/test/stdexec/algos/adaptors/test_when_all.cpp +++ b/test/stdexec/algos/adaptors/test_when_all.cpp @@ -33,6 +33,40 @@ namespace ex = STDEXEC; namespace { + struct stop_sensitive_sender + { + using sender_concept = ex::sender_tag; + + template + static consteval auto get_completion_signatures() + { + if constexpr (ex::unstoppable_token>) + { + return ex::completion_signatures{}; + } + else + { + return ex::completion_signatures{}; + } + } + + template + struct operation + { + Receiver receiver_; + + void start() & noexcept + { + ex::set_value(std::move(receiver_)); + } + }; + + template + auto connect(Receiver receiver) const noexcept -> operation + { + return {std::move(receiver)}; + } + }; TEST_CASE("when_all returns a sender", "[adaptors][when_all]") { @@ -42,6 +76,34 @@ namespace (void) snd; } + TEST_CASE("when_all coalesces empty and unary calls", "[adaptors][when_all]") + { + using empty_t = decltype(ex::when_all()); + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(!exec::sender_for); + STATIC_REQUIRE(noexcept(ex::when_all())); + + auto child = ex::just(42); + using unary_t = decltype(ex::when_all(child)); + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(!exec::sender_for); + STATIC_REQUIRE(noexcept(ex::when_all(child))); + + auto multi = ex::when_all(ex::just(1), ex::just(2)); + STATIC_REQUIRE(exec::sender_for); + } + + TEST_CASE("unary when_all preserves every completion channel", "[adaptors][when_all]") + { + wait_for_value(ex::when_all(ex::just(42)), 42); + + auto error_op = ex::connect(ex::when_all(ex::just_error(42)), expect_error_receiver{42}); + ex::start(error_op); + + auto stopped_op = ex::connect(ex::when_all(ex::just_stopped()), expect_stopped_receiver{}); + ex::start(stopped_op); + } + TEST_CASE("when_all with environment returns a sender", "[adaptors][when_all]") { auto snd = ex::when_all(ex::just(3), ex::just(0.1415)); @@ -453,6 +515,80 @@ namespace ex::start(op); } + TEST_CASE("infallible when_all children retain the receiver stop token", "[adaptors][when_all]") + { + auto observes_stop_possible = ex::read_env(ex::get_stop_token) + | ex::then([](auto token) noexcept + { return token.stop_possible(); }); + + auto unstoppable = ex::when_all(observes_stop_possible, observes_stop_possible); + static_assert(set_equivalent>, + ex::completion_signatures>); + wait_for_value(std::move(unstoppable), false, false); + + ex::inplace_stop_source source; + auto stoppable = ex::when_all(observes_stop_possible, observes_stop_possible); + auto env = ex::prop(ex::get_stop_token, source.get_token()); + static_assert(set_equivalent, + ex::completion_signatures>); + auto op = ex::connect(std::move(stoppable), expect_value_receiver{env_tag{}, env, true, true}); + ex::start(op); + } + + TEST_CASE("when_all publishes environment-sensitive completion signatures", + "[adaptors][when_all]") + { + auto snd = ex::when_all(stop_sensitive_sender{}, stop_sensitive_sender{}); + static_assert(set_equivalent>, + ex::completion_signatures>); + + ex::inplace_stop_source source; + auto env = ex::prop(ex::get_stop_token, source.get_token()); + static_assert( + set_equivalent, + ex::completion_signatures>); + } + + TEST_CASE("when_all publishes storage failure as exception_ptr", "[adaptors][when_all]") + { + auto snd = ex::when_all(ex::just(), ex::just(potentially_throwing{})); + static_assert(set_equivalent>, + ex::completion_signatures>); + + ex::inplace_stop_source source; + auto env = ex::prop(ex::get_stop_token, source.get_token()); + static_assert(set_equivalent, + ex::completion_signatures>); + } + + TEST_CASE("fallible when_all children receive an internal stop token", "[adaptors][when_all]") + { + bool observed_stop_possible = false; + auto observer = ex::read_env(ex::get_stop_token) + | ex::then([&](auto token) noexcept + { observed_stop_possible = token.stop_possible(); }); + auto snd = ex::when_all(std::move(observer), ex::just_error(42)); + auto op = ex::connect(std::move(snd), expect_error_receiver{42}); + ex::start(op); + CHECK(observed_stop_possible); + } + + TEST_CASE("potentially throwing when_all result storage uses an internal stop token", + "[adaptors][when_all]") + { + bool observed_stop_possible = false; + auto observer = ex::read_env(ex::get_stop_token) + | ex::then([&](auto token) noexcept + { observed_stop_possible = token.stop_possible(); }); + auto snd = ex::when_all(std::move(observer), ex::just(potentially_throwing{})) + | ex::then([](potentially_throwing) noexcept {}); + auto op = ex::connect(std::move(snd), expect_void_receiver{}); + ex::start(op); + CHECK(observed_stop_possible); + } + TEST_CASE("when_all handles stop requests from the environment correctly", "[adaptors][when_all]") { auto snd = ex::when_all(completes_if(false), completes_if(false));