Another day I was looking at a simple classic implementation of a shared counter in C++
using mutex, and I wondered what other thread-safe implementations existed. I usually use Go
to explore my curiosity. The result of this exploration is a compilation of ways on how to implement a goroutine-safe counter.
Don't Do This
Let's start with the non-safe implementation.
type NotSafeCounter struct {
number uint64
}
func NewNotSafeCounter() Counter {
return &NotSafeCounter{0}
}
func (c *NotSafeCounter) Add(num uint64) {
c.number = c.number + num
}
func (c *NotSafeCounter) Read() uint64 {
return c.number
}
Nothing magical. Let's test its correctness by running 100 goroutines
where 2 thirds of them Add
1 to the shared counter.
func testCorrectness(t *testing.T, counter Counter) {
wg := &sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
if i%3 == 0 {
go func(counter Counter) {
counter.Read()
wg.Done()
}(counter)
} else if i%3 == 1 {
go func(counter Counter) {
counter.Add(1)
counter.Read()
wg.Done()
}(counter)
} else {
go func(counter Counter) {
counter.Add(1)
wg.Done()
}(counter)
}
}
wg.Wait()
if counter.Read() != 66 {
t.Errorf("counter should be %d and was %d", 66, counter.Read())
}
}
The result of the test is not deterministic. Sometimes it passes. But sometimes you get a message like counter_test.go:34: counter should be 66 and was 65
.
The Classic
The traditional way to implement a correct counter is to use a mutex that guarantees that only one operation is done at a time. In Go
, we simply use the sync package.
type MutexCounter struct {
mu *sync.RWMutex
number uint64
}
func NewMutexCounter() Counter {
return &MutexCounter{&sync.RWMutex{}, 0}
}
func (c *MutexCounter) Add(num uint64) {
c.mu.Lock()
defer c.mu.Unlock()
c.number = c.number + num
}
func (c *MutexCounter) Read() uint64 {
c.mu.RLock()
defer c.mu.RUnlock()
return c.number
}
Now the tests run deterministically and always pass.
Using Channels
Locks are low-level primitives that let you achieve synchronization. Go
offers a more high-level primitive called channels. There are a lot mutexes-versus-channels
, which-one-is-better
or which-one-should-I-use
kinds of discussions about mutexes and channels. Some of the discussions are valid and very interesting, but that is not the point of this blog post.
The way we are going to implement a goroutine-safe counter using channels is by having and channel where every operation (Add
or Read
) called on the counter will be queued in a channel. The operations will be represented as a function func()
. When created, the counter spawns a goroutine that executes the queued operations in serial order.
Here is the counter definition:
type ChannelCounter struct {
ch chan func()
number uint64
}
func NewChannelCounter() Counter {
counter := &ChannelCounter{make(chan func(), 100), 0}
go func(counter *ChannelCounter) {
for f := range counter.ch {
f()
}
}(counter)
return counter
}
See how the counter's goroutine only reads the operations from the channel and executes them.
When a goroutine calls Add
, we queue a write operation:
func (c *ChannelCounter) Add(num uint64) {
c.ch <- func() {
c.number = c.number + num
}
}
When a goroutine calls Read
, we queue a read operation:
func (c *ChannelCounter) Read() uint64 {
ret := make(chan uint64)
c.ch <- func() {
ret <- c.number
close(ret)
}
return <-ret
}
What I really like about this implementation is how clear it is to visualize the operations being executed in serial order.
The Atomic Way
We can use even lower-level primitives and execute atomic instructions provided by the sync/atomic package.
type AtomicCounter struct {
number uint64
}
func NewAtomicCounter() Counter {
return &AtomicCounter{0}
}
func (c *AtomicCounter) Add(num uint64) {
atomic.AddUint64(&c.number, num)
}
func (c *AtomicCounter) Read() uint64 {
return atomic.LoadUint64(&c.number)
}
Compare And Swap
Alternatively, we can use the very classical atomic primitive Compare And Swap to Add
a number to the counter.
func (c *CASCounter) Add(num uint64) {
for {
v := atomic.LoadUint64(&c.number)
if atomic.CompareAndSwapUint64(&c.number, v, v+num) {
return
}
}
}
func (c *CASCounter) Read() uint64 {
return atomic.LoadUint64(&c.number)
}
Basically, it tries infinitely until it successfully updates the counter correctly.
What About Float Types?
In my exploration, I came across an awesome talk, called Prometheus: Designing and Implementing a Modern Monitoring Solution in Go, that discusses these techniques and benchmarks them. At the final, it talks about how to implement a counter of floats. All techniques provided so far works for floats, except the ones that use sync/atomic. sync/atomic does not provide atomic operations on floats. In the video, Björn Rabenstein presents how to solve this by storing the float as an uint64
and use the math.Float64bits
and math.Float64frombits
to do the conversion between float64
and uint64
.
type CASFloatCounter struct {
number uint64
}
func NewCASFloatCounter() *CASFloatCounter {
return &CASFloatCounter{0}
}
func (c *CASFloatCounter) Add(num float64) {
for {
v := atomic.LoadUint64(&c.number)
newValue := math.Float64bits(math.Float64frombits(v) + num)
if atomic.CompareAndSwapUint64(&c.number, v, newValue) {
return
}
}
}
func (c *CASFloatCounter) Read() float64 {
return math.Float64frombits(atomic.LoadUint64(&c.number))
}
Final Words
This is a simple collection of implementations of a shared counter. It is the result of my curiosity and also the result of trying to achieve a fundamental understanding of concurrency. If you know more ways on how to do this, I'd love to know.
You can check the implementations, run the tests and benchmarks at brunocalza/sharedcounter.
I'm always trying to share what I am learning about database internals.
I'm on Twitter.