From 2ebd48edfa1acc82faf6f9fa8d623973967ef4e3 Mon Sep 17 00:00:00 2001 From: SimoneDutto Date: Fri, 10 Jan 2025 11:32:44 +0100 Subject: [PATCH] feat: add ssh jump server --- go.mod | 2 + go.sum | 4 + internal/ssh/export_test.go | 4 + internal/ssh/ssh.go | 134 +++++++++++++++++++++++++++ internal/ssh/ssh_test.go | 177 ++++++++++++++++++++++++++++++++++++ internal/utils/utils.go | 14 +++ 6 files changed, 335 insertions(+) create mode 100644 internal/ssh/export_test.go create mode 100644 internal/ssh/ssh.go create mode 100644 internal/ssh/ssh_test.go diff --git a/go.mod b/go.mod index a56f08639..9ed11d213 100644 --- a/go.mod +++ b/go.mod @@ -349,6 +349,8 @@ require ( require github.com/golang-migrate/migrate/v4 v4.17.1 require ( + github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect + github.com/gliderlabs/ssh v0.3.8 // indirect github.com/klauspost/compress v1.17.7 // indirect github.com/moby/patternmatcher v0.6.0 // indirect github.com/moby/sys/sequential v0.6.0 // indirect diff --git a/go.sum b/go.sum index 4b98f0c71..a340e473e 100644 --- a/go.sum +++ b/go.sum @@ -91,6 +91,8 @@ github.com/adrg/xdg v0.3.3 h1:s/tV7MdqQnzB1nKY8aqHvAMD+uCiuEDzVB5HLRY849U= github.com/adrg/xdg v0.3.3/go.mod h1:61xAR2VZcggl2St4O9ohF5qCKe08+JDmE4VNzPFQvOQ= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ= github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw= @@ -253,6 +255,8 @@ github.com/gdamore/tcell/v2 v2.5.1/go.mod h1:wSkrPaXoiIWZqW/g7Px4xc79di6FTcpB8tv github.com/getkin/kin-openapi v0.125.0 h1:jyQCyf2qXS1qvs2U00xQzkGCqYPhEhZDmSmVt65fXno= github.com/getkin/kin-openapi v0.125.0/go.mod h1:wb1aSZA/iWmorQP9KTAS/phLj/t17B5jT7+fS8ed9NM= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= +github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s= github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-chi/render v1.0.2 h1:4ER/udB0+fMWB2Jlf15RV3F4A2FDuYi/9f+lFttR/Lg= diff --git a/internal/ssh/export_test.go b/internal/ssh/export_test.go new file mode 100644 index 000000000..5a5bec014 --- /dev/null +++ b/internal/ssh/export_test.go @@ -0,0 +1,4 @@ +// Copyright 2025 Canonical. +package ssh + +type ForwardMessage = forwardMessage diff --git a/internal/ssh/ssh.go b/internal/ssh/ssh.go new file mode 100644 index 000000000..5bb0b4741 --- /dev/null +++ b/internal/ssh/ssh.go @@ -0,0 +1,134 @@ +// Copyright 2025 Canonical. + +package ssh + +import ( + "context" + "fmt" + "io" + + "github.com/gliderlabs/ssh" + "github.com/juju/zaputil/zapctx" + "go.uber.org/zap" + gossh "golang.org/x/crypto/ssh" + + "github.com/canonical/jimm/v3/internal/openfga" +) + +// JUJU_SSH_DEFAULT_PORT is the default port we expect the juju controllers to respond on. +const JUJU_SSH_DEFAULT_PORT = 2223 + +// Resolver is the interface with the methods needed by the ssh jump server to route request. +type Resolver interface { + // GetAddrFromModelUUID is the method to resolve the address of the controller to contact given the model UUID. + GetAddrFromModelUUID(ctx context.Context, user openfga.User, modelUUID string) (string, error) +} + +// fowardMessage is the struct holding the information about the jump message received by the ssh client. +type forwardMessage struct { + DestAddr string + DestPort uint32 + SrcAddr string + SrcPort uint32 +} + +// Server is the custom struct to embed the gliderlabs.ssh server and a resolver. +type Server struct { + *ssh.Server + + resolver Resolver +} + +// NewJumpSSHServer creates the jump server struct. +func NewJumpSSHServer(ctx context.Context, port int, resolver Resolver) (Server, error) { + zapctx.Info(ctx, "NewSSHServer") + server := Server{ + Server: &ssh.Server{ + Addr: fmt.Sprintf(":%d", port), + ChannelHandlers: map[string]ssh.ChannelHandler{ + "direct-tcpip": directTCPIPHandler(resolver), + }, + PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool { + return true + }, + }, + resolver: resolver, + } + + return server, nil +} + +func directTCPIPHandler(resolver Resolver) func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { + return func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { + d := forwardMessage{} + + k := newChan.ExtraData() + + if err := gossh.Unmarshal(k, &d); err != nil { + rejectConnectionAndLogError(ctx, newChan, "Failed to parse channel data", err) + return + } + + dest := fmt.Sprintf("%s:%d", d.DestAddr, d.DestPort) + if d.DestPort == 0 { + d.DestPort = JUJU_SSH_DEFAULT_PORT + } + addr, err := resolver.GetAddrFromModelUUID(ctx, openfga.User{}, dest) + if err != nil { + rejectConnectionAndLogError(ctx, newChan, "Failed to parse channel data", err) + return + } + // this is temporary. The way we dial to the controller will heavily change. + client, err := gossh.Dial("tcp", fmt.Sprintf("%s:%d", addr, d.DestPort), &gossh.ClientConfig{ + //nolint:gosec // this will be removed once we handle hostkeys + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + Auth: []gossh.AuthMethod{ + gossh.PasswordCallback(func() (secret string, err error) { + return "jwt", nil + }), + }, + }) + if err != nil { + rejectConnectionAndLogError(ctx, newChan, fmt.Sprintf("Failed to connect to %s: %v", dest, err), err) + return + } + + dChan, reqs, err := client.OpenChannel("direct-tcpip", gossh.Marshal(d)) + if err != nil { + rejectConnectionAndLogError(ctx, newChan, "Failed to open destination channel", err) + return + } + + go gossh.DiscardRequests(reqs) + + ch, reqs, err := newChan.Accept() + if err != nil { + dChan.Close() + return + } + + go gossh.DiscardRequests(reqs) + + go func() { + defer ch.Close() + defer dChan.Close() + _, err := io.Copy(ch, dChan) + rejectConnectionAndLogError(ctx, newChan, "Failed to copy data from src to dts", err) + }() + go func() { + defer ch.Close() + defer dChan.Close() + _, err := io.Copy(dChan, ch) + rejectConnectionAndLogError(ctx, newChan, "Failed to copy data from dst to src", err) + }() + zapctx.Info(ctx, fmt.Sprintf("Proxying connection from %s:%d to %s:%d \n", d.SrcAddr, d.SrcPort, d.DestAddr, d.DestPort)) + } +} + +func rejectConnectionAndLogError(ctx context.Context, newChan gossh.NewChannel, msg string, err error) { + zapctx.Error(ctx, msg, zap.Error(err)) + err = newChan.Reject(gossh.ConnectionFailed, msg) + if err != nil { + zapctx.Error(ctx, "Failed to reject channel", zap.Error(err)) + } +} diff --git a/internal/ssh/ssh_test.go b/internal/ssh/ssh_test.go new file mode 100644 index 000000000..609679f6e --- /dev/null +++ b/internal/ssh/ssh_test.go @@ -0,0 +1,177 @@ +// Copyright 2025 Canonical. +package ssh_test + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "strconv" + "strings" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "github.com/frankban/quicktest/qtsuite" + gliderssh "github.com/gliderlabs/ssh" + gossh "golang.org/x/crypto/ssh" + + "github.com/canonical/jimm/v3/internal/openfga" + "github.com/canonical/jimm/v3/internal/ssh" + "github.com/canonical/jimm/v3/internal/utils" +) + +type resolver struct{} + +func (r resolver) GetAddrFromModelUUID(ctx context.Context, user openfga.User, modelName string) (string, error) { + return "", nil +} + +type sshSuite struct { + destinationJujuSSHServer gliderssh.Server + destinationServerPort int + jumpSSHServer ssh.Server + jumpServerPort int + privateKey gossh.Signer + testF func(fm ssh.ForwardMessage) + received chan bool +} + +func (s *sshSuite) Init(c *qt.C) { + s.received = make(chan bool) + port, err := utils.GetFreePort() + c.Assert(err, qt.IsNil) + s.destinationServerPort = port + s.destinationJujuSSHServer = gliderssh.Server{ + Addr: fmt.Sprintf(":%d", port), + ChannelHandlers: map[string]gliderssh.ChannelHandler{ + "direct-tcpip": func(srv *gliderssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx gliderssh.Context) { + d := ssh.ForwardMessage{} + if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { + err := newChan.Reject(gossh.ConnectionFailed, "Failed to parse channel data") + c.Assert(err, qt.IsNil) + return + } + _, _, err := newChan.Accept() + c.Assert(err, qt.IsNil) + s.testF(d) + s.received <- true + }, + }, + } + go func() { + err := s.destinationJujuSSHServer.ListenAndServe() + c.Assert(err, qt.IsNil) + }() + s.destinationServerPort, err = strconv.Atoi(strings.Split(s.destinationJujuSSHServer.Addr, ":")[1]) + c.Assert(err, qt.IsNil) + + port, err = utils.GetFreePort() + c.Assert(err, qt.IsNil) + s.jumpServerPort = port + s.jumpSSHServer, err = ssh.NewJumpSSHServer(context.Background(), port, resolver{}) + c.Assert(err, qt.IsNil) + go func() { + err := s.jumpSSHServer.ListenAndServe() + c.Assert(err, qt.IsNil) + }() + + k, err := rsa.GenerateKey(rand.Reader, 2048) + c.Assert(err, qt.IsNil) + keyPEM := pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(k), + }, + ) + + s.privateKey, err = gossh.ParsePrivateKey(keyPEM) + c.Assert(err, qt.IsNil) +} + +// CleanUp doesn't exist in qtsuite, so it needs to be called manually +func (s *sshSuite) CleanUp(c *qt.C) { + err := s.destinationJujuSSHServer.Close() + c.Assert(err, qt.IsNil) + err = s.jumpSSHServer.Close() + c.Assert(err, qt.IsNil) +} + +func (s *sshSuite) TestSSHJump(c *qt.C) { + defer s.CleanUp(c) + client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ + //nolint:gosec // this will be removed once we handle hostkeys + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + Auth: []gossh.AuthMethod{ + gossh.PublicKeys(s.privateKey), + }, + }) + c.Assert(err, qt.IsNil) + defer client.Close() + + // send forward message + msg := ssh.ForwardMessage{ + DestAddr: "model1", + //nolint:gosec + DestPort: uint32(s.destinationServerPort), + SrcAddr: "localhost", + SrcPort: 0, + } + s.testF = func(fm ssh.ForwardMessage) { + c.Assert(fm.DestAddr, qt.Equals, "model1") + } + ch, _, err := client.OpenChannel("direct-tcpip", gossh.Marshal(&msg)) + c.Assert(err, qt.IsNil) + defer ch.Close() + select { + case <-s.received: + case <-time.After(100 * time.Millisecond): + c.Fatalf("ssh jump test timeout") + } +} + +func (s *sshSuite) TestSSHJumpDialFail(c *qt.C) { + defer s.CleanUp(c) + _, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort+1), &gossh.ClientConfig{ + //nolint:gosec // this will be removed once we handle hostkeys + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + Auth: []gossh.AuthMethod{ + gossh.PublicKeys(s.privateKey), + }, + }) + c.Assert(err, qt.ErrorMatches, ".*connect: connection refused.*") +} + +func (s *sshSuite) TestSSHFinalDestinationDialFail(c *qt.C) { + defer s.CleanUp(c) + + client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ + //nolint:gosec // this will be removed once we handle hostkeys + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + Auth: []gossh.AuthMethod{ + gossh.PublicKeys(s.privateKey), + }, + }) + c.Assert(err, qt.IsNil) + + // send forward message + msg := ssh.ForwardMessage{ + DestAddr: "model1", + //nolint:gosec + DestPort: uint32(s.destinationServerPort + 1), + SrcAddr: "localhost", + SrcPort: 0, + } + s.testF = func(fm ssh.ForwardMessage) { + c.Assert(fm.DestAddr, qt.Equals, "model1") + } + _, _, err = client.OpenChannel("direct-tcpip", gossh.Marshal(&msg)) + c.Assert(err, qt.ErrorMatches, ".*connect failed.*") + +} + +func TestIdentityManager(t *testing.T) { + qtsuite.Run(qt.New(t), &sshSuite{}) +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index f277a48b9..35ab2e716 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -5,6 +5,8 @@ import ( "context" "crypto/rand" "encoding/hex" + "errors" + "net" "github.com/juju/zaputil/zapctx" "go.uber.org/zap" @@ -21,3 +23,15 @@ func NewConversationID() string { } return hex.EncodeToString(buf) } + +// GetFreePort asks the kernel for a free open port that is ready to use. +func GetFreePort() (int, error) { + if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil { + var l *net.TCPListener + if l, err = net.ListenTCP("tcp", a); err == nil { + defer l.Close() + return l.Addr().(*net.TCPAddr).Port, nil + } + } + return 0, errors.New("Couldn't find any free port") +}