In this post I discuss the new Task.WaitAsync() APIs introduced in .NET 6, how you can use them to "cancel" an await call, and how they can replace other approaches you may be using currently.

The new Task.WaitAsync API in .NET 6

In a recent post, I described how to use a TaskCompletionSource with IHostApplicationLifetime as a way of "pausing" a background service until the application starts up. In that code I used the following function that waits for a TaskCompletionSource.Task to complete, but also supports cancellation via a CancellationToken:

static async Task<bool> WaitForAppStartup(IHostApplicationLifetime lifetime, CancellationToken stoppingToken)
{
    var startedSource = new TaskCompletionSource();
    var cancelledSource = new TaskCompletionSource();

    using var reg1 = lifetime.Register(() => startedSource.SetResult());
    using var reg2 = stoppingToken.Register(() => cancelledSource.SetResult());

    Task completedTask = await Task.WhenAny(
        startedSource.Task,
        cancelledSource.Task).ConfigureAwait(false);

    // If the completed tasks was the "app started" task, return true, otherwise false
    return completedTask == startedSource.Task;
}

This code works on many versions of .NET, but in the post I specifically mentioned that this was talking about .NET 6, so Andreas Gehrke pointed out that I could have used a simpler approach:

Andreas is referring to a new API introduced to the Task (and Task<T>) API, which allows you to await a Task while also making that await cancellable:

namespace System.Threading.Tasks;
public class Task
{
    public Task WaitAsync(CancellationToken cancellationToken);
    public Task WaitAsync(TimeSpan timeout);
    public Task WaitAsync(TimeSpan timeout, CancellationToken cancellationToken);
}

As you can see, there are three new methods added to Task, all are overloads of WaitAsync(). This is useful for the exact scenario I described earlier—you want to await the completion of a Task, but want that await to be cancellable by a CancellationToken.

Based on this new API, we could rewrite the WaitForAppStartup function as the following:

static async Task<bool> WaitForAppStartup(IHostApplicationLifetime lifetime, CancellationToken stoppingToken)
{
    try
    {
        var tcs = new TaskCompletionSource();
        using var _ = lifetime.ApplicationStarted.Register(() => tcs.SetResult());
        await tcs.Task.WaitAsync(stoppingToken).ConfigureAwait(false);
        return true;
    }
    catch(TaskCanceledException)
    {
        return false;
    }
}

I think this is much easier to read, so thanks Andreas for pointing it out!

Awaiting a Task with a timeout

The Task.WaitAsync(CancellationToken cancellationToken) method (and its counterpart on Task<T>) is very useful when you want to make an await cancellable by a CancellationToken. The other overloads are useful if you want to make it cancellable based on a timeout.

For example, consider the following pseudo code:

public async Task<int> GetResult()
{
    var cachedResult = await LoadFromCache();
    if (cachedResult is not null)
    {
        return cachedResult.Value;
    }

    return await LoadDirectly(); //TODO: store the result in the cache

    async Task<int?> LoadFromCache()
    {
        // simulating something quick
        await Task.Delay(TimeSpan.FromMilliseconds(10));

        return 123;
    }

    async Task<int> LoadDirectly()
    {
        // simulating something slow
        await Task.Delay(TimeSpan.FromSeconds(30));

        return 123;
    }
}

This code shows a single public method, with two local functions:

  • GetResult() returns the result of an expensive operation, the result of which may be cached
  • LoadFromCache() returns the result from a cache, with a short delay
  • LoadDirectly() returns the result from the original source, which takes a lot longer

This code is pretty typical for when you need to cache the result of an expensive operation. But note that the "caching API" in this example is async. This could be because you're using the IDistributedCache in ASP.NET Core for example.

If all goes well then calling GetResult() multiple times should work like the following:

var result1 = await GetResult(); // takes ~5s
var result2 = await GetResult(); // takes ~10ms, as the result is cached
var result3 = await GetResult(); // takes ~10ms, as the result is cached

In this case, the cache is doing a great job speeding up subsequent requests for the result.

But what if something goes wrong with the distributed cache?

For example, maybe you're using Redis as a distributed cache, which most of the time is lighting-fast. But for some reason, your Redis server suddenly becomes unavailable: maybe the server crashes, there's network problems, or the network just becomes very slow.

Suddenly, your LoadFromCache() method is actually making the call to GetResult() slower, not faster!😱

Ideally, you want to be able to say "Try and load this from the cache, but if it takes longer than x milliseconds, then stop trying". i.e. you want to set a timeout.

Now, you may well be able to add a sensible timeout within the Redis connection library itself, but assume for a moment that you can't, or that your caching API doesn't provide any such APIs. In that case, you can use .NET 6's Task<T>.WaitAsync(TimeSpan):

public async Task<int> GetResult()
{
    // set a threshold to wait for the cached result
    var cacheTimeout = TimeSpan.FromMilliseconds(100);
    try
    {
        var cachedResult = await LoadFromCache().WaitAsync(cacheTimeout);
        if (cachedResult is not null)
        {
            return cachedResult.Value;
        }
    }
    catch(TimeoutException)
    {
        // cache took too long
    }

    return await LoadDirectly(); //TODO: store the result in the cache
    // ...
}

With this change, GetResult() won't wait longer than 100ms for the cache to return. If LoadFromCache() exceeds that timeout, Task.WaitAsync() throws a TimeoutException and the function immediately loads from LoadDirectly() instead.

Note that if you're using the CancellationToken overload of WaitAsync(), you'll get a TaskCanceledException when the task is cancelled. If you use a timeout, you'll get TimeoutException.

If you wanted this behaviour before .NET 6, you could replicate it using an extension something like the following:

// Extension method on `Task<T>`
public static async Task<TResult> TimeoutAfter<TResult>(this Task<TResult> task, TimeSpan timeout)
{
    // We need to be able to cancel the "timeout" task, so create a token source
    var cts = new CancellationTokenSource();

    // Create the timeout task (don't await it)
    var timeoutTask = Task<TResult>.Delay(timeout, cts.Token);

    // Run the task and timeout in parallel, return the Task that completes first
    var completedTask = await Task<TResult>.WhenAny(task, timeoutTask).ConfigureAwait(false);

    if (completedTask == task)
    {
        // Cancel the "timeout" task so we don't leak a Timer
        cts.Cancel();
        // await the task to bubble up any errors etc
        return await task.ConfigureAwait(false);
    }
    else
    {
         throw new TimeoutException($"Task timed out after {timeout}");
    }
}

Having this code be part of the .NET base class library is obviously very handy, but it also helps avoid subtle bugs from writing this code yourself. In the extension above, for example, it would be easy to forget to cancel the Task.Delay() call. This would leak a Timer instance until the delay trigger fires in the background. In high-throughput code, that could easily become an issue!

On top of that, .NET 6 adds a further overload that supports both a timeout, and a CancellationToken, saving you one more extension method to write 🙂 In the next post I'll dive into how this is actually implemented under the hood, as there's a lot more to it than the extension method above!

Summary

In this post I discussed the new Task.WaitAsync() method overloads introduced in .NET 6, and how you can use them to simplify any code where you wanted to wait for a Task, but wanted the await to be cancellable either via a CancellationToken or after a specified timeout.