16/02/2024 - GO, REDIS
In this example we are going to implement a rate limiter feature for a Golang API to manage client requests. I assume you already know what rate limiting is hence I won't get into details of it. However, if you don't know, I suggest you to read up on it first so this post makes sense.
We have individual clients and they all can have different quota for our API. Some have unlimited access, some have no access at all and others have different amount of tokens to use in a specific period of time. Although many use minute by default, a period can be second, minute or hour. I'll be using minute in this example. As you already gathered, we will be using a "window" based algorithm with set of tokens. The whole approach is similar to how Redis describes rate limiting here. Please read, it is a very useful write up.
X-Rate-Limit-*
response headers are adjusted. See examples at the bottom.Every code is open for improvements and you won't always get production grade example in blog posts. I kept this post as short and simple as possible for the sake of readability but you should improve it as you wish. For example the Redis storage package contains a bit of code duplication which can be easily refactored. Same goes with the actual package, it can be tidied up a bit.
├── Makefile
├── api
│ ├── comment.go
│ └── home.go
├── errorx
│ └── error.go
├── main.go
├── middleware
│ └── ratelimit.go
├── migration
│ └── clients.sql
├── postgres
│ ├── client.go
│ └── storage.go
├── ratelimit
│ └── ratelimit.go
└── redis
├── args.go
├── lua.go
├── ratelimit.go
└── storage.go
run:
go run -race main.go
cache:
docker run \
--rm \
--publish 6379:6379 \
--name rate-limiter-redis \
redis:7.2-alpine3.19 \
--loglevel notice
database:
docker run \
--rm \
--env POSTGRES_DB=postgres \
--env POSTGRES_USER=postgres \
--env POSTGRES_PASSWORD=postgres \
--volume ${PWD}/migration:/docker-entrypoint-initdb.d:ro \
--publish 5432:5432 \
--name rate-limiter-postgres \
postgres:15.0-alpine \
postgres -c log_statement=all -c log_destination=stderr
package api
import "net/http"
// GET /api/v1/comments
func GetComment(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`Good job!`))
}
package api
import "net/http"
// GET /
func GetHome(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`Welcome home!`))
}
CREATE TABLE IF NOT EXISTS clients
(
id text NOT NULL,
rate_limit_quota integer NOT NULL,
CONSTRAINT clients_prk_id PRIMARY KEY (id)
);
INSERT INTO clients
(id, rate_limit_quota)
VALUES
('client-0', 0), -- Prohibited Access
('client-1', -1), -- Unlimited Access
('client-2', 3)
;
package errorx
import (
"errors"
)
var (
ErrInternal = errors.New("internal")
ErrResourceNotFound = errors.New("resource not found")
)
package middleware
import (
"context"
"log"
"net/http"
"time"
)
type rateLimiter interface {
Limit(ctx context.Context, now time.Time, clientID string) (map[string]string, bool, error)
}
type RateLimit struct {
RateLimit rateLimiter
}
func (r RateLimit) Handle(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, rr *http.Request) {
headers, ok, err := r.RateLimit.Limit(rr.Context(), time.Now().UTC(), rr.Header.Get("X-Client-Id"))
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
for k, v := range headers {
w.Header().Set(k, v)
}
if !ok {
w.WriteHeader(http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, rr)
})
}
package postgres
import "github.com/jackc/pgx/v5/pgxpool"
type Storage struct {
Postgres *pgxpool.Pool
}
package postgres
import (
"context"
"ratelimit/errorx"
"github.com/jackc/pgx/v5"
"github.com/pkg/errors"
)
// ClientQuota returns clients rate limiting quota.
func (s Storage) ClientQuota(ctx context.Context, id string) (int, error) {
qry := `SELECT rate_limit_quota FROM clients WHERE id = $1 LIMIT 1`
var quota int
err := s.Postgres.QueryRow(ctx, qry, id).Scan("a)
switch {
case err == nil:
return quota, nil
case errors.Is(err, pgx.ErrNoRows):
return 0, errorx.ErrResourceNotFound
}
return 0, errors.Wrap(errorx.ErrInternal, err.Error())
}
package ratelimit
import (
"context"
"errors"
"fmt"
"strconv"
"time"
"ratelimit/errorx"
"ratelimit/redis"
)
type clientStorer interface {
ClientQuota(ctx context.Context, id string) (int, error)
}
type cacheStorer interface {
UseRateLimitQuota(ctx context.Context, req redis.UseRateLimitRequest) (redis.RateLimitResponse, error)
SetRateLimitQuota(ctx context.Context, req redis.SetRateLimitRequest) (redis.RateLimitResponse, error)
}
type RateLimit struct {
// ClientStorage is used to store client and its rate limiting quota data.
ClientStorage clientStorer
// CacheStorage is used to store rate limiting info in a cache driver.
CacheStorage cacheStorer
// QuoteExpiryPeriod dictates how long the quote information should be kept
// in the cache. It directly affects how often database will be hit to get
// the quote. e.g. If it is set to 10 minutes, once evey 10 minutes database
// will be hit.
QuoteExpiryPeriod time.Duration
// WindowExpiryPeriod distaces how it it takes for a client to use allowed
// quota. If the rate limiting is per minute based, this value should
// be a minute. e.g. If client is allowed to use 60 rate limiting quota per
// minute, pass `time.Minute`.
WindowExpiryPeriod time.Duration
}
func (r RateLimit) Limit(ctx context.Context, now time.Time, clientID string) (map[string]string, bool, error) {
var (
quotaKey = r.key(clientID, "quota")
windowKey = r.key(clientID, now.Format("04"))
quotaExpiry = r.expiry(now, r.QuoteExpiryPeriod)
windowExpiry = r.expiry(now, r.WindowExpiryPeriod)
)
// STAGE: Use quota --------------------------------------------------------
useReq := redis.UseRateLimitRequest{
QuotaKey: quotaKey,
WindowKey: windowKey,
WindowExpiry: windowExpiry,
}
res, err := r.CacheStorage.UseRateLimitQuota(ctx, useReq)
switch {
case err == nil:
return r.headers(res, now), res.Allow, nil
case !errors.Is(err, errorx.ErrResourceNotFound):
return nil, false, fmt.Errorf("unable to use rate limit, terminating: %w", err)
}
// STAGE: Find quota -------------------------------------------------------
quota, err := r.ClientStorage.ClientQuota(ctx, clientID)
if err != nil {
return nil, false, fmt.Errorf("unable to find client quota, terminating: %w", err)
}
// STAGE: Set and use quota ------------------------------------------------
setReq := redis.SetRateLimitRequest{
QuotaAllowed: quota,
QuotaKey: quotaKey,
WindowKey: windowKey,
QuotaExpiry: quotaExpiry,
WindowExpiry: windowExpiry,
}
res, err = r.CacheStorage.SetRateLimitQuota(ctx, setReq)
if err != nil {
return nil, false, fmt.Errorf("unable to set rate limit quota, terminating: %w", err)
}
return r.headers(res, now), res.Allow, nil
}
func (r RateLimit) key(clientID, suffix string) string {
return "rate-limit:{" + clientID + "}:" + suffix
}
func (r RateLimit) expiry(now time.Time, period time.Duration) int64 {
return now.Truncate(time.Minute).Add(period).Unix()
}
func (r RateLimit) headers(res redis.RateLimitResponse, now time.Time) map[string]string {
headers := map[string]string{
"X-Rate-Limit-Policy": "Window",
"X-Rate-Limit-Quota": strconv.Itoa(res.Quota),
}
if res.Quota < 1 {
return headers
}
headers["X-Rate-Limit-Reset"] = strconv.FormatInt(res.Reset, 10)
headers["X-Rate-Limit-Used"] = strconv.Itoa(res.Used)
return headers
}
package redis
import "github.com/redis/go-redis/v9"
type Storage struct {
Redis *redis.Client
// Environment is used to construct cache prefix for all items.
Environment string
// Application is used to construct cache prefix for all items.
Application string
// UnlimitedAccess helps decide if a client has unlimited access to API. If
// so, no rate limiting will apply to it and all requests will be allowed.
// Clients allowed quota value must match this value.
UnlimitedAccess int
// ProhibitedAccess helps decide if a client is forbidden from accessing the
// API. If so, no rate limiting will apply to it and all requests will be
// disallowed. Clients allowed quota value must match this value.
ProhibitedAccess int
}
// key constructs custom cache key by prefixing it with common values.
func (s Storage) key(key string) string {
return s.Environment + ":" + s.Application + ":" + key
}
package redis
type SetRateLimitRequest struct {
QuotaAllowed int
QuotaKey string
WindowKey string
// QuotaExpiry specifies Unix time at which the key will expire in seconds.
// This directly links to how often database would be hit to retrieve the
// allowed quota as it is the source of truth. e.g. `600` equals to 10
// minutes which means database is hit once every 10 minutes.
QuotaExpiry int64
// WindowExpiry specifies Unix time at which the key will expire in seconds.
// Almost always it tends to be 1 minute window. e.g. If the current time is
// `19:21`, this would be `38` which translates to `19:59`.
WindowExpiry int64
}
type UseRateLimitRequest struct {
QuotaKey string
WindowKey string
// WindowExpiry specifies Unix time at which the key will expire in seconds.
// Almost always it tends to be 1 minute window. e.g. If the current time is
// `19:21`, this would be `38` which translates to `19:59`.
WindowExpiry int64
}
type RateLimitResponse struct {
// Qouta contains allowed quota for the client. Always set unless function
// returns an error.
Quota int
// Used contains the value which client has used out of their allowed quota
// for the current window. Always set unless client has either unlimited
// access or forbidden from accessing the API.
Used int
// Reset contains value in seconds for client to wait before sending any
// more requests. Same as `Used`.
Reset int64
// Allow reports is the request is allowed to proceed or not. Always set.
Allow bool
}
package redis
import "github.com/redis/go-redis/v9"
// useRateLimitQuota is meant to be used for incrementing a counter by a
// positive number (`inc`) up to a upper limit (`max`). If the counter value is
// not less than the upper limit, no action is taken and the overflowing number
// is returned. A counter always has an absolute expiry time in future as a unit
// timestamp.
//
// Example:
//
// Allow incrementing a counter by 1 up to 5.
// num, err := useRateLimitQuota.Run(ctx, s.Redis, []string{"key"}, 1, 5, expiry.Unix()).Int()
var useRateLimitQuota = redis.NewScript(`
local key = KEYS[1]
local inc = ARGV[1]
local max = tonumber(ARGV[2])
local exp = ARGV[3]
local val = redis.call("GET", key) or 0
val = val + inc
if (val > max) then
return val
end
redis.call("SET", key, val, "EXAT", exp)
return redis.call("GET", key)
`)
package redis
import (
"context"
"fmt"
"time"
"ratelimit/errorx"
"github.com/redis/go-redis/v9"
)
func (s Storage) UseRateLimitQuota(ctx context.Context, req UseRateLimitRequest) (RateLimitResponse, error) {
response := RateLimitResponse{}
// STAGE: Get quota --------------------------------------------------------
quota, err := s.Redis.Get(ctx, s.key(req.QuotaKey)).Int()
switch {
case err == redis.Nil:
return RateLimitResponse{}, errorx.ErrResourceNotFound
case err != nil:
return RateLimitResponse{}, err
}
response.Quota = quota
switch response.Quota {
case s.UnlimitedAccess:
response.Allow = true
return response, nil
case s.ProhibitedAccess:
return response, nil
}
// STAGE: Use quota --------------------------------------------------------
windowKey := s.key(req.WindowKey)
windowExp := req.WindowExpiry
response.Reset = req.WindowExpiry - time.Now().UTC().Unix()
used, err := useRateLimitQuota.Run(ctx, s.Redis, []string{windowKey}, 1, quota, windowExp).Int()
switch {
case err == redis.Nil:
response.Used = quota
return response, nil
case err != nil:
return RateLimitResponse{}, err
}
if used > quota {
response.Used = quota
return response, nil
}
response.Used = used
response.Allow = true
fmt.Println("USE: allow", used)
return response, nil
}
func (s Storage) SetRateLimitQuota(ctx context.Context, req SetRateLimitRequest) (RateLimitResponse, error) {
response := RateLimitResponse{
Reset: req.WindowExpiry - time.Now().UTC().Unix(),
Quota: req.QuotaAllowed,
}
// STAGE: Set quota --------------------------------------------------------
err := s.Redis.Do(ctx, "SET", s.key(req.QuotaKey), req.QuotaAllowed, "EXAT", req.QuotaExpiry).Err()
switch {
case err == redis.Nil:
response.Used = req.QuotaAllowed
return response, nil
case err != nil:
return RateLimitResponse{}, err
}
switch response.Quota {
case s.UnlimitedAccess:
response.Allow = true
return response, nil
case s.ProhibitedAccess:
return response, nil
}
// STAGE: Use quota --------------------------------------------------------
used, err := useRateLimitQuota.Run(ctx, s.Redis, []string{s.key(req.WindowKey)}, 1, req.QuotaAllowed, req.WindowExpiry).Int()
switch {
case err == redis.Nil:
response.Used = req.QuotaAllowed
return response, nil
case err != nil:
return RateLimitResponse{}, err
}
if used > req.QuotaAllowed {
response.Used = req.QuotaAllowed
return response, nil
}
response.Used = used
response.Allow = true
fmt.Println("SET: allow", used)
return response, nil
}
package main
import (
"context"
"log"
"net/http"
"time"
"ratelimit/api"
"ratelimit/middleware"
"ratelimit/postgres"
"ratelimit/ratelimit"
"ratelimit/redis"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgxpool"
redisx "github.com/redis/go-redis/v9"
)
func main() {
// POSTGRES CLIENT ---------------------------------------------------------
postgresClient, err := pgxpool.New(context.Background(), "postgres://postgres:postgres@0.0.0.0:5432/postgres?sslmode=disable")
if err != nil {
panic(err)
}
defer postgresClient.Close()
// POSTGRES STORAGE --------------------------------------------------------
postgresStorage := postgres.Storage{
Postgres: postgresClient,
}
// REDIS CLIENT ------------------------------------------------------------
redisClient := redisx.NewClient(&redisx.Options{
Addr: "0.0.0.0:6379",
DB: 0,
})
defer redisClient.Close()
// REDIS STORAGE -----------------------------------------------------------
redisStorage := redis.Storage{
Redis: redisClient,
Environment: "local",
Application: "blog",
UnlimitedAccess: -1,
ProhibitedAccess: 0,
}
// RATE LIMITER ------------------------------------------------------------
rateLimit := ratelimit.RateLimit{
ClientStorage: postgresStorage,
CacheStorage: redisStorage,
QuoteExpiryPeriod: time.Minute * 30,
WindowExpiryPeriod: time.Minute,
}
// MIDDLEWARE ---------------------------------------------------------------
rateLimitMiddleware := middleware.RateLimit{
RateLimit: rateLimit,
}
// ROUTING -----------------------------------------------------------------
router := chi.NewMux()
router.Route("/", func(rtr chi.Router) {
rtr.Get("/", api.GetHome)
rtr.Route("/api", func(rtr chi.Router) {
rtr.Use(rateLimitMiddleware.Handle)
rtr.Route("/v1", func(rtr chi.Router) {
rtr.Get("/comments", api.GetComment)
})
})
})
// SERVER ------------------------------------------------------------------
log.Fatalln(http.ListenAndServe("0.0.0.0:1234", router))
}
$ curl -kv GET 'http://0.0.0.0:1234/api/v1/comments' --header 'X-Client-Id: client-0'
X-Client-Id: client-0
HTTP/1.1 429 Too Many Requests
X-Rate-Limit-Policy: Window
X-Rate-Limit-Quota: 0
Content-Length: 0
$ curl -kv GET 'http://0.0.0.0:1234/api/v1/comments' --header 'X-Client-Id: client-1'
X-Client-Id: client-1
HTTP/1.1 200 OK
X-Rate-Limit-Policy: Window
X-Rate-Limit-Quota: -1
Content-Length: 9
$ curl -kv GET 'http://0.0.0.0:1234/api/v1/comments' --header 'X-Client-Id: client-2'
X-Client-Id: client-2
HTTP/1.1 200 OK
X-Rate-Limit-Quota: 3
X-Rate-Limit-Reset: 36
X-Rate-Limit-Used: 1
Content-Length: 9
$ curl -kv GET 'http://0.0.0.0:1234/api/v1/comments' --header 'X-Client-Id: client-2'
X-Client-Id: client-2
HTTP/1.1 200 OK
X-Rate-Limit-Quota: 3
X-Rate-Limit-Reset: 33
X-Rate-Limit-Used: 2
Content-Length: 9
$ curl -kv GET 'http://0.0.0.0:1234/api/v1/comments' --header 'X-Client-Id: client-2'
X-Client-Id: client-2
HTTP/1.1 200 OK
X-Rate-Limit-Quota: 3
X-Rate-Limit-Reset: 31
X-Rate-Limit-Used: 3
Content-Length: 9
$ curl -kv GET 'http://0.0.0.0:1234/api/v1/comments' --header 'X-Client-Id: client-2'
X-Client-Id: client-2
HTTP/1.1 429 Too Many Requests
X-Rate-Limit-Quota: 3
X-Rate-Limit-Reset: 25
X-Rate-Limit-Used: 3
Content-Length: 0