In my previous post I showed how to create a Simple Moving Average calculator. The implementation was very simple, so it was conceptually easy to follow, but had a few limitations. In this post, I address some of those limitations to make the implementation thread-safe and to separate the updating of the average from the fetching of the current value.
Simple Moving Average: brief recap
The Simple Moving Average (SMA) is the mean of the last k
values, where k
is specified ahead of time. The following example shows the SMA for a series of values with k=3
:
Value | Simple Moving Average | Calculation |
---|---|---|
2 | 0.67 | (0 + 0 + 2) / 3 |
4 | 2 | (0 + 2 + 4) / 3 |
5 | 3.67 | (2 + 4 + 5) / 3 |
3 | 4 | (4 + 5 + 3) / 3 |
8 | 5.33 | (5 + 3 + 8) / 3 |
6 | 5.67 | (3 + 8 + 6) / 3 |
4 | 6 | (8 + 6 + 4) / 3 |
In my previous post, I implemented this simply in C# using the following class:
If you haven't already I strongly suggest checking out that post. It contains a deeper explanation of the code below.
public class SimpleMovingAverage
{
private readonly int _k;
private readonly int[] _values;
private int _index = 0;
private int _sum = 0;
public SimpleMovingAverage(int k)
{
_k = k;
_values = new int[k];
}
public double Update(int nextInput)
{
_sum = _sum - _values[_index] + nextInput;
_values[_index] = nextInput;
_index = (_index + 1) % _k;
return ((double) _sum) / _k;
}
}
This simple class has a single method, Update()
that takes the next entry for the SMA calculation, and returns the new SMA value. However, as I mentioned in my previous post, this implementation has several limitations:
- It is not thread safe
- You can only fetch the SMA value when you update it
- There's no handling of overflow in
_sum
.
In this post, I provide a more robust implementation that addresses the first two of these limitations.
Designing the thread safe Simple Moving Average calculator
The limitations shown above are not too difficult to address, but before I do, I want to take another look at the overall design of the SimpleMovingAverage
class above. With a single Update()
method, the expectation is that you will supply the next value in the sequence, i.e. 2
, 4
, 5
, 3
, 8
etc from the table at the start of this post.
However, for my purposes, I wanted to track the number of events occurring in a given period of time (1 second). So the values in the table above represent 2
, 4
, 5
etc events per second. With the previous design, I would have to track how many events had occurred in the last second externally to the SimpleMovingAverage
class, and then pass this value in to the Update()
method. While possible, I would prefer the SimpleMovingAverage
class keep track of that itself.
In addition, other threads would need to fetch the current SMA value without updating it. That led me to the following design:
This design places the responsibility for keeping track of the number of events within a given time bucket on the SimpleMovingAverage
class. The "Event Generator" doesn't have to keep track of the events itself; every time an event occurs, it tells the SimpleMovingAverage
to increment the current count.
Every 1 second, the SimpleMovingAverage
checks the current event count and resets it to 0
. The current event count is used to update the running _sum
and SMA values, using the same efficient approach described in the previous post, where we keep track of the previous k
values.
Separately, the "SMA reporter" can read the current _sma
and use this to calculate the current SMA value. This can happen at any time, and isn't tied to the reporting of events. In practice, the calculation will happen inside the SimpleMovingAverage
class, but conceptually that's what's happening.
That covers the overall design, so now lets look at how to build it.
Using the Interlocked
class for thread-safe code
It's hopefully clear from the previous design that we need our new SimpleMovingAverage
class to be thread-safe; We will have different threads updating the current bucket value, updating the _sum
value, and reading the _sum
value.
There are a few different approaches to making code thread-safe in C#, for example:
- Use
lock()
to "guard" locations that read or modify data that is shared across multiple threads. Only one thread is allowed to access the "locked" section at any one time. - Use immutable data structures to ensure that data never changes, it's only appended to.
- Ensure shared data is updated atomically using the
Interlocked
orVolatile
classes.
All of these approaches can be valid, but I chose to use the Interlocked
classes, as I only needed to update single values at a time atomically. If I needed to update multiple values, I probably would have chosen to use a lock()
instead.
The Interlocked
methods I'll be using are:
Interlocked.Add(ref _currentBucketValue, count)
. This increments the field_currentBucketValue
by the valuecount
.var newBucketValue = Interlocked.Exchange(ref _currentBucketValue, 0)
. This fetches the value from the_currentBucketValue
field, returns it in the variablenewBucketValue
, and sets the value of_currentBucketValue
to0
.
Note that both of these methods work on the field _currentBucketValue
, and they will be called from different threads. However, because we're using the Interlocked
methods, our code to update and fetch the _currentBucketValue
is thread safe 🎉.
That covers the thread-safety aspect, so now lets look at some code!
Implementing the thread-safe SimpleMovingAverage
class
The following is the full implementation for the thread-safe SimpleMovingAverage
class. There's quite a lot of code here, but I'll walk through each part of the implementation below, to show how it works.
using System;
using System.Threading;
using System.Threading.Tasks;
public class ThreadSafeSimpleMovingAverage: IDisposable
{
private readonly int _k;
private readonly TimeSpan _bucketDuration;
private readonly int[] _values;
private readonly TaskCompletionSource<bool> _processExit = new TaskCompletionSource<bool>();
private int _index = 0;
private long _sum = 0;
private int _currentBucketValue = 0;
public ThreadSafeSimpleMovingAverage(int k, TimeSpan bucketDuration)
{
if (k <= 0) throw new ArgumentOutOfRangeException(nameof(k), "Must be greater than 0");
_k = k;
_bucketDuration = bucketDuration;
_values = new int[k];
// start the background update task
Task.Run(UpdateBucketTaskLoopAsync)
.ContinueWith(t => Console.WriteLine(t.Exception), TaskContinuationOptions.OnlyOnFaulted);
}
public void IncrementCurrentBucket(int count)
{
// Note, can cause overflows!
Interlocked.Add(ref _currentBucketValue, count);
}
public double GetLatestAverage()
{
var sum = Interlocked.Read(ref _sum);
return ((double)sum) / _k;
}
private void UpdateBucket()
{
int previousBucketValue = _values[_index];
int newBucketValue = Interlocked.Exchange(ref _currentBucketValue, 0);
long newSum = _sum - previousBucketValue + newBucketValue;
Interlocked.Exchange(ref _sum, newSum);
_values[_index] = newBucketValue;
_index = (_index + 1) % _k;
}
private async Task UpdateBucketTaskLoopAsync()
{
while (true)
{
if (_processExit.Task.IsCompleted)
{
return;
}
UpdateBucket();
await Task.WhenAny(
Task.Delay(_bucketDuration),
_processExit.Task)
.ConfigureAwait(false);
}
}
public void CancelUpdates()
{
_processExit.TrySetResult(true);
}
public void Dispose() => CancelUpdates();
}
In the following sections we'll look in detail at each part of this class:
The constructor and field definitions
We'll start with the fields and the constructor for the class:
public class ThreadSafeSimpleMovingAverage: IDisposable
{
private readonly int _k;
private readonly TimeSpan _bucketDuration;
private readonly int[] _values;
private readonly TaskCompletionSource<bool> _processExit = new TaskCompletionSource<bool>();
private int _index = 0;
private long _sum = 0;
private int _currentBucketValue = 0;
public ThreadSafeSimpleMovingAverage(int k, TimeSpan bucketDuration)
{
if (k <= 0) throw new ArgumentOutOfRangeException(nameof(k), "Must be greater than 0");
_k = k;
_bucketDuration = bucketDuration;
_values = new int[k];
// start the background update task
Task.Run(UpdateBucketTaskLoopAsync)
.ContinueWith(t => Console.WriteLine(t.Exception), TaskContinuationOptions.OnlyOnFaulted);
}
//...
}
Much of this is the same as for the implementation in the previous post, but we now have a few more fields:
TimeSpan _bucketDuration
This is how often we should accumulate values before resetting the current bucket count. In the example I showed earlier, this is set to 1 second, but I've made it configurable via a constructor parameter.int _currentBucketValue
The current accumulated event count for the current time bucket.TaskCompletionSource<bool> _processExit
Used to signal when the app is shutting down, so we can terminate the "update" loop.
In addition, notice that we're using a long
for the _sum
field. This is to avoid overflow issues when there are a high number of events and we're summing the k
previous values.
The last step in the constructor starts the UpdateBucketTaskLoopAsync
update loop using Task.Run()
. We'll come to the update loop method shortly.
Updating the current bucket and retrieving the latest SMA value
Before we look at the update loop, lets look at the methods which will be called by the "event generator" and the "SMA reporter" components, i.e. the classes that interact with the calculator.
The IncrementCurrentBucket
method is used to increment the count of events by the "event generator". This uses the Interlocked.Add()
method to atomically add the count
to the value stored in the _currentBucketValue
:
public void IncrementCurrentBucket(int count)
{
// Note, can cause overflows!
Interlocked.Add(ref _currentBucketValue, count);
}
Unfortunately, this method is vulnerable to integer overflows. For example, if _currentBucketValue
is already set to int.MaxValue
(2147483646 ), and you call IncrementCurrentBucket(1)
, then the new value will be -2147483648: we have overflowed! There's not a simple solution to that problem without switching to using lock
instead, but it wasn't a practical concern in my case, so I chose to just document it, and accept the limitation.
The "SMA reporter" component fetches the latest SMA value from the calculator by calling GetLatestAverage()
. This reads the _sum
field, and converts it to an average by dividing by the window size, _k
(if this calculation doesn't make sense, refer to my previous post).
public double GetLatestAverage()
{
var sum = Interlocked.Read(ref _sum);
return ((double)sum) / _k;
}
If you're wondering why I don't store the SMA in a double field, it's because there's no
Interlocked.Read()
for double values to ensure the read is atomic, and I wanted to demonstrate that method! There are ways to atomically read a double, as you'll see in the next post.
We've covered how the calculator interfaces with the external components, in the next section we'll look at how the _sum
is updated.
Updating the SMA sum
We update the _sum
value in the UpdateBucket()
method. This is called internally by the SimpleMovingAverage
class on a schedule, and is used to set the _currentBucketValue
as the next value, and to update the _sum
accordingly.
private void UpdateBucket()
{
int previousBucketValue = _values[_index];
int newBucketValue = Interlocked.Exchange(ref _currentBucketValue, 0);
long newSum = _sum - previousBucketValue + newBucketValue;
Interlocked.Exchange(ref _sum, newSum);
_values[_index] = newBucketValue;
_index = (_index + 1) % _k;
}
This method first fetches the previousBucketValue
which is the "oldest" value in the time window. We then fetch the current value stored in _currentBucketValue
and reset its value to 0 at the same time.
Interlocked.Exchange
ensures the fetch and replace happens atomically, so even if something is callingIncrementCurrentBucket()
at the same time, we won't lose or double-count any events.
Once we have previousBucketValue
and newBucketValue
, we can calculate the newSum
using the technique described in the previous post, and use Interlocked.Exchange()
to safely update the current SMA _sum
.
Finally, we remove the previousBucketValue
from our array of stored values, and increment the _index
.
This is the only method that accesses the
_index
field, and it is never executed in parallel, so we don't have to worry about thread safety for the_values
or_index
fields.
The UpdateBucket()
method is called in a loop, so the final part of the class is the looping code itself.
Running the UpdateBucket()
method in a loop
In the constructor of the SimpleMovingAverage
class we started a Task that runs continuously (unless an exception occurs in the task):
Task.Run(UpdateBucketTaskLoopAsync)
.ContinueWith(t => Console.WriteLine(t.Exception), TaskContinuationOptions.OnlyOnFaulted);
The UpdateBucketTaskLoopAsync
loops continuously until the TaskCompletionSource _processExit
completes, at which point we break out of the loop:
private async Task UpdateBucketTaskLoopAsync()
{
while (true)
{
if (_processExit.Task.IsCompleted)
{
return;
}
UpdateBucket();
await Task.WhenAny(
Task.Delay(_bucketDuration),
_processExit.Task)
.ConfigureAwait(false);
}
}
This pattern with the TaskCompletionSource
allows us to wait for the _bucketDuration
delay in the loop, but also to immediately stop waiting and exit if the _processExit
task completes.
You could use a
CancellationTokenSource
in a similar way, but one of the advantages of usingTaskCompletionSource
is that you don't have to worry about catching cancellation exceptions etc.
When the SimpleMovingAverage
class is disposed, or CancelUpdates
is explicitly called, we set the value in the _processExit
field, which causes the while loop to break:
public void CancelUpdates()
{
_processExit.TrySetResult(true);
}
public void Dispose() => CancelUpdates();
And there you have it, a thread-safe implementation of the Simple Moving Average calculator from the previous post. In the next post, I'll tweak this a little bit further for my final use case, where we need to store two values, instead of a single number. I'll also tackle the overflow issue by constraining the problem slightly, and show a more optimised mechanism for the thread safety.
Summary
In this post, I showed how to implement a thread-safe Simple Moving Average (SMA) calculator, using the Interlocked
class to provide thread safety guarantees. The calculator accumulates values within a given time period (e.g. 1 second) when external components call IncrementCurrentBucket()
, and periodically updates its internal counts. External components can call GetLatestAverage()
at any time to get the current SMA.