Context Cancellation With Cause

1 min read Tweet this post

Go 1.20 already released, and 1 of many interesting features is WithCancelCause function that simplify passing error when cancelling context. Previously when working with http server request, I passing error from the inner most handler to middleware that handle logger with creating custom handler with error for example func(w http.ResponseWriter, r *http.Request) error. The main problem is I must rewrite all middleware that follow the signature so I can called compose it like this middlewareA(middlewareB(handler)).

Fortunately using WithCancelCause we just simply call the cancellation with the reason (non-nill error) like cancel(customErr) and got the error cause.

Here the simple demo https://go.dev/play/p/K3_RcvniXxZ

import (
	"context"
	"errors"
	"fmt"
)

func main() {
	customErr := errors.New("not found")
	parent := context.Background()
	ctx, cancel := context.WithCancelCause(parent)
	cancel(customErr)
	fmt.Println(ctx.Err())          // returns context.Canceled
	fmt.Println(context.Cause(ctx)) // returns customErr
}

// output:
// context canceled
// not found

Implement with simple http server

package main

import (
	"context"
	"errors"
	"fmt"
	"log"
	"net/http"
	"time"
)

var (
	errGotEven = errors.New("ups we got even")
)

type RequestKey string

const (
	CancelKey RequestKey = "cancel"
)

func cancelContext(ctx context.Context, err error) {
	cancel, ok := ctx.Value(CancelKey).(context.CancelCauseFunc)
	if ok {
		cancel(err)
	}
}

func getOdd(w http.ResponseWriter, r *http.Request) {
	now := time.Now()
	if now.Minute()%2 == 1 {
		fmt.Fprintf(w, "now %v \n", now)
		return
	}

	cancelContext(r.Context(), errGotEven)
}

func logger(next http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		start := time.Now()
		defer func() {
			ctx := r.Context()

			msg := "OK"
			if err := context.Cause(ctx); err != nil {
				msg = err.Error()
			}

			log.Printf("%s - %s %s %s %s %v", r.RemoteAddr, r.Proto, r.Method, r.URL.RequestURI(), msg, time.Since(start))
		}()
		ctx := r.Context()
		ctx, cancel := context.WithCancelCause(ctx)
		ctx = context.WithValue(ctx, CancelKey, cancel)

		r = r.WithContext(ctx)
		next.ServeHTTP(w, r)
	})
}

func main() {
	srv := http.Server{
		Addr:         ":8888",
		WriteTimeout: 5 * time.Second,
		Handler:      http.TimeoutHandler(logger(http.HandlerFunc(getOdd)), 3*time.Second, "Timeout!\n"),
	}

	if err := srv.ListenAndServe(); err != nil {
		fmt.Printf("Server failed: %s\n", err)
	}
}

Try run with “go run main.go” then call curl localhost:8888 several times. Here’s the logs server:

 go run main.go
2023/02/10 01:39:03 127.0.0.1:56144 - HTTP/1.1 GET / OK 64.823µs
2023/02/10 01:39:08 127.0.0.1:51530 - HTTP/1.1 GET / OK 23.485µs
2023/02/10 01:39:11 127.0.0.1:51536 - HTTP/1.1 GET / OK 32.462µs
2023/02/10 01:39:14 127.0.0.1:51540 - HTTP/1.1 GET / OK 24.417µs
2023/02/10 01:39:56 127.0.0.1:50684 - HTTP/1.1 GET / OK 26.29µs
2023/02/10 01:40:07 127.0.0.1:34338 - HTTP/1.1 GET / ups we got even 3.156µs
 curl localhost:8888
now 2023-02-10 01:41:29.905447708 +0000 UTC m=+151.664912082
 curl localhost:8888
err:ups we got even%
programming go context