Skip to content

Commit

Permalink
feat: add ssh jump server
Browse files Browse the repository at this point in the history
  • Loading branch information
SimoneDutto committed Jan 10, 2025
1 parent 8c6df2a commit 2ebd48e
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 0 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,8 @@ require (
require github.com/golang-migrate/migrate/v4 v4.17.1

require (
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/gliderlabs/ssh v0.3.8 // indirect
github.com/klauspost/compress v1.17.7 // indirect
github.com/moby/patternmatcher v0.6.0 // indirect
github.com/moby/sys/sequential v0.6.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ github.com/adrg/xdg v0.3.3 h1:s/tV7MdqQnzB1nKY8aqHvAMD+uCiuEDzVB5HLRY849U=
github.com/adrg/xdg v0.3.3/go.mod h1:61xAR2VZcggl2St4O9ohF5qCKe08+JDmE4VNzPFQvOQ=
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/antlr4-go/antlr/v4 v4.13.1 h1:SqQKkuVZ+zWkMMNkjy5FZe5mr5WURWnlpmOuzYWrPrQ=
github.com/antlr4-go/antlr/v4 v4.13.1/go.mod h1:GKmUxMtwp6ZgGwZSva4eWPC5mS6vUAmOABFgjdkM7Nw=
Expand Down Expand Up @@ -253,6 +255,8 @@ github.com/gdamore/tcell/v2 v2.5.1/go.mod h1:wSkrPaXoiIWZqW/g7Px4xc79di6FTcpB8tv
github.com/getkin/kin-openapi v0.125.0 h1:jyQCyf2qXS1qvs2U00xQzkGCqYPhEhZDmSmVt65fXno=
github.com/getkin/kin-openapi v0.125.0/go.mod h1:wb1aSZA/iWmorQP9KTAS/phLj/t17B5jT7+fS8ed9NM=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU=
github.com/go-chi/chi/v5 v5.0.12 h1:9euLV5sTrTNTRUU9POmDUvfxyj6LAABLUcEWO+JJb4s=
github.com/go-chi/chi/v5 v5.0.12/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8=
github.com/go-chi/render v1.0.2 h1:4ER/udB0+fMWB2Jlf15RV3F4A2FDuYi/9f+lFttR/Lg=
Expand Down
4 changes: 4 additions & 0 deletions internal/ssh/export_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// Copyright 2025 Canonical.
package ssh

type ForwardMessage = forwardMessage
134 changes: 134 additions & 0 deletions internal/ssh/ssh.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright 2025 Canonical.

package ssh

import (
"context"
"fmt"
"io"

"github.com/gliderlabs/ssh"
"github.com/juju/zaputil/zapctx"
"go.uber.org/zap"
gossh "golang.org/x/crypto/ssh"

"github.com/canonical/jimm/v3/internal/openfga"
)

// JUJU_SSH_DEFAULT_PORT is the default port we expect the juju controllers to respond on.
const JUJU_SSH_DEFAULT_PORT = 2223

// Resolver is the interface with the methods needed by the ssh jump server to route request.
type Resolver interface {
// GetAddrFromModelUUID is the method to resolve the address of the controller to contact given the model UUID.
GetAddrFromModelUUID(ctx context.Context, user openfga.User, modelUUID string) (string, error)
}

// fowardMessage is the struct holding the information about the jump message received by the ssh client.
type forwardMessage struct {
DestAddr string
DestPort uint32
SrcAddr string
SrcPort uint32
}

// Server is the custom struct to embed the gliderlabs.ssh server and a resolver.
type Server struct {
*ssh.Server

resolver Resolver
}

// NewJumpSSHServer creates the jump server struct.
func NewJumpSSHServer(ctx context.Context, port int, resolver Resolver) (Server, error) {
zapctx.Info(ctx, "NewSSHServer")
server := Server{
Server: &ssh.Server{
Addr: fmt.Sprintf(":%d", port),
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": directTCPIPHandler(resolver),
},
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool {
return true
},
},
resolver: resolver,
}

return server, nil
}

func directTCPIPHandler(resolver Resolver) func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
return func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
d := forwardMessage{}

k := newChan.ExtraData()

if err := gossh.Unmarshal(k, &d); err != nil {
rejectConnectionAndLogError(ctx, newChan, "Failed to parse channel data", err)
return
}

dest := fmt.Sprintf("%s:%d", d.DestAddr, d.DestPort)
if d.DestPort == 0 {
d.DestPort = JUJU_SSH_DEFAULT_PORT
}
addr, err := resolver.GetAddrFromModelUUID(ctx, openfga.User{}, dest)
if err != nil {
rejectConnectionAndLogError(ctx, newChan, "Failed to parse channel data", err)
return
}
// this is temporary. The way we dial to the controller will heavily change.
client, err := gossh.Dial("tcp", fmt.Sprintf("%s:%d", addr, d.DestPort), &gossh.ClientConfig{
//nolint:gosec // this will be removed once we handle hostkeys
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Auth: []gossh.AuthMethod{
gossh.PasswordCallback(func() (secret string, err error) {
return "jwt", nil
}),
},
})
if err != nil {
rejectConnectionAndLogError(ctx, newChan, fmt.Sprintf("Failed to connect to %s: %v", dest, err), err)
return
}

dChan, reqs, err := client.OpenChannel("direct-tcpip", gossh.Marshal(d))
if err != nil {
rejectConnectionAndLogError(ctx, newChan, "Failed to open destination channel", err)
return
}

go gossh.DiscardRequests(reqs)

ch, reqs, err := newChan.Accept()
if err != nil {
dChan.Close()
return
}

go gossh.DiscardRequests(reqs)

go func() {
defer ch.Close()
defer dChan.Close()
_, err := io.Copy(ch, dChan)
rejectConnectionAndLogError(ctx, newChan, "Failed to copy data from src to dts", err)
}()
go func() {
defer ch.Close()
defer dChan.Close()
_, err := io.Copy(dChan, ch)
rejectConnectionAndLogError(ctx, newChan, "Failed to copy data from dst to src", err)
}()
zapctx.Info(ctx, fmt.Sprintf("Proxying connection from %s:%d to %s:%d \n", d.SrcAddr, d.SrcPort, d.DestAddr, d.DestPort))
}
}

func rejectConnectionAndLogError(ctx context.Context, newChan gossh.NewChannel, msg string, err error) {
zapctx.Error(ctx, msg, zap.Error(err))
err = newChan.Reject(gossh.ConnectionFailed, msg)
if err != nil {
zapctx.Error(ctx, "Failed to reject channel", zap.Error(err))
}
}
177 changes: 177 additions & 0 deletions internal/ssh/ssh_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
// Copyright 2025 Canonical.
package ssh_test

import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"fmt"
"strconv"
"strings"
"testing"
"time"

qt "github.com/frankban/quicktest"
"github.com/frankban/quicktest/qtsuite"
gliderssh "github.com/gliderlabs/ssh"
gossh "golang.org/x/crypto/ssh"

"github.com/canonical/jimm/v3/internal/openfga"
"github.com/canonical/jimm/v3/internal/ssh"
"github.com/canonical/jimm/v3/internal/utils"
)

type resolver struct{}

func (r resolver) GetAddrFromModelUUID(ctx context.Context, user openfga.User, modelName string) (string, error) {
return "", nil
}

type sshSuite struct {
destinationJujuSSHServer gliderssh.Server
destinationServerPort int
jumpSSHServer ssh.Server
jumpServerPort int
privateKey gossh.Signer
testF func(fm ssh.ForwardMessage)
received chan bool
}

func (s *sshSuite) Init(c *qt.C) {
s.received = make(chan bool)
port, err := utils.GetFreePort()
c.Assert(err, qt.IsNil)
s.destinationServerPort = port
s.destinationJujuSSHServer = gliderssh.Server{
Addr: fmt.Sprintf(":%d", port),
ChannelHandlers: map[string]gliderssh.ChannelHandler{
"direct-tcpip": func(srv *gliderssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx gliderssh.Context) {
d := ssh.ForwardMessage{}
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
err := newChan.Reject(gossh.ConnectionFailed, "Failed to parse channel data")
c.Assert(err, qt.IsNil)
return
}
_, _, err := newChan.Accept()
c.Assert(err, qt.IsNil)
s.testF(d)
s.received <- true
},
},
}
go func() {
err := s.destinationJujuSSHServer.ListenAndServe()
c.Assert(err, qt.IsNil)
}()
s.destinationServerPort, err = strconv.Atoi(strings.Split(s.destinationJujuSSHServer.Addr, ":")[1])
c.Assert(err, qt.IsNil)

port, err = utils.GetFreePort()
c.Assert(err, qt.IsNil)
s.jumpServerPort = port
s.jumpSSHServer, err = ssh.NewJumpSSHServer(context.Background(), port, resolver{})
c.Assert(err, qt.IsNil)
go func() {
err := s.jumpSSHServer.ListenAndServe()
c.Assert(err, qt.IsNil)
}()

k, err := rsa.GenerateKey(rand.Reader, 2048)
c.Assert(err, qt.IsNil)
keyPEM := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(k),
},
)

s.privateKey, err = gossh.ParsePrivateKey(keyPEM)
c.Assert(err, qt.IsNil)
}

// CleanUp doesn't exist in qtsuite, so it needs to be called manually
func (s *sshSuite) CleanUp(c *qt.C) {
err := s.destinationJujuSSHServer.Close()
c.Assert(err, qt.IsNil)
err = s.jumpSSHServer.Close()
c.Assert(err, qt.IsNil)
}

func (s *sshSuite) TestSSHJump(c *qt.C) {
defer s.CleanUp(c)
client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{
//nolint:gosec // this will be removed once we handle hostkeys
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Auth: []gossh.AuthMethod{
gossh.PublicKeys(s.privateKey),
},
})
c.Assert(err, qt.IsNil)
defer client.Close()

// send forward message
msg := ssh.ForwardMessage{
DestAddr: "model1",
//nolint:gosec
DestPort: uint32(s.destinationServerPort),
SrcAddr: "localhost",
SrcPort: 0,
}
s.testF = func(fm ssh.ForwardMessage) {
c.Assert(fm.DestAddr, qt.Equals, "model1")
}
ch, _, err := client.OpenChannel("direct-tcpip", gossh.Marshal(&msg))
c.Assert(err, qt.IsNil)
defer ch.Close()
select {
case <-s.received:
case <-time.After(100 * time.Millisecond):
c.Fatalf("ssh jump test timeout")
}
}

func (s *sshSuite) TestSSHJumpDialFail(c *qt.C) {
defer s.CleanUp(c)
_, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort+1), &gossh.ClientConfig{
//nolint:gosec // this will be removed once we handle hostkeys
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Auth: []gossh.AuthMethod{
gossh.PublicKeys(s.privateKey),
},
})
c.Assert(err, qt.ErrorMatches, ".*connect: connection refused.*")
}

func (s *sshSuite) TestSSHFinalDestinationDialFail(c *qt.C) {
defer s.CleanUp(c)

client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{
//nolint:gosec // this will be removed once we handle hostkeys
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
Auth: []gossh.AuthMethod{
gossh.PublicKeys(s.privateKey),
},
})
c.Assert(err, qt.IsNil)

// send forward message
msg := ssh.ForwardMessage{
DestAddr: "model1",
//nolint:gosec
DestPort: uint32(s.destinationServerPort + 1),
SrcAddr: "localhost",
SrcPort: 0,
}
s.testF = func(fm ssh.ForwardMessage) {
c.Assert(fm.DestAddr, qt.Equals, "model1")
}
_, _, err = client.OpenChannel("direct-tcpip", gossh.Marshal(&msg))
c.Assert(err, qt.ErrorMatches, ".*connect failed.*")

}

func TestIdentityManager(t *testing.T) {
qtsuite.Run(qt.New(t), &sshSuite{})
}
14 changes: 14 additions & 0 deletions internal/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"context"
"crypto/rand"
"encoding/hex"
"errors"
"net"

"github.com/juju/zaputil/zapctx"
"go.uber.org/zap"
Expand All @@ -21,3 +23,15 @@ func NewConversationID() string {
}
return hex.EncodeToString(buf)
}

// GetFreePort asks the kernel for a free open port that is ready to use.
func GetFreePort() (int, error) {
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
var l *net.TCPListener
if l, err = net.ListenTCP("tcp", a); err == nil {
defer l.Close()
return l.Addr().(*net.TCPAddr).Port, nil
}
}
return 0, errors.New("Couldn't find any free port")
}

0 comments on commit 2ebd48e

Please sign in to comment.