Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement limit connections #1527

Open
wants to merge 1 commit into
base: v3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions internal/ssh/listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright 2025 Canonical.
package ssh

import (
"net"
"sync"
"time"
)

// N.B.:
// This is a copypaste of netutil.LimiLister, but we add a timeout so when we are at the limit
// we actively close connections instead of waiting indefinetely. (Look at line 44)

// LimitListenerWithTimeout returns a Listener that accepts at most n simultaneous
// connections from the provided Listener, and it timeouts when the max
// has been reached and no seats has been freed for the timeout period.
func LimitListenerWithTimeout(l net.Listener, n int, timeout time.Duration) net.Listener {
return &limitListener{
Listener: l,
sem: make(chan struct{}, n),
done: make(chan struct{}),
timeout: timeout,
}
}

type limitListener struct {
net.Listener
sem chan struct{}
closeOnce sync.Once // ensures the done chan is only closed once
done chan struct{} // no values sent; closed when Close is called
timeout time.Duration // timeout for acquiring the connection
}

// acquire acquires the limiting semaphore. Returns true if successfully
// acquired, false if the listener is closed and the semaphore is not
// acquired.
func (l *limitListener) acquire() bool {
select {
case <-l.done:
return false
case l.sem <- struct{}{}:
return true
// we add a timeout here, so the connection is closed when the timeout has passed instead of waiting.
case <-time.After(l.timeout):
return false
}
}
func (l *limitListener) release() { <-l.sem }

func (l *limitListener) Accept() (net.Conn, error) {
if !l.acquire() {
// If the semaphore isn't acquired because the listener was closed, expect
// that this call to accept won't block, but immediately return an error.
// If it instead returns a spurious connection (due to a bug in the
// Listener, such as https://golang.org/issue/50216), we immediately close
// it and try again. Some buggy Listener implementations (like the one in
// the aforementioned issue) seem to assume that Accept will be called to
// completion, and may otherwise fail to clean up the client end of pending
// connections.
for {
c, err := l.Listener.Accept()
if err != nil {
return nil, err
}
c.Close()
}
}

c, err := l.Listener.Accept()
if err != nil {
l.release()
return nil, err
}
return &limitListenerConn{Conn: c, release: l.release}, nil
}

func (l *limitListener) Close() error {
err := l.Listener.Close()
l.closeOnce.Do(func() { close(l.done) })
return err
}

type limitListenerConn struct {
net.Conn
releaseOnce sync.Once
release func()
}

func (l *limitListenerConn) Close() error {
err := l.Conn.Close()
l.releaseOnce.Do(l.release)
return err
}
64 changes: 47 additions & 17 deletions internal/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net"
"time"

"github.com/gliderlabs/ssh"
"github.com/juju/names/v5"
Expand All @@ -19,6 +20,7 @@ import (

// juju_ssh_default_port is the default port we expect the juju controllers to respond on.
const juju_ssh_default_port = 17022
const defaultAcceptConnectionTimeout = time.Second

type publicKeySSHUserKey struct{}

Expand Down Expand Up @@ -47,40 +49,68 @@ type forwardMessage struct {
type Config struct {
Port string
HostKey []byte
MaxConcurrentConnections string
MaxConcurrentConnections int
AcceptConnectionTimeout time.Duration
}

type Server struct {
*ssh.Server

MaxConcurrentConnections int
AcceptConnectionTimeout time.Duration
}

// NewJumpServer creates the jump server struct.
func NewJumpServer(ctx context.Context, config Config, sshAuthorizer SSHAuthorizer, sshResolver SSHResolver) (*ssh.Server, error) {
func NewJumpServer(ctx context.Context, config Config, sshAuthorizer SSHAuthorizer, sshResolver SSHResolver) (Server, error) {
zapctx.Info(ctx, "NewJumpServer")

if sshResolver == nil {
return nil, fmt.Errorf("Cannot create JumpSSHServer with a nil resolver.")
return Server{}, fmt.Errorf("Cannot create JumpSSHServer with a nil resolver.")
}
server := &ssh.Server{
Addr: fmt.Sprintf(":%s", config.Port),
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": directTCPIPHandler(sshResolver),
},
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
user, err := sshAuthorizer.PublicKeyHandler(ctx, ctx.User(), key.Marshal())
if err != nil {
zapctx.Debug(ctx, fmt.Sprintf("cannot verify key for user %s", ctx.User()), zap.Error(err))
return false
}
ctx.SetValue(publicKeySSHUserKey{}, user)
return true
server := Server{
Server: &ssh.Server{
Addr: fmt.Sprintf(":%s", config.Port),
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": directTCPIPHandler(sshResolver),
},
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
user, err := sshAuthorizer.PublicKeyHandler(ctx, ctx.User(), key.Marshal())
if err != nil {
zapctx.Debug(ctx, fmt.Sprintf("cannot verify key for user %s", ctx.User()), zap.Error(err))
return false
}
ctx.SetValue(publicKeySSHUserKey{}, user)
return true
},
},
MaxConcurrentConnections: config.MaxConcurrentConnections,
AcceptConnectionTimeout: config.AcceptConnectionTimeout,
}
hostKey, err := gossh.ParsePrivateKey([]byte(config.HostKey))
if err != nil {
return nil, fmt.Errorf("Cannot parse hostkey.")
return Server{}, fmt.Errorf("Cannot parse hostkey.")
}
server.AddHostKey(hostKey)

return server, nil
}

// ListenAndServe create a LimitListenerWithTimeout and Serve requests.
func (srv Server) ListenAndServe() error {
ln, err := net.Listen("tcp", srv.Addr)
if srv.MaxConcurrentConnections == 0 {
srv.MaxConcurrentConnections = 100
}
if srv.AcceptConnectionTimeout == 0 {
srv.AcceptConnectionTimeout = defaultAcceptConnectionTimeout
}
ln = LimitListenerWithTimeout(ln, srv.MaxConcurrentConnections, srv.AcceptConnectionTimeout)
if err != nil {
return err
}
return srv.Serve(ln)
}

func directTCPIPHandler(sshResolver SSHResolver) func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
return func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
d := forwardMessage{}
Expand Down
37 changes: 34 additions & 3 deletions internal/ssh/ssh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
type sshSuite struct {
destinationJujuSSHServer *gliderssh.Server
destinationServerPort int
jumpSSHServer *gliderssh.Server
jumpSSHServer ssh.Server
jumpServerPort int
privateKey gossh.Signer
hostKey gossh.Signer
Expand Down Expand Up @@ -112,8 +112,9 @@ func (s *sshSuite) Init(c *qt.C) {

jumpServer, err := ssh.NewJumpServer(context.Background(),
ssh.Config{
Port: fmt.Sprint(port),
HostKey: hostKey,
Port: fmt.Sprint(port),
HostKey: hostKey,
MaxConcurrentConnections: 10,
},
mocks.SSHAuthorizer{
PublicKeyHandler_: func(ctx context.Context, claimUser string, key []byte) (*openfga.User, error) {
Expand Down Expand Up @@ -269,6 +270,36 @@ func (s *sshSuite) TestSSHFinalDestinationDialFail(c *qt.C) {
c.Assert(err, qt.ErrorMatches, ".*connect failed.*")
}

func (s *sshSuite) TestMaxConcurrentConnections(c *qt.C) {
// fill the max of concurrent connection
maxConcurrentConnections := 10
clients := make([]*gossh.Client, 0)
for range maxConcurrentConnections {
client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{
HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()),
Auth: []gossh.AuthMethod{
gossh.PublicKeys(s.privateKey),
},
User: "alice",
})
c.Check(err, qt.IsNil)
clients = append(clients, client)
}
// this connection is sent when we are at maximum connection>
_, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{
HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()),
Auth: []gossh.AuthMethod{
gossh.PublicKeys(s.privateKey),
},
User: "alice",
Timeout: 50 * time.Millisecond,
})
c.Check(err, qt.ErrorMatches, ".*connection reset.*")
for _, client := range clients {
client.Close()
}
}

func TestIdentityManager(t *testing.T) {
qtsuite.Run(qt.New(t), &sshSuite{})
}
Loading