From b345fed8c0c9c3b67c81e9befc295a142cc272a7 Mon Sep 17 00:00:00 2001 From: SimoneDutto Date: Fri, 17 Jan 2025 15:14:33 +0100 Subject: [PATCH] implement limit connections --- internal/ssh/listener.go | 93 ++++++++++++++++++++++++++++++++++++++++ internal/ssh/ssh.go | 64 +++++++++++++++++++-------- internal/ssh/ssh_test.go | 37 ++++++++++++++-- 3 files changed, 174 insertions(+), 20 deletions(-) create mode 100644 internal/ssh/listener.go diff --git a/internal/ssh/listener.go b/internal/ssh/listener.go new file mode 100644 index 000000000..64d93eaea --- /dev/null +++ b/internal/ssh/listener.go @@ -0,0 +1,93 @@ +// Copyright 2025 Canonical. +package ssh + +import ( + "net" + "sync" + "time" +) + +// N.B.: +// This is a copypaste of netutil.LimiLister, but we add a timeout so when we are at the limit +// we actively close connections instead of waiting indefinetely. (Look at line 44) + +// LimitListenerWithTimeout returns a Listener that accepts at most n simultaneous +// connections from the provided Listener, and it timeouts when the max +// has been reached and no seats has been freed for the timeout period. +func LimitListenerWithTimeout(l net.Listener, n int, timeout time.Duration) net.Listener { + return &limitListener{ + Listener: l, + sem: make(chan struct{}, n), + done: make(chan struct{}), + timeout: timeout, + } +} + +type limitListener struct { + net.Listener + sem chan struct{} + closeOnce sync.Once // ensures the done chan is only closed once + done chan struct{} // no values sent; closed when Close is called + timeout time.Duration // timeout for acquiring the connection +} + +// acquire acquires the limiting semaphore. Returns true if successfully +// acquired, false if the listener is closed and the semaphore is not +// acquired. +func (l *limitListener) acquire() bool { + select { + case <-l.done: + return false + case l.sem <- struct{}{}: + return true + // we add a timeout here, so the connection is closed when the timeout has passed instead of waiting. + case <-time.After(l.timeout): + return false + } +} +func (l *limitListener) release() { <-l.sem } + +func (l *limitListener) Accept() (net.Conn, error) { + if !l.acquire() { + // If the semaphore isn't acquired because the listener was closed, expect + // that this call to accept won't block, but immediately return an error. + // If it instead returns a spurious connection (due to a bug in the + // Listener, such as https://golang.org/issue/50216), we immediately close + // it and try again. Some buggy Listener implementations (like the one in + // the aforementioned issue) seem to assume that Accept will be called to + // completion, and may otherwise fail to clean up the client end of pending + // connections. + for { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + c.Close() + } + } + + c, err := l.Listener.Accept() + if err != nil { + l.release() + return nil, err + } + return &limitListenerConn{Conn: c, release: l.release}, nil +} + +func (l *limitListener) Close() error { + err := l.Listener.Close() + l.closeOnce.Do(func() { close(l.done) }) + return err +} + +type limitListenerConn struct { + net.Conn + releaseOnce sync.Once + release func() +} + +func (l *limitListenerConn) Close() error { + err := l.Conn.Close() + l.releaseOnce.Do(l.release) + return err +} diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 196142c99..6525da938 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net" + "time" "github.com/gliderlabs/ssh" "github.com/juju/names/v5" @@ -19,6 +20,7 @@ import ( // juju_ssh_default_port is the default port we expect the juju controllers to respond on. const juju_ssh_default_port = 17022 +const defaultAcceptConnectionTimeout = time.Second type publicKeySSHUserKey struct{} @@ -47,40 +49,68 @@ type forwardMessage struct { type Config struct { Port string HostKey []byte - MaxConcurrentConnections string + MaxConcurrentConnections int + AcceptConnectionTimeout time.Duration +} + +type Server struct { + *ssh.Server + + MaxConcurrentConnections int + AcceptConnectionTimeout time.Duration } // NewJumpServer creates the jump server struct. -func NewJumpServer(ctx context.Context, config Config, sshAuthorizer SSHAuthorizer, sshResolver SSHResolver) (*ssh.Server, error) { +func NewJumpServer(ctx context.Context, config Config, sshAuthorizer SSHAuthorizer, sshResolver SSHResolver) (Server, error) { zapctx.Info(ctx, "NewJumpServer") if sshResolver == nil { - return nil, fmt.Errorf("Cannot create JumpSSHServer with a nil resolver.") + return Server{}, fmt.Errorf("Cannot create JumpSSHServer with a nil resolver.") } - server := &ssh.Server{ - Addr: fmt.Sprintf(":%s", config.Port), - ChannelHandlers: map[string]ssh.ChannelHandler{ - "direct-tcpip": directTCPIPHandler(sshResolver), - }, - PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool { - user, err := sshAuthorizer.PublicKeyHandler(ctx, ctx.User(), key.Marshal()) - if err != nil { - zapctx.Debug(ctx, fmt.Sprintf("cannot verify key for user %s", ctx.User()), zap.Error(err)) - return false - } - ctx.SetValue(publicKeySSHUserKey{}, user) - return true + server := Server{ + Server: &ssh.Server{ + Addr: fmt.Sprintf(":%s", config.Port), + ChannelHandlers: map[string]ssh.ChannelHandler{ + "direct-tcpip": directTCPIPHandler(sshResolver), + }, + PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool { + user, err := sshAuthorizer.PublicKeyHandler(ctx, ctx.User(), key.Marshal()) + if err != nil { + zapctx.Debug(ctx, fmt.Sprintf("cannot verify key for user %s", ctx.User()), zap.Error(err)) + return false + } + ctx.SetValue(publicKeySSHUserKey{}, user) + return true + }, }, + MaxConcurrentConnections: config.MaxConcurrentConnections, + AcceptConnectionTimeout: config.AcceptConnectionTimeout, } hostKey, err := gossh.ParsePrivateKey([]byte(config.HostKey)) if err != nil { - return nil, fmt.Errorf("Cannot parse hostkey.") + return Server{}, fmt.Errorf("Cannot parse hostkey.") } server.AddHostKey(hostKey) return server, nil } +// ListenAndServe create a LimitListenerWithTimeout and Serve requests. +func (srv Server) ListenAndServe() error { + ln, err := net.Listen("tcp", srv.Addr) + if srv.MaxConcurrentConnections == 0 { + srv.MaxConcurrentConnections = 100 + } + if srv.AcceptConnectionTimeout == 0 { + srv.AcceptConnectionTimeout = defaultAcceptConnectionTimeout + } + ln = LimitListenerWithTimeout(ln, srv.MaxConcurrentConnections, srv.AcceptConnectionTimeout) + if err != nil { + return err + } + return srv.Serve(ln) +} + func directTCPIPHandler(sshResolver SSHResolver) func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { return func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { d := forwardMessage{} diff --git a/internal/ssh/ssh_test.go b/internal/ssh/ssh_test.go index e07953167..cdde29395 100644 --- a/internal/ssh/ssh_test.go +++ b/internal/ssh/ssh_test.go @@ -32,7 +32,7 @@ import ( type sshSuite struct { destinationJujuSSHServer *gliderssh.Server destinationServerPort int - jumpSSHServer *gliderssh.Server + jumpSSHServer ssh.Server jumpServerPort int privateKey gossh.Signer hostKey gossh.Signer @@ -112,8 +112,9 @@ func (s *sshSuite) Init(c *qt.C) { jumpServer, err := ssh.NewJumpServer(context.Background(), ssh.Config{ - Port: fmt.Sprint(port), - HostKey: hostKey, + Port: fmt.Sprint(port), + HostKey: hostKey, + MaxConcurrentConnections: 10, }, mocks.SSHAuthorizer{ PublicKeyHandler_: func(ctx context.Context, claimUser string, key []byte) (*openfga.User, error) { @@ -269,6 +270,36 @@ func (s *sshSuite) TestSSHFinalDestinationDialFail(c *qt.C) { c.Assert(err, qt.ErrorMatches, ".*connect failed.*") } +func (s *sshSuite) TestMaxConcurrentConnections(c *qt.C) { + // fill the max of concurrent connection + maxConcurrentConnections := 10 + clients := make([]*gossh.Client, 0) + for range maxConcurrentConnections { + client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ + HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()), + Auth: []gossh.AuthMethod{ + gossh.PublicKeys(s.privateKey), + }, + User: "alice", + }) + c.Check(err, qt.IsNil) + clients = append(clients, client) + } + // this connection is sent when we are at maximum connection> + _, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ + HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()), + Auth: []gossh.AuthMethod{ + gossh.PublicKeys(s.privateKey), + }, + User: "alice", + Timeout: 50 * time.Millisecond, + }) + c.Check(err, qt.ErrorMatches, ".*connection reset.*") + for _, client := range clients { + client.Close() + } +} + func TestIdentityManager(t *testing.T) { qtsuite.Run(qt.New(t), &sshSuite{}) }