diff --git a/cmd/jimmsrv/main.go b/cmd/jimmsrv/main.go index 5624b807f..07bf8cba9 100644 --- a/cmd/jimmsrv/main.go +++ b/cmd/jimmsrv/main.go @@ -1,4 +1,4 @@ -// Copyright 2024 Canonical. +// Copyright 2025 Canonical. package main @@ -209,5 +209,6 @@ func start(ctx context.Context, s *service.Service) error { }) s.Go(httpsrv.ListenAndServe) zapctx.Info(ctx, "Successfully started JIMM server") + return nil } diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go index 5a554be58..127a3d7df 100644 --- a/internal/ssh/ssh.go +++ b/internal/ssh/ssh.go @@ -25,7 +25,7 @@ type Resolver interface { AddrFromModelUUID(ctx context.Context, user openfga.User, modelUUID string) (string, error) } -// fowardMessage is the struct holding the information about the jump message received by the ssh client. +// forwardMessage is the struct holding the information about the jump message received by the ssh client. type forwardMessage struct { DestAddr string DestPort uint32 @@ -40,16 +40,23 @@ type Server struct { resolver Resolver } -// NewJumpSSHServer creates the jump server struct. -func NewJumpSSHServer(ctx context.Context, port int, resolver Resolver) (Server, error) { - zapctx.Info(ctx, "NewSSHServer") +// Config is the struct holding the configuration for the jump server. +type Config struct { + Port string + HostKey []byte + MaxConcurrentConnections string +} + +// NewJumpServer creates the jump server struct. +func NewJumpServer(ctx context.Context, config Config, resolver Resolver) (Server, error) { + zapctx.Info(ctx, "NewJumpServer") if resolver == nil { return Server{}, fmt.Errorf("Cannot create JumpSSHServer with a nil resolver.") } server := Server{ Server: &ssh.Server{ - Addr: fmt.Sprintf(":%d", port), + Addr: fmt.Sprintf(":%s", config.Port), ChannelHandlers: map[string]ssh.ChannelHandler{ "direct-tcpip": directTCPIPHandler(resolver), }, @@ -59,6 +66,11 @@ func NewJumpSSHServer(ctx context.Context, port int, resolver Resolver) (Server, }, resolver: resolver, } + s, err := gossh.ParsePrivateKey([]byte(config.HostKey)) + if err != nil { + return Server{}, fmt.Errorf("Cannot parse hostkey.") + } + server.AddHostKey(s) return server, nil } diff --git a/internal/ssh/ssh_test.go b/internal/ssh/ssh_test.go index 687afcbaa..ae7ff18b3 100644 --- a/internal/ssh/ssh_test.go +++ b/internal/ssh/ssh_test.go @@ -36,6 +36,7 @@ type sshSuite struct { jumpSSHServer ssh.Server jumpServerPort int privateKey gossh.Signer + hostKey gossh.Signer testInDestinationServerF func(fm ssh.ForwardMessage) received chan bool } @@ -71,13 +72,29 @@ func (s *sshSuite) Init(c *qt.C) { port, err = jimmtest.GetFreePort() c.Assert(err, qt.IsNil) s.jumpServerPort = port - s.jumpSSHServer, err = ssh.NewJumpSSHServer(context.Background(), port, resolver{}) + k, err := rsa.GenerateKey(rand.Reader, 2048) + c.Assert(err, qt.IsNil) + hostKey := pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(k), + }, + ) + s.hostKey, err = gossh.ParsePrivateKey(hostKey) + c.Assert(err, qt.IsNil) + + s.jumpSSHServer, err = ssh.NewJumpServer(context.Background(), + ssh.Config{ + Port: fmt.Sprint(port), + HostKey: hostKey}, + resolver{}, + ) c.Assert(err, qt.IsNil) go func() { _ = s.jumpSSHServer.ListenAndServe() }() - k, err := rsa.GenerateKey(rand.Reader, 2048) + k, err = rsa.GenerateKey(rand.Reader, 2048) c.Assert(err, qt.IsNil) keyPEM := pem.EncodeToMemory( &pem.Block{ @@ -98,8 +115,7 @@ func (s *sshSuite) Init(c *qt.C) { func (s *sshSuite) TestSSHJump(c *qt.C) { client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ - //nolint:gosec // this will be removed once we handle hostkeys - HostKeyCallback: gossh.InsecureIgnoreHostKey(), + HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()), Auth: []gossh.AuthMethod{ gossh.PublicKeys(s.privateKey), }, @@ -130,8 +146,7 @@ func (s *sshSuite) TestSSHJump(c *qt.C) { func (s *sshSuite) TestSSHJumpDialFail(c *qt.C) { _, err := gossh.Dial("tcp", fmt.Sprintf(":%d", 1), &gossh.ClientConfig{ - //nolint:gosec // this will be removed once we handle hostkeys - HostKeyCallback: gossh.InsecureIgnoreHostKey(), + HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()), Auth: []gossh.AuthMethod{ gossh.PublicKeys(s.privateKey), }, @@ -142,8 +157,7 @@ func (s *sshSuite) TestSSHJumpDialFail(c *qt.C) { func (s *sshSuite) TestSSHFinalDestinationDialFail(c *qt.C) { client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ - //nolint:gosec // this will be removed once we handle hostkeys - HostKeyCallback: gossh.InsecureIgnoreHostKey(), + HostKeyCallback: gossh.FixedHostKey(s.hostKey.PublicKey()), Auth: []gossh.AuthMethod{ gossh.PublicKeys(s.privateKey), },