-
Notifications
You must be signed in to change notification settings - Fork 2
/
tlslimit.go
189 lines (163 loc) · 5.67 KB
/
tlslimit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
// Package tlslimit provides a rate limiter to fewer expensive TLS handshakes.
package tlslimit
import (
"crypto/tls"
"errors"
"net"
"time"
"github.com/shaj13/libcache"
_ "github.com/shaj13/libcache/lru"
"golang.org/x/time/rate"
)
// Option configures Limiter using the functional options paradigm
// popularized by Rob Pike and Dave Cheney. If you're unfamiliar with this style,
// see https://commandcenter.blogspot.com/2014/01/self-referential-functions-and-design.html and
// https://dave.cheney.net/2014/10/17/functional-options-for-friendly-apis.
type Option interface {
apply(*Limiter)
}
// OptionFunc implements Option interface.
type optionFunc func(*Limiter)
// apply the configuration to the provided config.
func (fn optionFunc) apply(r *Limiter) {
fn(r)
}
// WithGetCertificate returns a tls.Certificate based on the given
// tls.ClientHelloInfo. It will only be called if the Limiter.GetCertificate
// sat in tls.Config and client TLS handshakes rate limits does not exceeded.
//
// If fn is nil or returns nil, then the TLS handshakes will be aborted.
//
// See the documentation of [tls.Config]: https://pkg.go.dev/crypto/tls#Config for more information.
func WithGetCertificate(fn func(*tls.ClientHelloInfo) (*tls.Certificate, error)) Option {
return optionFunc(func(l *Limiter) {
l.getCertificate = fn
})
}
// WithGetConfigForClient returns a tls.Config based on the given
// tls.ClientHelloInfo. It will only be called if the Limiter.GetConfigForClient
// sat in tls.Config and client tls handshakes rate limit does not exceeded.
//
// If fn is nil or returns nil, then the original tls.Config will be used.
//
// See the documentation of [tls.Config]: https://pkg.go.dev/crypto/tls#Config for more information.
func WithGetConfigForClient(fn func(*tls.ClientHelloInfo) (*tls.Config, error)) Option {
return optionFunc(func(l *Limiter) {
l.getConfigForClient = fn
})
}
// WithLimit defines the maximum frequency of TLS handshakes.
// A zero Limit allows no TLS handshakes.
func WithLimit(r time.Duration) Option {
return optionFunc(func(l *Limiter) {
l.r = r
})
}
// WithBursts defines the maximum number of TLS handshakes.
// A zero Burst allows no TLS handshakes.
func WithBursts(b int) Option {
return optionFunc(func(l *Limiter) {
l.b = b
})
}
// WithTLSHostname apply rate limiting per domain
// by using *tls.ClientHelloInfo.ServerName as a key.
func WithTLSHostname() Option {
return optionFunc(func(l *Limiter) {
l.keyFn = func(ci *tls.ClientHelloInfo) string {
return ci.ServerName
}
})
}
// WithTLSClientIP apply rate limiting per IP
// by using *tls.ClientHelloInfo.Conn.RemoteAddr() as a key.
func WithTLSClientIP() Option {
return optionFunc(func(l *Limiter) {
l.keyFn = func(ci *tls.ClientHelloInfo) string {
remoteAddr := ci.Conn.RemoteAddr().String()
remoteAddr, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr
}
return remoteAddr
}
})
}
// WithCacheMaxSize defines maximum number of cache entries.
func WithCacheMaxSize(size int) Option {
return optionFunc(func(l *Limiter) {
l.cache = libcache.LRU.New(size)
})
}
// NewLimiter returns a new Limiter that allows TLS handshakes up to rate r and permits
// bursts of at most b tokens.
func NewLimiter(opts ...Option) *Limiter {
lim := new(Limiter)
for _, opt := range opts {
opt.apply(lim)
}
if lim.cache == nil {
lim.cache = libcache.LRU.New(0)
}
return lim
}
// Limiter controls how frequently TLS handshakes are allowed to happen.
// It implements a "token bucket" of size b, initially full and refilled
// at rate r tokens per duration.
// See https://en.wikipedia.org/wiki/Token_bucket for more about token buckets.
//
// The zero value is a valid Limiter, but it will reject all TLS handshakes.
// Use NewLimiter to create non-zero Limiters.
//
// Limiter has two main methods, GetCertificate, and GetConfigForClient
// suitable to be used in tls.Config
//
// Each of the two methods consumes a single token.
// If no token is available, It returns error to abort TLS handshake.
// If client reuse TLS connections (HTTP2), the two methods will not be invoked
// by "crypto/tls" package then the rate limiting will not be applied.
// This translates to fewer expensive TLS handshakes, mitigates SSL/TLS exhaustion DDoS attacks,
// and an overall reduction in required server resources without affecting the overall number of
// concurrent requests that the server can handle.
//
// Limiter by default applies global rate limiting.
// Use WithTLSHostname or WithTLSClientIP to apply rate limiting per ip or domain.
type Limiter struct {
cache libcache.Cache
r time.Duration
b int
keyFn func(*tls.ClientHelloInfo) string
getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error)
getConfigForClient func(*tls.ClientHelloInfo) (*tls.Config, error)
}
func (lim *Limiter) GetCertificate(ci *tls.ClientHelloInfo) (*tls.Certificate, error) {
if err := lim.limit(ci); err != nil || lim.getCertificate == nil {
return nil, err
}
return lim.getCertificate(ci)
}
func (lim *Limiter) GetConfigForClient(ci *tls.ClientHelloInfo) (*tls.Config, error) {
if err := lim.limit(ci); err != nil || lim.getConfigForClient == nil {
return nil, err
}
return lim.getConfigForClient(ci)
}
func (lim *Limiter) limit(ci *tls.ClientHelloInfo) error {
err := errors.New("tlslimit: too many TLS handshakes")
key := "global"
if lim.cache == nil {
return err
}
if lim.keyFn != nil {
key = lim.keyFn(ci)
}
v, ok := lim.cache.Load(key)
if !ok && lim.r > 0 {
v = rate.NewLimiter(rate.Every(lim.r), lim.b)
lim.cache.StoreWithTTL(key, v, lim.r)
}
if lim.r <= 0 || !v.(*rate.Limiter).Allow() {
return err
}
return nil
}