forgejo-tickets/internal/middleware/ratelimit.go

87 lines
1.8 KiB
Go

package middleware
import (
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
type ipRecord struct {
mu sync.Mutex
timestamps []time.Time
}
// RateLimiter holds the per-IP rate limiting state.
type RateLimiter struct {
ips sync.Map // map[string]*ipRecord
limit int
window time.Duration
}
// NewRateLimiter creates a rate limiter that allows `limit` requests per `window` per IP.
func NewRateLimiter(limit int, window time.Duration) *RateLimiter {
rl := &RateLimiter{
limit: limit,
window: window,
}
// Periodically clean up stale entries
go rl.cleanup()
return rl
}
// cleanup removes entries that have no recent timestamps every 5 minutes.
func (rl *RateLimiter) cleanup() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
now := time.Now()
rl.ips.Range(func(key, value any) bool {
rec := value.(*ipRecord)
rec.mu.Lock()
if len(rec.timestamps) == 0 || now.Sub(rec.timestamps[len(rec.timestamps)-1]) > rl.window {
rec.mu.Unlock()
rl.ips.Delete(key)
} else {
rec.mu.Unlock()
}
return true
})
}
}
// Middleware returns a Gin middleware that enforces the rate limit.
func (rl *RateLimiter) Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
ip := c.ClientIP()
val, _ := rl.ips.LoadOrStore(ip, &ipRecord{})
rec := val.(*ipRecord)
rec.mu.Lock()
now := time.Now()
cutoff := now.Add(-rl.window)
// Remove timestamps outside the sliding window
valid := 0
for _, t := range rec.timestamps {
if t.After(cutoff) {
rec.timestamps[valid] = t
valid++
}
}
rec.timestamps = rec.timestamps[:valid]
if len(rec.timestamps) >= rl.limit {
rec.mu.Unlock()
c.AbortWithStatus(http.StatusTooManyRequests)
return
}
rec.timestamps = append(rec.timestamps, now)
rec.mu.Unlock()
c.Next()
}
}