blog post image
Andrew Lock avatar

Andrew Lock

~10 min read

A deep-dive into the new Task.WaitAsync() API in .NET 6

In this post I look at how the new Task.WaitAsync() API is implemented in .NET 6, looking at the internal types used to implement it.

Adding a timeout or cancellation support to await Task

In my previous post, I showed how you could "cancel" an await Task call for a Task that didn't directly support cancellation by using the new WaitAsync() API in .NET 6.

I used WaitAsync() in that post to improve the code that waits for the IHostApplicationLifetime.ApplicationStarted event to fire. The final code I settled on is shown below:

static async Task<bool> WaitForAppStartup(IHostApplicationLifetime lifetime, CancellationToken stoppingToken)
{
    try
    {
        // Create a TaskCompletionSource which completes when 
        // the lifetime.ApplicationStarted token fires
        var tcs = new TaskCompletionSource();
        using var _ = lifetime.ApplicationStarted.Register(() => tcs.SetResult());

        // Wait for the TaskCompletionSource Task, _or_ the stopping Token to fire
        // using the new .NET 6 API, WaitAsync()
        await tcs.Task.WaitAsync(stoppingToken).ConfigureAwait(false);
        return true;
    }
    catch(TaskCanceledException)
    {
        // stoppingToken fired
        return false;
    }
}

In this post, I look at how the .NET 6 API Task.WaitAsync() is actually implemented.

Diving into the Task.WaitAsync implementation

For the rest of the post I'm going to walk through the implementation behind the API. There's not anything very surprising there, but I haven't looked much at the code behind Task and its kin, so it was interesting to see some of the details.

Task.WaitAsync() was introduced in this PR by Stephen Toub.

We'll start with the Task.WaitAsync methods:

public class Task
{
    public Task WaitAsync(CancellationToken cancellationToken) 
        => WaitAsync(Timeout.UnsignedInfinite, cancellationToken);

    public Task WaitAsync(TimeSpan timeout) 
        => WaitAsync(ValidateTimeout(timeout, ExceptionArgument.timeout), default);

    public Task WaitAsync(TimeSpan timeout, CancellationToken cancellationToken)
        => WaitAsync(ValidateTimeout(timeout, ExceptionArgument.timeout), cancellationToken);
}

These three methods all ultimately delegate to a different, private, WaitAsync overload (shown shortly) that takes a timeout in milliseconds. This timeout is calculated and validated in the ValidateTimeout method, shown below, which asserts that the timeout is in the allowed range, and converts it to a uint of milliseconds.

internal static uint ValidateTimeout(TimeSpan timeout, ExceptionArgument argument)
{
    long totalMilliseconds = (long)timeout.TotalMilliseconds;
    if (totalMilliseconds < -1 || totalMilliseconds > Timer.MaxSupportedTimeout)
    {
        ThrowHelper.ThrowArgumentOutOfRangeException(argument, ExceptionResource.Task_InvalidTimerTimeSpan);
    }

    return (uint)totalMilliseconds;
}

Now we come to the WaitAsync method that all the public APIs delegate too. I've annotated the method below:

private Task WaitAsync(uint millisecondsTimeout, CancellationToken cancellationToken)
{
    // If the task has already completed, or if we don't have a timeout OR a cancellation token
    // then there's nothing we can do, and WaitAsync is a noop that returns the original Task
    if (IsCompleted || (!cancellationToken.CanBeCanceled && millisecondsTimeout == Timeout.UnsignedInfinite))
    {
        return this;
    }

    // If the cancellation token has already fired, we can immediately return a cancelled Task
    if (cancellationToken.IsCancellationRequested)
    {
        return FromCanceled(cancellationToken);
    }

    // If the timeout is 0, then we will immediately return a faulted Task
    if (millisecondsTimeout == 0)
    {
        return FromException(new TimeoutException());
    }

    // The CancellationPromise<T> is where most of the heavy lifting happens
    return new CancellationPromise<VoidTaskResult>(this, millisecondsTimeout, cancellationToken);
}

Most of this method is checking whether we can take a fast-path and avoid the extra work involved in creating a CancellationPromise<T>, but if not, then we need to dive into it. Before we do, it's worth addressing the VoidTaskResult generic parameter used with the returned CancellationPromise<T>.

VoidTaskResult is an internal nested type of Task, which is used a little like the unit type in functional programming; it indicates that you can ignore the T.

// Special internal struct that we use to signify that we are not interested in
// a Task<VoidTaskResult>'s result.
internal struct VoidTaskResult { }

Using VoidTaskResult means more of the implementation of Task and Task<T> can be shared. In this case, the CancellationPromise<T> implementation is the same in both the Task.WaitAsync() implementation (shown above), and the generic versions of those methods exposed by Task<TR>..

So with that out the way, let's look at the implementation of CancellationPromise<T> to see how the magic happens.

Under the hood of CancellationPromise<T>

There's quite a few types involved in CancellationPromise that you probably won't be familiar with unless you regularly browse the .NET source code, so we'll take this one slowly.

First of all, we have the type signature for the nested type CancellationPromise<T>:

public class Task
{
    private protected sealed class CancellationPromise<TResult> : Task<TResult>, ITaskCompletionAction
    {
        // ...
    }
}

There's a few things to note in the signature alone:

  • private protected—this modifier means that the CancellationPromise<T> type can only be accessed from classes that derive from Task, and are in the same assembly. Which means you can't use it directly in your user code.
  • Task<TResult>—the CancellationPromise<T> derives from Task<TResult>. For the most part it's a "normal" task, that can be cancelled, completed, or faulted just like any other Task.
  • ITaskCompletionAction—this is an internal interface that essentially allows you to register a lightweight action to take when a Task completes. This is similar to a standard continuation created with ContinueWith, except it is lower overhead. Again, this is internal, so you can't use it in your types. We'll look in more depth at this shortly.

We've looked at the signature, now let's look at it's private fields. The descriptions for these in the source cover it pretty well I think:

/// <summary>The source task.  It's stored so that we can remove the continuation from it upon timeout or cancellation.</summary>
private readonly Task _task;
/// <summary>Cancellation registration used to unregister from the token source upon timeout or the task completing.</summary>
private readonly CancellationTokenRegistration _registration;
/// <summary>The timer used to implement the timeout.  It's stored so that it's rooted and so that we can dispose it upon cancellation or the task completing.</summary>
private readonly TimerQueueTimer? _timer;

So we have 3 fields:

  • The original Task on which we called WaitAsync()
  • The cancellation token registration received when we registered with the CancellationToken. If the default cancellation token was used, this will be a "dummy" default instance.
  • The timer used to implement the timeout behaviour (if required).

Note that the _timer field is of type TimerQueueTimer. This is another internal implementation, this time it is part of the overall Timer implementation. We're going deep enough as it is in this post, so I'll only touch on how this is used briefly below. For now it's enough to know that it behaves similarly to a regular System.Threading.Timer.

So, the CancellationPromise<T> is a class that derives from Task<T> , maintains a reference to the original Task, a CancellationTokenRegistration, and a TimerQueueTimer.

The CancellationPromise constructor

Lets look at the constructor now. We'll take this in 4 bite-size chunks. First off, the arguments passed in from Task.WaitAsync() have some debug assertions applied, and then the original Task is stored in _task. Finally, the CancellationPromise<T> instance is registered as a completion action for the source Task (we'll come back to what this means shortly).

internal CancellationPromise(Task source, uint millisecondsDelay, CancellationToken token)
{
    Debug.Assert(source != null);
    Debug.Assert(millisecondsDelay != 0);

    // Register with the target task.
    _task = source;
    source.AddCompletionAction(this);

    // ... rest of the constructor covered shortly
}

Next we have the timeout configuration. This creates a TimerQueueTimer and passes in a callback to be executed after millisecondsDelay (and does not execute periodically). A static lambda is used to avoid capturing state, which instead is passed as the second argument to the TimerQueueTimer. The callback tries to mark the CancellationPromise<T> as faulted by setting a TimeoutException() (remember that CancellationPromise<T> itself is a Task), and then does some cleanup we'll see later.

Note also that flowExecutionContext is false, which avoids capturing and restoring the execution context for performance reasons. For more about execution context, see this post by Stephen Toub.

// Register with a timer if it's needed.
if (millisecondsDelay != Timeout.UnsignedInfinite)
{
    _timer = new TimerQueueTimer(static state =>
    {
        var thisRef = (CancellationPromise<TResult>)state!;
        if (thisRef.TrySetException(new TimeoutException()))
        {
            thisRef.Cleanup();
        }
    }, 
    state: this, 
    duetime: millisecondsDelay, 
    period: Timeout.UnsignedInfinite, 
    flowExecutionContext: false);
}

After configuring the timeout, the constructor configures the CancellationToken support. This similarly registers a callback to fire when the provided CancellationToken is cancelled. Note that again this uses UnsafeRegister() (instead of the normal Register()) to avoid flowing the execution context into the callback.

// Register with the cancellation token.
_registration = token.UnsafeRegister(static (state, cancellationToken) =>
{
    var thisRef = (CancellationPromise<TResult>)state!;
    if (thisRef.TrySetCanceled(cancellationToken))
    {
        thisRef.Cleanup();
    }
}, this);

Finally, the constructor does some house keeping. This accounts for the situation where the source Task completes while the constructor is executing, before the timeout and cancellation have been registered. Or if the timeout fires before the cancellation is registered. Without the following block, you could end up with leaking resources not being cleaned up

// If one of the callbacks fired, it's possible they did so prior to our having registered the other callbacks,
// and thus cleanup may have missed those additional registrations.  Just in case, check here, and if we're
// already completed, unregister everything again.  Unregistration is idempotent and thread-safe.
if (IsCompleted)
{
    Cleanup();
}

That's all the code in the constructor. Once constructed, the CancellationPromise<T> is returned from the WaitAsync() method as a Task (or a Task<T>), and can be awaited just as any other Task. In the next section we'll see what happens when the source Task completes.

Implementing ITaskCompletionAction

In the constructor of CancellationPromise<T> we registered a completion action with the source Task (the one we called WaitAsync() on):

_task = source;
source.AddCompletionAction(this);

The object passed to AddCompletionAction() must implement ITaskCompletionAction (as CancellationPromise<T> does) ITaskCompletionAction interface is simple, consisting of a single method (which is invoked when the source Task completes) and a single property:

internal interface ITaskCompletionAction
{
    // Invoked to run the action
    void Invoke(Task completingTask);
    // Should only return false for specialised scenarios for performance reasons
    // Controls whether to force running as a continuation (synchronously)
    bool InvokeMayRunArbitraryCode { get; }
}

CancellationPromise<T> implements this method as shown below. It sets InvokeMayRunArbitraryCode to true (as all non-specialised scenarios do) and implements the Invoke() method, receiving the completed source Task as an argument.

The implementation essentially "copies" the status of the completed source Task into the CancellationPromise<T> task:

  • If the source Task was cancelled, it calls TrySetCancelled, re-using the exception dispatch information to "hide" the details of CancellationPromise<T>
  • If the source task was faulted, it calls TrySetException()
  • If the task completed, it calls TrySetResult

Note that whatever the status of the source Task, the TrySet* method may fail, if cancellation was requested or the timeout expired in the mean time. In those cases the bool variable is set to false, and we can skip calling Cleanup() (as the successful path will call it instead).

class CancellationPromise<TResult> : ITaskCompletionAction
{
    bool ITaskCompletionAction.InvokeMayRunArbitraryCode => true;

    void ITaskCompletionAction.Invoke(Task completingTask)
    {
        Debug.Assert(completingTask.IsCompleted);

        bool set = completingTask.Status switch
        {
            TaskStatus.Canceled => TrySetCanceled(completingTask.CancellationToken, completingTask.GetCancellationExceptionDispatchInfo()),
            TaskStatus.Faulted => TrySetException(completingTask.GetExceptionDispatchInfos()),
            _ => completingTask is Task<TResult> taskTResult ? TrySetResult(taskTResult.Result) : TrySetResult(),
        };

        if (set)
        {
            Cleanup();
        }
    }
}

Now you've seen all three callbacks for the 3 possible outcomes of WaitAsync(). In each case, whether the task, timeout, or cancellation completes first, we have some cleanup to do.

Cleaning up

One of the things you can forget when working with CancellationTokens and timers, is to make sure you clean up after yourself. CancellationPromise<T> makes sure to do this by always calling Cleanup(). This does three things:

  • Dispose the CancellationTokenRegistration returned from CancellationToken.UnsafeRegister()
  • Close the ThreadQueueTimer (if it exists), which cleans up the underlying resources
  • Removes the callback from the source Task, so the ITaskCompletionAction.Invoke() method on CancellationPromise<T> won't be called.
private void Cleanup()
{
    _registration.Dispose();
    _timer?.Close();
    _task.RemoveContinuation(this);
}

Each of these methods is idempotent and thread-safe, so it's safe to call the Cleanup() method from multiple callbacks, which might happen if something fires when we're still running the CancellationPromise<T> constructor, for example.

One point to bear in mind is that even if a timeout occurs, or the cancellation token fires and the CancellationPromise<T> completes, the source Task will continue to execute in the background. The caller who executed source.WaitAsync() won't ever see the output of result of the Task, but if that Task has side effects, they will still occur.

And that's it! It took a while to go through it, but there's not actually much code involved in the implementation of WaitAsync(), and it's somewhat comparable to the "naive" approach you might have used in previous versions of .NET, but using some of .NET's internal types for performance reasons. I hope it was interesting!

Summary

In this post I took an in-depth look at the new Task.WaitAsync() method in .NET 6, exploring how it is implemented using internal types of the BCL. I showed that the Task returned from WaitAsync() is actually a CancellationPromise<T> instance, which derives from Task<T>, but which supports cancellation and timeouts directly. Finally, I walked through the implementation of CancellationPromise<T>, showing how it wraps the source Task.

Andrew Lock | .Net Escapades
Want an email when
there's new posts?