Line data Source code
1 : //
2 : // Copyright (c) 2026 Steve Gerbino
3 : //
4 : // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 : // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 : //
7 : // Official repository: https://github.com/cppalliance/capy
8 : //
9 :
10 : #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 : #define BOOST_CAPY_WHEN_ALL_HPP
12 :
13 : #include <boost/capy/detail/config.hpp>
14 : #include <boost/capy/concept/executor.hpp>
15 : #include <boost/capy/concept/io_awaitable.hpp>
16 : #include <boost/capy/ex/any_coro.hpp>
17 : #include <boost/capy/ex/any_executor_ref.hpp>
18 : #include <boost/capy/ex/frame_allocator.hpp>
19 : #include <boost/capy/task.hpp>
20 :
21 : #include <array>
22 : #include <atomic>
23 : #include <exception>
24 : #include <optional>
25 : #if BOOST_CAPY_HAS_STOP_TOKEN
26 : #include <stop_token>
27 : #endif
28 : #include <tuple>
29 : #include <type_traits>
30 : #include <utility>
31 :
32 : namespace boost {
33 : namespace capy {
34 :
35 : namespace detail {
36 :
37 : /** Type trait to filter void types from a tuple.
38 :
39 : Void-returning tasks do not contribute a value to the result tuple.
40 : This trait computes the filtered result type.
41 :
42 : Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
43 : */
44 : template<typename T>
45 : using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
46 :
47 : template<typename... Ts>
48 : using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
49 :
50 : /** Holds the result of a single task within when_all.
51 : */
52 : template<typename T>
53 : struct result_holder
54 : {
55 : std::optional<T> value_;
56 :
57 45 : void set(T v)
58 : {
59 45 : value_ = std::move(v);
60 45 : }
61 :
62 38 : T get() &&
63 : {
64 38 : return std::move(*value_);
65 : }
66 : };
67 :
68 : /** Specialization for void tasks - no value storage needed.
69 : */
70 : template<>
71 : struct result_holder<void>
72 : {
73 : };
74 :
75 : /** Shared state for when_all operation.
76 :
77 : @tparam Ts The result types of the tasks.
78 : */
79 : template<typename... Ts>
80 : struct when_all_state
81 : {
82 : static constexpr std::size_t task_count = sizeof...(Ts);
83 :
84 : // Completion tracking - when_all waits for all children
85 : std::atomic<std::size_t> remaining_count_;
86 :
87 : // Result storage in input order
88 : std::tuple<result_holder<Ts>...> results_;
89 :
90 : // Runner handles - destroyed in await_resume while allocator is valid
91 : std::array<any_coro, task_count> runner_handles_{};
92 :
93 : // Exception storage - first error wins, others discarded
94 : std::atomic<bool> has_exception_{false};
95 : std::exception_ptr first_exception_;
96 :
97 : #if BOOST_CAPY_HAS_STOP_TOKEN
98 : // Stop propagation - on error, request stop for siblings
99 : std::stop_source stop_source_;
100 :
101 : // Connects parent's stop_token to our stop_source
102 : struct stop_callback_fn
103 : {
104 : std::stop_source* source_;
105 1 : void operator()() const { source_->request_stop(); }
106 : };
107 : using stop_callback_t = std::stop_callback<stop_callback_fn>;
108 : std::optional<stop_callback_t> parent_stop_callback_;
109 : #endif
110 :
111 : // Parent resumption
112 : any_coro continuation_;
113 : any_executor_ref caller_ex_;
114 :
115 24 : when_all_state()
116 24 : : remaining_count_(task_count)
117 : {
118 24 : }
119 :
120 24 : ~when_all_state()
121 : {
122 85 : for(auto h : runner_handles_)
123 61 : if(h)
124 61 : h.destroy();
125 24 : }
126 :
127 : /** Capture an exception (first one wins).
128 : */
129 11 : void capture_exception(std::exception_ptr ep)
130 : {
131 11 : bool expected = false;
132 11 : if(has_exception_.compare_exchange_strong(
133 : expected, true, std::memory_order_relaxed))
134 8 : first_exception_ = ep;
135 11 : }
136 :
137 : /** Signal that a task has completed.
138 :
139 : The last child to complete triggers resumption of the parent.
140 : */
141 61 : any_coro signal_completion()
142 : {
143 61 : auto remaining = remaining_count_.fetch_sub(1, std::memory_order_acq_rel);
144 61 : if(remaining == 1)
145 24 : return caller_ex_.dispatch(continuation_);
146 37 : return std::noop_coroutine();
147 : }
148 :
149 : };
150 :
151 : /** Wrapper coroutine that intercepts task completion.
152 :
153 : This runner awaits its assigned task and stores the result in
154 : the shared state, or captures the exception and requests stop.
155 : */
156 : template<typename T, typename... Ts>
157 : struct when_all_runner
158 : {
159 : struct promise_type : frame_allocating_base
160 : {
161 : when_all_state<Ts...>* state_ = nullptr;
162 : any_executor_ref ex_;
163 : #if BOOST_CAPY_HAS_STOP_TOKEN
164 : std::stop_token stop_token_;
165 : #endif
166 :
167 61 : when_all_runner get_return_object()
168 : {
169 61 : return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
170 : }
171 :
172 61 : std::suspend_always initial_suspend() noexcept
173 : {
174 61 : return {};
175 : }
176 :
177 61 : auto final_suspend() noexcept
178 : {
179 : struct awaiter
180 : {
181 : promise_type* p_;
182 :
183 61 : bool await_ready() const noexcept
184 : {
185 61 : return false;
186 : }
187 :
188 61 : any_coro await_suspend(any_coro) noexcept
189 : {
190 : // Signal completion; last task resumes parent
191 61 : return p_->state_->signal_completion();
192 : }
193 :
194 0 : void await_resume() const noexcept
195 : {
196 0 : }
197 : };
198 61 : return awaiter{this};
199 : }
200 :
201 50 : void return_void()
202 : {
203 50 : }
204 :
205 11 : void unhandled_exception()
206 : {
207 11 : state_->capture_exception(std::current_exception());
208 : #if BOOST_CAPY_HAS_STOP_TOKEN
209 : // Request stop for sibling tasks
210 11 : state_->stop_source_.request_stop();
211 : #endif
212 11 : }
213 :
214 : template<class Awaitable>
215 : struct transform_awaiter
216 : {
217 : std::decay_t<Awaitable> a_;
218 : promise_type* p_;
219 :
220 61 : bool await_ready()
221 : {
222 61 : return a_.await_ready();
223 : }
224 :
225 61 : auto await_resume()
226 : {
227 61 : return a_.await_resume();
228 : }
229 :
230 : template<class Promise>
231 61 : auto await_suspend(std::coroutine_handle<Promise> h)
232 : {
233 : #if BOOST_CAPY_HAS_STOP_TOKEN
234 61 : return a_.await_suspend(h, p_->ex_, p_->stop_token_);
235 : #else
236 : return a_.await_suspend(h, p_->ex_, std::stop_token{});
237 : #endif
238 : }
239 : };
240 :
241 : template<class Awaitable>
242 61 : auto await_transform(Awaitable&& a)
243 : {
244 : using A = std::decay_t<Awaitable>;
245 : if constexpr (IoAwaitable<A, any_executor_ref>)
246 : {
247 : return transform_awaiter<Awaitable>{
248 122 : std::forward<Awaitable>(a), this};
249 : }
250 : else
251 : {
252 : return make_affine(std::forward<Awaitable>(a), ex_);
253 : }
254 61 : }
255 : };
256 :
257 : std::coroutine_handle<promise_type> h_;
258 :
259 61 : explicit when_all_runner(std::coroutine_handle<promise_type> h)
260 61 : : h_(h)
261 : {
262 61 : }
263 :
264 : #if defined(__clang__) && __clang_major__ == 14 && !defined(__apple_build_version__)
265 : // Clang 14 has a bug where it calls the move constructor for coroutine
266 : // return objects even though they should be constructed in-place via RVO.
267 : // This happens when returning a non-movable type from a coroutine.
268 : when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
269 : #endif
270 :
271 : // Non-copyable, non-movable - release() is always called immediately
272 : when_all_runner(when_all_runner const&) = delete;
273 : when_all_runner& operator=(when_all_runner const&) = delete;
274 :
275 : #if !defined(__clang__) || __clang_major__ != 14 || defined(__apple_build_version__)
276 : when_all_runner(when_all_runner&&) = delete;
277 : #endif
278 :
279 : when_all_runner& operator=(when_all_runner&&) = delete;
280 :
281 61 : auto release() noexcept
282 : {
283 61 : return std::exchange(h_, nullptr);
284 : }
285 : };
286 :
287 : /** Create a runner coroutine for a single task.
288 :
289 : Task is passed directly to ensure proper coroutine frame storage.
290 : */
291 : template<std::size_t Index, typename T, typename... Ts>
292 : when_all_runner<T, Ts...>
293 61 : make_when_all_runner(task<T> inner, when_all_state<Ts...>* state)
294 : {
295 : if constexpr (std::is_void_v<T>)
296 : co_await std::move(inner);
297 : else
298 : std::get<Index>(state->results_).set(co_await std::move(inner));
299 122 : }
300 :
301 : /** Internal awaitable that launches all runner coroutines and waits.
302 :
303 : This awaitable is used inside the when_all coroutine to handle
304 : the concurrent execution of child tasks.
305 : */
306 : template<typename... Ts>
307 : class when_all_launcher
308 : {
309 : std::tuple<task<Ts>...>* tasks_;
310 : when_all_state<Ts...>* state_;
311 :
312 : public:
313 24 : when_all_launcher(
314 : std::tuple<task<Ts>...>* tasks,
315 : when_all_state<Ts...>* state)
316 24 : : tasks_(tasks)
317 24 : , state_(state)
318 : {
319 24 : }
320 :
321 24 : bool await_ready() const noexcept
322 : {
323 24 : return sizeof...(Ts) == 0;
324 : }
325 :
326 : #if BOOST_CAPY_HAS_STOP_TOKEN
327 : template<typename Ex>
328 24 : any_coro await_suspend(any_coro continuation, Ex const& caller_ex, std::stop_token parent_token = {})
329 : {
330 24 : state_->continuation_ = continuation;
331 24 : state_->caller_ex_ = caller_ex;
332 :
333 : // Forward parent's stop requests to children
334 24 : if(parent_token.stop_possible())
335 : {
336 8 : state_->parent_stop_callback_.emplace(
337 : parent_token,
338 4 : typename when_all_state<Ts...>::stop_callback_fn{&state_->stop_source_});
339 :
340 4 : if(parent_token.stop_requested())
341 1 : state_->stop_source_.request_stop();
342 : }
343 :
344 : // Launch all tasks concurrently
345 24 : auto token = state_->stop_source_.get_token();
346 48 : [&]<std::size_t... Is>(std::index_sequence<Is...>) {
347 24 : (..., launch_one<Is>(caller_ex, token));
348 24 : }(std::index_sequence_for<Ts...>{});
349 :
350 : // Let signal_completion() handle resumption
351 48 : return std::noop_coroutine();
352 24 : }
353 : #else
354 : template<typename Ex>
355 : any_coro await_suspend(any_coro continuation, Ex const& caller_ex)
356 : {
357 : state_->continuation_ = continuation;
358 : state_->caller_ex_ = caller_ex;
359 :
360 : // Launch all tasks concurrently
361 : [&]<std::size_t... Is>(std::index_sequence<Is...>) {
362 : (..., launch_one<Is>(caller_ex));
363 : }(std::index_sequence_for<Ts...>{});
364 :
365 : // Let signal_completion() handle resumption
366 : return std::noop_coroutine();
367 : }
368 : #endif
369 :
370 24 : void await_resume() const noexcept
371 : {
372 : // Results are extracted by the when_all coroutine from state
373 24 : }
374 :
375 : private:
376 : #if BOOST_CAPY_HAS_STOP_TOKEN
377 : template<std::size_t I, typename Ex>
378 61 : void launch_one(Ex const& caller_ex, std::stop_token token)
379 : {
380 61 : auto runner = make_when_all_runner<I>(
381 61 : std::move(std::get<I>(*tasks_)), state_);
382 :
383 61 : auto h = runner.release();
384 61 : h.promise().state_ = state_;
385 61 : h.promise().ex_ = caller_ex;
386 61 : h.promise().stop_token_ = token;
387 :
388 61 : any_coro ch{h};
389 61 : state_->runner_handles_[I] = ch;
390 61 : caller_ex.dispatch(ch).resume();
391 61 : }
392 : #else
393 : template<std::size_t I, typename Ex>
394 : void launch_one(Ex const& caller_ex)
395 : {
396 : auto runner = make_when_all_runner<I>(
397 : std::move(std::get<I>(*tasks_)), state_);
398 :
399 : auto h = runner.release();
400 : h.promise().state_ = state_;
401 : h.promise().ex_ = caller_ex;
402 :
403 : any_coro ch{h};
404 : state_->runner_handles_[I] = ch;
405 : caller_ex.dispatch(ch).resume();
406 : }
407 : #endif
408 : };
409 :
410 : /** Compute the result type for when_all.
411 :
412 : Returns void when all tasks are void (P2300 aligned),
413 : otherwise returns a tuple with void types filtered out.
414 : */
415 : template<typename... Ts>
416 : using when_all_result_t = std::conditional_t<
417 : std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
418 : void,
419 : filter_void_tuple_t<Ts...>>;
420 :
421 : /** Helper to extract a single result, returning empty tuple for void.
422 : This is a separate function to work around a GCC-11 ICE that occurs
423 : when using nested immediately-invoked lambdas with pack expansion.
424 : */
425 : template<std::size_t I, typename... Ts>
426 40 : auto extract_single_result(when_all_state<Ts...>& state)
427 : {
428 : using T = std::tuple_element_t<I, std::tuple<Ts...>>;
429 : if constexpr (std::is_void_v<T>)
430 2 : return std::tuple<>();
431 : else
432 38 : return std::make_tuple(std::move(std::get<I>(state.results_)).get());
433 : }
434 :
435 : /** Extract results from state, filtering void types.
436 : */
437 : template<typename... Ts>
438 15 : auto extract_results(when_all_state<Ts...>& state)
439 : {
440 30 : return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
441 15 : return std::tuple_cat(extract_single_result<Is>(state)...);
442 30 : }(std::index_sequence_for<Ts...>{});
443 : }
444 :
445 : } // namespace detail
446 :
447 : /** Wait for all tasks to complete concurrently.
448 :
449 : @par Example
450 : @code
451 : task<void> example() {
452 : auto [a, b] = co_await when_all(
453 : fetch_int(), // task<int>
454 : fetch_string() // task<std::string>
455 : );
456 : }
457 : @endcode
458 :
459 : @param tasks The tasks to execute concurrently.
460 : @return A task yielding a tuple of results (void types filtered out).
461 :
462 : Key features:
463 : @li All child tasks are launched concurrently
464 : @li Results are collected in input order
465 : @li First error is captured; subsequent errors are discarded
466 : @li On error, stop is requested for all siblings
467 : @li Completes only after all children have completed
468 : @li Void tasks do not contribute to the result tuple
469 : @li Properly propagates frame allocators to all child coroutines
470 : */
471 : template<typename... Ts>
472 : [[nodiscard]] task<detail::when_all_result_t<Ts...>>
473 24 : when_all(task<Ts>... tasks)
474 : {
475 : using result_type = detail::when_all_result_t<Ts...>;
476 :
477 : // State is stored in the coroutine frame, using the frame allocator
478 : detail::when_all_state<Ts...> state;
479 :
480 : // Store tasks in the frame
481 : std::tuple<task<Ts>...> task_tuple(std::move(tasks)...);
482 :
483 : // Launch all tasks and wait for completion
484 : co_await detail::when_all_launcher<Ts...>(&task_tuple, &state);
485 :
486 : // Propagate first exception if any.
487 : // Safe without explicit acquire: capture_exception() is sequenced-before
488 : // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
489 : // last task's decrement that resumes this coroutine.
490 : if(state.first_exception_)
491 : std::rethrow_exception(state.first_exception_);
492 :
493 : // Extract and return results
494 : if constexpr (std::is_void_v<result_type>)
495 : co_return;
496 : else
497 : co_return detail::extract_results(state);
498 48 : }
499 :
500 : // For backwards compatibility and type queries, expose result type computation
501 : template<typename... Ts>
502 : using when_all_result_type = detail::when_all_result_t<Ts...>;
503 :
504 : } // namespace capy
505 : } // namespace boost
506 :
507 : #endif
|