-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8c6df2a
commit 2ebd48e
Showing
6 changed files
with
335 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
// Copyright 2025 Canonical. | ||
package ssh | ||
|
||
type ForwardMessage = forwardMessage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
// Copyright 2025 Canonical. | ||
|
||
package ssh | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"io" | ||
|
||
"github.com/gliderlabs/ssh" | ||
"github.com/juju/zaputil/zapctx" | ||
"go.uber.org/zap" | ||
gossh "golang.org/x/crypto/ssh" | ||
|
||
"github.com/canonical/jimm/v3/internal/openfga" | ||
) | ||
|
||
// JUJU_SSH_DEFAULT_PORT is the default port we expect the juju controllers to respond on. | ||
const JUJU_SSH_DEFAULT_PORT = 2223 | ||
|
||
// Resolver is the interface with the methods needed by the ssh jump server to route request. | ||
type Resolver interface { | ||
// GetAddrFromModelUUID is the method to resolve the address of the controller to contact given the model UUID. | ||
GetAddrFromModelUUID(ctx context.Context, user openfga.User, modelUUID string) (string, error) | ||
} | ||
|
||
// fowardMessage is the struct holding the information about the jump message received by the ssh client. | ||
type forwardMessage struct { | ||
DestAddr string | ||
DestPort uint32 | ||
SrcAddr string | ||
SrcPort uint32 | ||
} | ||
|
||
// Server is the custom struct to embed the gliderlabs.ssh server and a resolver. | ||
type Server struct { | ||
*ssh.Server | ||
|
||
resolver Resolver | ||
} | ||
|
||
// NewJumpSSHServer creates the jump server struct. | ||
func NewJumpSSHServer(ctx context.Context, port int, resolver Resolver) (Server, error) { | ||
zapctx.Info(ctx, "NewSSHServer") | ||
server := Server{ | ||
Server: &ssh.Server{ | ||
Addr: fmt.Sprintf(":%d", port), | ||
ChannelHandlers: map[string]ssh.ChannelHandler{ | ||
"direct-tcpip": directTCPIPHandler(resolver), | ||
}, | ||
PublicKeyHandler: func(ctx ssh.Context, key ssh.PublicKey) bool { | ||
return true | ||
}, | ||
}, | ||
resolver: resolver, | ||
} | ||
|
||
return server, nil | ||
} | ||
|
||
func directTCPIPHandler(resolver Resolver) func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { | ||
return func(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { | ||
d := forwardMessage{} | ||
|
||
k := newChan.ExtraData() | ||
|
||
if err := gossh.Unmarshal(k, &d); err != nil { | ||
rejectConnectionAndLogError(ctx, newChan, "Failed to parse channel data", err) | ||
return | ||
} | ||
|
||
dest := fmt.Sprintf("%s:%d", d.DestAddr, d.DestPort) | ||
if d.DestPort == 0 { | ||
d.DestPort = JUJU_SSH_DEFAULT_PORT | ||
} | ||
addr, err := resolver.GetAddrFromModelUUID(ctx, openfga.User{}, dest) | ||
if err != nil { | ||
rejectConnectionAndLogError(ctx, newChan, "Failed to parse channel data", err) | ||
return | ||
} | ||
// this is temporary. The way we dial to the controller will heavily change. | ||
client, err := gossh.Dial("tcp", fmt.Sprintf("%s:%d", addr, d.DestPort), &gossh.ClientConfig{ | ||
//nolint:gosec // this will be removed once we handle hostkeys | ||
HostKeyCallback: gossh.InsecureIgnoreHostKey(), | ||
Auth: []gossh.AuthMethod{ | ||
gossh.PasswordCallback(func() (secret string, err error) { | ||
return "jwt", nil | ||
}), | ||
}, | ||
}) | ||
if err != nil { | ||
rejectConnectionAndLogError(ctx, newChan, fmt.Sprintf("Failed to connect to %s: %v", dest, err), err) | ||
return | ||
} | ||
|
||
dChan, reqs, err := client.OpenChannel("direct-tcpip", gossh.Marshal(d)) | ||
if err != nil { | ||
rejectConnectionAndLogError(ctx, newChan, "Failed to open destination channel", err) | ||
return | ||
} | ||
|
||
go gossh.DiscardRequests(reqs) | ||
|
||
ch, reqs, err := newChan.Accept() | ||
if err != nil { | ||
dChan.Close() | ||
return | ||
} | ||
|
||
go gossh.DiscardRequests(reqs) | ||
|
||
go func() { | ||
defer ch.Close() | ||
defer dChan.Close() | ||
_, err := io.Copy(ch, dChan) | ||
rejectConnectionAndLogError(ctx, newChan, "Failed to copy data from src to dts", err) | ||
}() | ||
go func() { | ||
defer ch.Close() | ||
defer dChan.Close() | ||
_, err := io.Copy(dChan, ch) | ||
rejectConnectionAndLogError(ctx, newChan, "Failed to copy data from dst to src", err) | ||
}() | ||
zapctx.Info(ctx, fmt.Sprintf("Proxying connection from %s:%d to %s:%d \n", d.SrcAddr, d.SrcPort, d.DestAddr, d.DestPort)) | ||
} | ||
} | ||
|
||
func rejectConnectionAndLogError(ctx context.Context, newChan gossh.NewChannel, msg string, err error) { | ||
zapctx.Error(ctx, msg, zap.Error(err)) | ||
err = newChan.Reject(gossh.ConnectionFailed, msg) | ||
if err != nil { | ||
zapctx.Error(ctx, "Failed to reject channel", zap.Error(err)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
// Copyright 2025 Canonical. | ||
package ssh_test | ||
|
||
import ( | ||
"context" | ||
"crypto/rand" | ||
"crypto/rsa" | ||
"crypto/x509" | ||
"encoding/pem" | ||
"fmt" | ||
"strconv" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
qt "github.com/frankban/quicktest" | ||
"github.com/frankban/quicktest/qtsuite" | ||
gliderssh "github.com/gliderlabs/ssh" | ||
gossh "golang.org/x/crypto/ssh" | ||
|
||
"github.com/canonical/jimm/v3/internal/openfga" | ||
"github.com/canonical/jimm/v3/internal/ssh" | ||
"github.com/canonical/jimm/v3/internal/utils" | ||
) | ||
|
||
type resolver struct{} | ||
|
||
func (r resolver) GetAddrFromModelUUID(ctx context.Context, user openfga.User, modelName string) (string, error) { | ||
return "", nil | ||
} | ||
|
||
type sshSuite struct { | ||
destinationJujuSSHServer gliderssh.Server | ||
destinationServerPort int | ||
jumpSSHServer ssh.Server | ||
jumpServerPort int | ||
privateKey gossh.Signer | ||
testF func(fm ssh.ForwardMessage) | ||
received chan bool | ||
} | ||
|
||
func (s *sshSuite) Init(c *qt.C) { | ||
s.received = make(chan bool) | ||
port, err := utils.GetFreePort() | ||
c.Assert(err, qt.IsNil) | ||
s.destinationServerPort = port | ||
s.destinationJujuSSHServer = gliderssh.Server{ | ||
Addr: fmt.Sprintf(":%d", port), | ||
ChannelHandlers: map[string]gliderssh.ChannelHandler{ | ||
"direct-tcpip": func(srv *gliderssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx gliderssh.Context) { | ||
d := ssh.ForwardMessage{} | ||
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { | ||
err := newChan.Reject(gossh.ConnectionFailed, "Failed to parse channel data") | ||
c.Assert(err, qt.IsNil) | ||
return | ||
} | ||
_, _, err := newChan.Accept() | ||
c.Assert(err, qt.IsNil) | ||
s.testF(d) | ||
s.received <- true | ||
}, | ||
}, | ||
} | ||
go func() { | ||
err := s.destinationJujuSSHServer.ListenAndServe() | ||
c.Assert(err, qt.IsNil) | ||
}() | ||
s.destinationServerPort, err = strconv.Atoi(strings.Split(s.destinationJujuSSHServer.Addr, ":")[1]) | ||
c.Assert(err, qt.IsNil) | ||
|
||
port, err = utils.GetFreePort() | ||
c.Assert(err, qt.IsNil) | ||
s.jumpServerPort = port | ||
s.jumpSSHServer, err = ssh.NewJumpSSHServer(context.Background(), port, resolver{}) | ||
c.Assert(err, qt.IsNil) | ||
go func() { | ||
err := s.jumpSSHServer.ListenAndServe() | ||
c.Assert(err, qt.IsNil) | ||
}() | ||
|
||
k, err := rsa.GenerateKey(rand.Reader, 2048) | ||
c.Assert(err, qt.IsNil) | ||
keyPEM := pem.EncodeToMemory( | ||
&pem.Block{ | ||
Type: "RSA PRIVATE KEY", | ||
Bytes: x509.MarshalPKCS1PrivateKey(k), | ||
}, | ||
) | ||
|
||
s.privateKey, err = gossh.ParsePrivateKey(keyPEM) | ||
c.Assert(err, qt.IsNil) | ||
} | ||
|
||
// CleanUp doesn't exist in qtsuite, so it needs to be called manually | ||
func (s *sshSuite) CleanUp(c *qt.C) { | ||
err := s.destinationJujuSSHServer.Close() | ||
c.Assert(err, qt.IsNil) | ||
err = s.jumpSSHServer.Close() | ||
c.Assert(err, qt.IsNil) | ||
} | ||
|
||
func (s *sshSuite) TestSSHJump(c *qt.C) { | ||
defer s.CleanUp(c) | ||
client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ | ||
//nolint:gosec // this will be removed once we handle hostkeys | ||
HostKeyCallback: gossh.InsecureIgnoreHostKey(), | ||
Auth: []gossh.AuthMethod{ | ||
gossh.PublicKeys(s.privateKey), | ||
}, | ||
}) | ||
c.Assert(err, qt.IsNil) | ||
defer client.Close() | ||
|
||
// send forward message | ||
msg := ssh.ForwardMessage{ | ||
DestAddr: "model1", | ||
//nolint:gosec | ||
DestPort: uint32(s.destinationServerPort), | ||
SrcAddr: "localhost", | ||
SrcPort: 0, | ||
} | ||
s.testF = func(fm ssh.ForwardMessage) { | ||
c.Assert(fm.DestAddr, qt.Equals, "model1") | ||
} | ||
ch, _, err := client.OpenChannel("direct-tcpip", gossh.Marshal(&msg)) | ||
c.Assert(err, qt.IsNil) | ||
defer ch.Close() | ||
select { | ||
case <-s.received: | ||
case <-time.After(100 * time.Millisecond): | ||
c.Fatalf("ssh jump test timeout") | ||
} | ||
} | ||
|
||
func (s *sshSuite) TestSSHJumpDialFail(c *qt.C) { | ||
defer s.CleanUp(c) | ||
_, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort+1), &gossh.ClientConfig{ | ||
//nolint:gosec // this will be removed once we handle hostkeys | ||
HostKeyCallback: gossh.InsecureIgnoreHostKey(), | ||
Auth: []gossh.AuthMethod{ | ||
gossh.PublicKeys(s.privateKey), | ||
}, | ||
}) | ||
c.Assert(err, qt.ErrorMatches, ".*connect: connection refused.*") | ||
} | ||
|
||
func (s *sshSuite) TestSSHFinalDestinationDialFail(c *qt.C) { | ||
defer s.CleanUp(c) | ||
|
||
client, err := gossh.Dial("tcp", fmt.Sprintf(":%d", s.jumpServerPort), &gossh.ClientConfig{ | ||
//nolint:gosec // this will be removed once we handle hostkeys | ||
HostKeyCallback: gossh.InsecureIgnoreHostKey(), | ||
Auth: []gossh.AuthMethod{ | ||
gossh.PublicKeys(s.privateKey), | ||
}, | ||
}) | ||
c.Assert(err, qt.IsNil) | ||
|
||
// send forward message | ||
msg := ssh.ForwardMessage{ | ||
DestAddr: "model1", | ||
//nolint:gosec | ||
DestPort: uint32(s.destinationServerPort + 1), | ||
SrcAddr: "localhost", | ||
SrcPort: 0, | ||
} | ||
s.testF = func(fm ssh.ForwardMessage) { | ||
c.Assert(fm.DestAddr, qt.Equals, "model1") | ||
} | ||
_, _, err = client.OpenChannel("direct-tcpip", gossh.Marshal(&msg)) | ||
c.Assert(err, qt.ErrorMatches, ".*connect failed.*") | ||
|
||
} | ||
|
||
func TestIdentityManager(t *testing.T) { | ||
qtsuite.Run(qt.New(t), &sshSuite{}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters