Concurrency in Golang, Zero to Hero
Learn concurrency with GoRoutines, WaitGroups, Channels, Context with cancel function, and creating worker pools.
What is Concurrency?
Concurrency is the ability of a program to run multiple tasks that can run individually but remain part of the same program. Concurrency is important when we need to run an individual program without disturbing the original flow. In modern software, concurrency is required as the program needs to run fast and there could be multiple tasks that need to be done by the same program.
Concurrency in GoLang
Golang has a very powerful concurrency model called CSP (communicating sequential processes), which breaks a problem into smaller sequential processes and then schedules several instances of these processes called Goroutines. When we create a function as a goroutine it will be treated as an independent unit of work that gets scheduled and then executed on an available logical processor. CSP relies on using channels to pass the immutable messages between two or more concurrent processes.
GoRoutines:
It is a function that runs independently from the main thread. If we add go in front of a function it becomes a goroutine and gets executed in a different thread.
Channels:
It is a medium to send the message between the goroutines.
How do you achieve it?
Let’s take a look at a normal app with multiple calls in a loop
Normal App
package main
import (
"log/slog"
"os"
"time"
)
var logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
func main() {
logger.Info("Start Program.")
start := time.Now()
var data = []int{2000, 1000, 5000, 4000}
for _, d := range data {
processData(d)
}
logger.Info("Program Stopped.", "duration", time.Since(start).String())
}
func processData(inp int) {
logger.Info("start processing", "input", inp)
start := time.Now()
defer func() {
logger.Info("Stop processing", "input", inp, "duration", time.Since(start).String())
}()
time.Sleep(time.Duration(float32(inp)) * time.Millisecond)
}
Output
time=2024-10-03T07:18:21.954+10:00 level=INFO msg="Start Program."
time=2024-10-03T07:18:21.954+10:00 level=INFO msg="start processing" input=2000
time=2024-10-03T07:18:23.955+10:00 level=INFO msg="Stop processing" input=2000 duration=2.001110917s
time=2024-10-03T07:18:23.955+10:00 level=INFO msg="start processing" input=1000
time=2024-10-03T07:18:24.956+10:00 level=INFO msg="Stop processing" input=1000 duration=1.001026708s
time=2024-10-03T07:18:24.956+10:00 level=INFO msg="start processing" input=5000
time=2024-10-03T07:18:29.958+10:00 level=INFO msg="Stop processing" input=5000 duration=5.001245083s
time=2024-10-03T07:18:29.958+10:00 level=INFO msg="start processing" input=4000
time=2024-10-03T07:18:33.958+10:00 level=INFO msg="Stop processing" input=4000 duration=4.000528167s
time=2024-10-03T07:18:33.959+10:00 level=INFO msg="Program Stopped." duration=12.00467425s
If I had separate threads for the same task it would obviously take less time. Let's use goroutines and check.
App with GoRoutines
package main
import (
"log/slog"
"os"
"time"
)
var logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
func main() {
logger.Info("Start Program.")
start := time.Now()
var data = []int{2000, 1000, 5000, 4000}
for _, d := range data {
go processData(d)
}
logger.Info("Program Stopped.", "duration", time.Since(start).String())
}
func processData(inp int) {
logger.Info("start processing", "input", inp)
start := time.Now()
defer func() {
logger.Info("Stop processing", "input", inp, "duration", time.Since(start).String())
}()
time.Sleep(time.Duration(float32(inp)) * time.Millisecond)
}
Output
time=2024-10-03T07:27:20.811+10:00 level=INFO msg="Start Program."
time=2024-10-03T07:27:20.811+10:00 level=INFO msg="Program Stopped." duration=15.708µs
The problem here is I just spun multiple threads but didn’t wait for their results. So the main thread ended the program before other threads could complete their job.
To solve this let’s use wait groups to wait for all the threads to complete.
GoRoutines and WaitGroup
package main
import (
"log/slog"
"os"
"sync"
"time"
)
var logger = slog.New(slog.NewTextHandler(os.Stdout, nil))
func main() {
logger.Info("Start Program.")
start := time.Now()
var data = []int{2000, 1000, 5000, 4000}
var wg sync.WaitGroup
for _, d := range data {
wg.Add(1)
go processData(d, &wg)
}
wg.Wait()
logger.Info("Program Stopped.", "duration", time.Since(start).String())
}
func processData(inp int, wg *sync.WaitGroup) {
defer wg.Done()
logger.Info("start processing", "input", inp)
start := time.Now()
defer func() {
logger.Info("Stop processing", "input", inp, "duration", time.Since(start).String())
}()
time.Sleep(time.Duration(float32(inp)) * time.Millisecond)
}
Output
time=2024-10-03T07:29:43.023+10:00 level=INFO msg="Start Program."
time=2024-10-03T07:29:43.023+10:00 level=INFO msg="start processing" input=4000
time=2024-10-03T07:29:43.023+10:00 level=INFO msg="start processing" input=1000
time=2024-10-03T07:29:43.023+10:00 level=INFO msg="start processing" input=5000
time=2024-10-03T07:29:43.023+10:00 level=INFO msg="start processing" input=2000
time=2024-10-03T07:29:44.027+10:00 level=INFO msg="Stop processing" input=1000 duration=1.003094292s
time=2024-10-03T07:29:45.025+10:00 level=INFO msg="Stop processing" input=2000 duration=2.001338959s
time=2024-10-03T07:29:47.024+10:00 level=INFO msg="Stop processing" input=4000 duration=4.000604792s
time=2024-10-03T07:29:48.024+10:00 level=INFO msg="Stop processing" input=5000 duration=5.000899875s
time=2024-10-03T07:29:48.024+10:00 level=INFO msg="Program Stopped."
duration=5.001070084s
Here processData
function now takes an argument for Waitgroup and will notify that it has done its task after completion. I added an waitgroup
on the main thread and added a job for each data in the loop. After the loop is completed waitgroup
wait for all the workers to complete and the count to go zero.
It significantly reduced the processing time from 12 seconds to 5 seconds here.
This method is okay when we have a small number of tasks to be done, but let’s say we have hundreds of records to process we can’t just spin a goroutine for each work, it will overload the system. To handle that we need to implement a worker pool so that at any point in time we only spin X number of goroutines.
Implementing Worker Pool
package main
import (
"fmt"
"log/slog"
"os"
"sync"
"time"
)
var logger = slog.New(slog.NewTextHandler(os.Stderr, nil))
func main() {
logger.Info("Start Program.")
start := time.Now()
data := []int{2000, 1000, 5000, 4000}
workerCount := 2
workerPool(workerCount, data)
logger.Info("Program Stopped.", "duration", time.Since(start).String())
}
func workerPool(workersNum int, data []int) {
jobsCount := len(data)
jobs := make(chan int, jobsCount)
wg := &sync.WaitGroup{}
for w := 1; w <= workersNum; w++ {
wg.Add(1)
go worker(w, jobs, wg)
}
for _, d := range data {
jobs <- d
}
close(jobs)
wg.Wait()
}
func worker(id int, jobs <-chan int, wg *sync.WaitGroup) {
defer wg.Done()
infoF("Thread Started %d", id)
for job := range jobs {
processData(job, id)
}
infoF("Thread %d stopped", id)
}
func processData(inp, thread int) {
logger.Info("start processing", "input", inp, "thread", thread)
start := time.Now()
defer func() {
logger.Info("Stop processing", "input", inp, "thread", thread, "duration", time.Since(start).String())
}()
time.Sleep(time.Duration(float32(inp)) * time.Millisecond)
}
func infoF(format string, args ...any) {
logger.Info(fmt.Sprintf(format, args...))
}
Output
time=2024-10-03T07:42:16.681+10:00 level=INFO msg="Start Program."
time=2024-10-03T07:42:16.681+10:00 level=INFO msg="Thread Started 2"
time=2024-10-03T07:42:16.681+10:00 level=INFO msg="Thread Started 1"
time=2024-10-03T07:42:16.681+10:00 level=INFO msg="start processing" input=2000 thread=2
time=2024-10-03T07:42:16.681+10:00 level=INFO msg="start processing" input=1000 thread=1
time=2024-10-03T07:42:17.682+10:00 level=INFO msg="Stop processing" input=1000 thread=1 duration=1.001129583s
time=2024-10-03T07:42:17.682+10:00 level=INFO msg="start processing" input=5000 thread=1
time=2024-10-03T07:42:18.682+10:00 level=INFO msg="Stop processing" input=2000 thread=2 duration=2.000665875s
time=2024-10-03T07:42:18.682+10:00 level=INFO msg="start processing" input=4000 thread=2
time=2024-10-03T07:42:22.683+10:00 level=INFO msg="Stop processing" input=5000 thread=1 duration=5.000444125s
time=2024-10-03T07:42:22.683+10:00 level=INFO msg="Thread 1 stopped"
time=2024-10-03T07:42:22.683+10:00 level=INFO msg="Stop processing" input=4000 thread=2 duration=4.001053166s
time=2024-10-03T07:42:22.683+10:00 level=INFO msg="Thread 2 stopped"
time=2024-10-03T07:42:22.683+10:00 level=INFO msg="Program Stopped." duration=6.002219167s
Here I have created a pool of 2 workers to do the job. I then created a channel
with the size of the data to process. Each worker will do the job until the message is received through the channel and upon completion of all the messages, it stops the worker.
I then send the data to process in a loop to the channel. After all the messages are sent I have closed the channel so that the check of message completion can be done in the worker pool. I have waitgroup
to check that all the workers have completed their jobs. This way we can achieve concurrency with worker pools.
Now let’s say we need to stop the workers if there’s an error while processing any request and also we need to capture the response from the request we made.
Worker pool with receiver channel and context
package main
import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"time"
)
var logger = slog.New(slog.NewTextHandler(os.Stderr, nil))
func main() {
logger.Info("Start Program.")
start := time.Now()
var data = []int{1000, 1500, 500, 3001, 200, 201, 300, 400}
workerCount := 2
workerPool(context.Background(), workerCount, data)
logger.Info("End Program.", "duration", time.Since(start).String())
}
func workerPool(ctx context.Context, workerCount int, data []int) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var res = make(chan map[int]string, len(data))
var jobs = make(chan int, len(data))
var errs = make(chan error)
for i := 1; i <= workerCount; i++ {
go worker(ctx, cancel, i, jobs, res, errs)
}
for _, d := range data {
jobs <- d
}
close(jobs)
for j := 0; j < len(data); j++ {
select {
case err := <-errs:
errorF("error received %v ending program", err)
return
case d := <-res:
infoF("Response received %+v", d)
}
}
}
func worker(ctx context.Context, cancel context.CancelFunc, thread int, jobs <-chan int, res chan<- map[int]string, errs chan<- error) {
for d := range jobs {
select {
case <-ctx.Done():
infoF("Context cancel called thread: %d stopped", thread)
return
default:
rs, err := processEvenData(d, thread)
if err != nil {
errorF("Error received thread: %d data %d stopped %v", thread, d, err)
errs <- err
cancel()
return
}
res <- rs
}
}
}
func processData(inp, thread int) (map[int]string, error) {
logger.Info("start processing", "input", inp, "thread", thread)
start := time.Now()
defer func() {
logger.Info("Stop processing", "input", inp, "thread", thread, "duration", time.Since(start).String())
}()
time.Sleep(time.Duration(float32(inp)) * time.Millisecond)
res := map[int]string{inp: getCurrentTime()}
return res, nil
}
func processEvenData(inp, thread int) (res map[int]string, err error) {
logger.Info("start processing", "input", inp, "thread", thread)
start := time.Now()
defer func() {
logger.Info("Stop processing", "input", inp, "thread", thread, "duration", time.Since(start).String())
}()
time.Sleep(time.Duration(float32(inp)) * time.Millisecond)
if inp%2 == 0 {
return map[int]string{inp: getCurrentTime()}, nil
}
return nil, errors.New("failed to process data")
}
func getCurrentTime() string {
return time.Now().UTC().Format("2006-01-02T15:04:05Z")
}
func infoF(format string, args ...any) {
logger.Info(fmt.Sprintf(format, args...))
}
func errorF(format string, args ...any) {
logger.Error(fmt.Sprintf(format, args...))
}
Let’s look at the breakdown for the worker pool
ctx, cancel := context.WithCancel(context.Background())
Context with the cancel function allows us to stop the work, we will use it to check if the stop is called from any worker, and if it’s called we stop that thread from processing further requests.
var res = make(chan map[int]string, len(data))
This channel will be used to write the response.
var errs = make(chan error)
This is an error channel that will be used to write errors if any occur.
Now inside the worker function
select {
case <-ctx.Done():
}
This will check the context if there’s a cancel signal sent, if canceled it will stop the worker.
if err != nil {
cancel()
}
If there’s an error we call the cancel so that other workers will stop processing requests.
There are two different functions processData(inp, thread int)
that will not send any error. Let’s check the response for it.
time=2024-10-03T08:07:07.359+10:00 level=INFO msg="Start Program."
time=2024-10-03T08:07:07.360+10:00 level=INFO msg="start processing" input=1000 thread=2
time=2024-10-03T08:07:07.360+10:00 level=INFO msg="start processing" input=1500 thread=1
time=2024-10-03T08:07:08.361+10:00 level=INFO msg="Stop processing" input=1000 thread=2 duration=1.001164208s
time=2024-10-03T08:07:08.361+10:00 level=INFO msg="start processing" input=500 thread=2
time=2024-10-03T08:07:08.361+10:00 level=INFO msg="Response received map[1000:2024-10-02T22:07:08Z]"
time=2024-10-03T08:07:08.861+10:00 level=INFO msg="Stop processing" input=1500 thread=1 duration=1.500901625s
time=2024-10-03T08:07:08.861+10:00 level=INFO msg="start processing" input=3001 thread=1
time=2024-10-03T08:07:08.861+10:00 level=INFO msg="Response received map[1500:2024-10-02T22:07:08Z]"
time=2024-10-03T08:07:08.861+10:00 level=INFO msg="Stop processing" input=500 thread=2 duration=500.077459ms
time=2024-10-03T08:07:08.861+10:00 level=INFO msg="start processing" input=200 thread=2
time=2024-10-03T08:07:08.861+10:00 level=INFO msg="Response received map[500:2024-10-02T22:07:08Z]"
time=2024-10-03T08:07:09.062+10:00 level=INFO msg="Stop processing" input=200 thread=2 duration=201.1625ms
time=2024-10-03T08:07:09.063+10:00 level=INFO msg="start processing" input=201 thread=2
time=2024-10-03T08:07:09.063+10:00 level=INFO msg="Response received map[200:2024-10-02T22:07:09Z]"
time=2024-10-03T08:07:09.265+10:00 level=INFO msg="Stop processing" input=201 thread=2 duration=202.230333ms
time=2024-10-03T08:07:09.265+10:00 level=INFO msg="start processing" input=300 thread=2
time=2024-10-03T08:07:09.265+10:00 level=INFO msg="Response received map[201:2024-10-02T22:07:09Z]"
time=2024-10-03T08:07:09.569+10:00 level=INFO msg="Stop processing" input=300 thread=2 duration=303.591791ms
time=2024-10-03T08:07:09.569+10:00 level=INFO msg="start processing" input=400 thread=2
time=2024-10-03T08:07:09.569+10:00 level=INFO msg="Response received map[300:2024-10-02T22:07:09Z]"
time=2024-10-03T08:07:09.970+10:00 level=INFO msg="Stop processing" input=400 thread=2 duration=401.056583ms
time=2024-10-03T08:07:09.970+10:00 level=INFO msg="Response received map[400:2024-10-02T22:07:09Z]"
time=2024-10-03T08:07:11.862+10:00 level=INFO msg="Stop processing" input=3001 thread=1 duration=3.001260292s
time=2024-10-03T08:07:11.862+10:00 level=INFO msg="Response received map[3001:2024-10-02T22:07:11Z]"
time=2024-10-03T08:07:11.862+10:00 level=INFO msg="End Program." duration=4.502660875s
Here all the records are processed without error. Now let’s use the other function processEvenData
which sends an error if the input is not an even number.
time=2024-10-03T08:10:16.450+10:00 level=INFO msg="Start Program."
time=2024-10-03T08:10:16.451+10:00 level=INFO msg="start processing" input=1000 thread=2
time=2024-10-03T08:10:16.451+10:00 level=INFO msg="start processing" input=1500 thread=1
time=2024-10-03T08:10:17.452+10:00 level=INFO msg="Stop processing" input=1000 thread=2 duration=1.001280334s
time=2024-10-03T08:10:17.452+10:00 level=INFO msg="start processing" input=500 thread=2
time=2024-10-03T08:10:17.453+10:00 level=INFO msg="Response received map[1000:2024-10-02T22:10:17Z]"
time=2024-10-03T08:10:17.952+10:00 level=INFO msg="Stop processing" input=1500 thread=1 duration=1.500988708s
time=2024-10-03T08:10:17.952+10:00 level=INFO msg="start processing" input=3001 thread=1
time=2024-10-03T08:10:17.952+10:00 level=INFO msg="Response received map[1500:2024-10-02T22:10:17Z]"
time=2024-10-03T08:10:17.952+10:00 level=INFO msg="Stop processing" input=500 thread=2 duration=500.075584ms
time=2024-10-03T08:10:17.952+10:00 level=INFO msg="start processing" input=200 thread=2
time=2024-10-03T08:10:17.952+10:00 level=INFO msg="Response received map[500:2024-10-02T22:10:17Z]"
time=2024-10-03T08:10:18.154+10:00 level=INFO msg="Stop processing" input=200 thread=2 duration=201.066375ms
time=2024-10-03T08:10:18.154+10:00 level=INFO msg="start processing" input=201 thread=2
time=2024-10-03T08:10:18.154+10:00 level=INFO msg="Response received map[200:2024-10-02T22:10:18Z]"
time=2024-10-03T08:10:18.356+10:00 level=INFO msg="Stop processing" input=201 thread=2 duration=202.121ms
time=2024-10-03T08:10:18.356+10:00 level=ERROR msg="Error received thread: 2 data 201 stopped failed to process data"
time=2024-10-03T08:10:18.356+10:00 level=ERROR msg="error received failed to process data ending program"
time=2024-10-03T08:10:18.356+10:00 level=INFO msg="End Program."
duration=1.905565709s
Hereafter processing the odd number 3001, other tasks are not processed as it also stops other workers.
I hope this makes things clear.
Recent Updates:
- September 2024 Use log/slog package for structured login (introduced in go 1.21)