Skip to content

Commit

Permalink
allow consumers to wrap the net.Listener to customize handling
Browse files Browse the repository at this point in the history
Allow consumers applications to provide a function which wraps around a
raw net.Listener and a loaded *tls.Config, returning another
net.Listener, in order to customize the manner in which net.Conn
instances are accepted and built.
  • Loading branch information
jgraettinger committed Sep 23, 2024
1 parent b64cfe9 commit 7392043
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 8 deletions.
2 changes: 1 addition & 1 deletion cmd/gazette/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (cmdServe) Execute(args []string) error {
}

// Bind our server listener, grabbing a random available port if Port is zero.
srv, err := server.New("", Config.Broker.Host, Config.Broker.Port, serverTLS, peerTLS, Config.Broker.MaxGRPCRecvSize)
srv, err := server.New("", Config.Broker.Host, Config.Broker.Port, serverTLS, peerTLS, Config.Broker.MaxGRPCRecvSize, nil)
mbp.Must(err, "building Server instance")

// If a file:// root was provided, ensure it exists and apply it.
Expand Down
15 changes: 12 additions & 3 deletions mainboilerplate/runconsumer/run_consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"context"
"crypto/tls"
"fmt"
"net"
"os"
"os/signal"
"syscall"
Expand Down Expand Up @@ -112,8 +113,9 @@ const iniFilename = "gazette.ini"

// Cmd wraps a Config and Application to provide an Execute entry-point.
type Cmd struct {
Cfg Config
App Application
Cfg Config
App Application
WrapListener func(net.Listener, *tls.Config) (net.Listener, error)
}

func (sc Cmd) Execute(args []string) error {
Expand Down Expand Up @@ -155,7 +157,14 @@ func (sc Cmd) Execute(args []string) error {
}

// Bind our server listener, grabbing a random available port if Port is zero.
srv, err := server.New("", bc.Consumer.Host, bc.Consumer.Port, serverTLS, peerTLS, bc.Consumer.MaxGRPCRecvSize)
srv, err := server.New(
"", // Bind all interfaces
bc.Consumer.Host,
bc.Consumer.Port,
serverTLS, peerTLS,
bc.Consumer.MaxGRPCRecvSize,
sc.WrapListener,
)
mbp.Must(err, "building Server instance")

if bc.Broker.Cache.Size <= 0 {
Expand Down
17 changes: 13 additions & 4 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ type Server struct {
// and `port` for serving traffic directed at `host`.
// `port` may be empty, in which case a random free port is assigned.
// if `tlsConfig` is non-nil, the Server uses TLS (and is otherwise in the clear).
func New(iface, host, port string, serverTLS, peerTLS *tls.Config, maxSize uint32) (*Server, error) {
func New(
iface, host, port string,
serverTLS, peerTLS *tls.Config,
maxGRPCRecvSize uint32,
wrapListener func(net.Listener, *tls.Config) (net.Listener, error),
) (*Server, error) {
var network, bind string
if port == "" {
network, bind = "tcp", fmt.Sprintf("%s:0", iface) // Assign a random free port.
Expand Down Expand Up @@ -100,10 +105,14 @@ func New(iface, host, port string, serverTLS, peerTLS *tls.Config, maxSize uint3
GRPCServer: grpc.NewServer(
grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor),
grpc.UnaryInterceptor(grpc_prometheus.UnaryServerInterceptor),
grpc.MaxRecvMsgSize(int(maxSize)),
grpc.MaxRecvMsgSize(int(maxGRPCRecvSize)),
),
}
if serverTLS != nil {
if wrapListener != nil {
if listener, err = wrapListener(listener, serverTLS); err != nil {
return nil, fmt.Errorf("failed to wrap listener: %w", err)
}
} else if serverTLS != nil {
listener = tls.NewListener(listener, serverTLS)
}
srv.CMux = cmux.New(listener)
Expand Down Expand Up @@ -193,7 +202,7 @@ func BuildTLSConfig(certPath, keyPath, trustedCAPath string) (*tls.Config, error
// MustLoopback builds and returns a new Server instance bound to a random
// port on the loopback interface. It panics on error.
func MustLoopback() *Server {
if srv, err := New("127.0.0.1", "127.0.0.1", "", nil, nil, 1<<20); err != nil {
if srv, err := New("127.0.0.1", "127.0.0.1", "", nil, nil, 1<<20, nil); err != nil {
log.WithField("err", err).Panic("failed to build Server")
panic("not reached")
} else {
Expand Down

0 comments on commit 7392043

Please sign in to comment.