Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pkg/auth/discovery/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi
}
}

// Check if we have manual OAuth endpoints configured
// Check if we have OAuth endpoints configured
if config.AuthorizeURL != "" && config.TokenURL != "" {
logger.Infof("Using manual OAuth endpoints - authorize_url: %s, token_url: %s",
logger.Infof("Using OAuth endpoints - authorize_url: %s, token_url: %s",
config.AuthorizeURL, config.TokenURL)

oauthConfig, err = oauth.CreateOAuthConfigManual(
Expand Down
11 changes: 10 additions & 1 deletion pkg/runner/config_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package runner
import (
"context"
"fmt"
"net/url"
"slices"
"strings"

Expand Down Expand Up @@ -94,7 +95,15 @@ func (b *RunConfigBuilder) WithHost(host string) *RunConfigBuilder {

// WithTargetHost sets the target host (applies default if empty)
func (b *RunConfigBuilder) WithTargetHost(targetHost string) *RunConfigBuilder {
if targetHost == "" {
if b.config.RemoteURL != "" {
remoteURL, err := url.Parse(b.config.RemoteURL)
if err == nil {
targetHost = remoteURL.Host
} else {
logger.Warnf("Failed to parse remote URL: %v", err)
targetHost = transport.LocalhostIPv4
}
} else if targetHost == "" {
targetHost = transport.LocalhostIPv4
}
b.config.TargetHost = targetHost
Expand Down
7 changes: 6 additions & 1 deletion pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,12 @@ func (r *Runner) Run(ctx context.Context) error {
logger.Warnf("Warning: Failed to create client manager: %v", err)
} else {
transportType := labels.GetTransportType(r.Config.ContainerLabels)
serverURL := transport.GenerateMCPServerURL(transportType, "localhost", r.Config.Port, r.Config.ContainerName)
serverURL := transport.GenerateMCPServerURL(
transportType,
"localhost",
r.Config.Port,
r.Config.ContainerName,
r.Config.RemoteURL)

if err := clientManager.AddServerToClients(ctx, r.Config.ContainerName, serverURL, transportType, r.Config.Group); err != nil {
logger.Warnf("Warning: Failed to add server to client configurations: %v", err)
Expand Down
18 changes: 14 additions & 4 deletions pkg/transport/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"net/url"
"sync"

"golang.org/x/oauth2"
Expand Down Expand Up @@ -97,8 +98,8 @@ func NewHTTPTransport(
}

// SetRemoteURL sets the remote URL for the MCP server
func (t *HTTPTransport) SetRemoteURL(url string) {
t.remoteURL = url
func (t *HTTPTransport) SetRemoteURL(remoteURL string) {
t.remoteURL = remoteURL
}

// SetTokenSource sets the OAuth token source for remote authentication
Expand Down Expand Up @@ -271,8 +272,17 @@ func (t *HTTPTransport) Start(ctx context.Context) error {
var targetURI string

if t.remoteURL != "" {
// For remote MCP servers, use the remote URL directly
targetURI = t.remoteURL
remoteURL, err := url.Parse(t.remoteURL)
if err != nil {
return fmt.Errorf("failed to parse remote URL: %w", err)
}
// If the remote URL is a full URL, we need to extract the scheme and host
// and use them to construct the target URI
targetURI = (&url.URL{
Scheme: remoteURL.Scheme,
Host: remoteURL.Host,
}).String()

logger.Infof("Setting up transparent proxy to forward from host port %d to remote URL %s",
t.proxyPort, targetURI)
} else {
Expand Down
69 changes: 8 additions & 61 deletions pkg/transport/proxy/transparent/transparent_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,70 +135,27 @@ func (t *tracingTransport) forward(req *http.Request) (*http.Response, error) {
return tr.RoundTrip(req)
}

// manualForward manually forwards a request to the remote server using an HTTP client
func (t *tracingTransport) manualForward(req *http.Request) (*http.Response, error) {
// Create a new request to the target URL
targetURL := t.p.targetURI
if req.URL.RawQuery != "" {
targetURL += "?" + req.URL.RawQuery
}

newReq, err := http.NewRequest(req.Method, targetURL, req.Body)
if err != nil {
return nil, fmt.Errorf("failed to create new request: %w", err)
}

// Copy headers from the original request
for name, values := range req.Header {
for _, value := range values {
newReq.Header.Add(name, value)
}
}

// Create HTTP client and make the request
client := &http.Client{
Timeout: 30 * time.Second,
}

return client.Do(newReq)
}

// nolint:gocyclo // This function handles multiple request types and is complex by design
func (t *tracingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.p.isRemote {
// Route based on transport type for remote servers
switch t.p.transportType {
case "sse":
// Use reverse proxy for SSE (streaming)
return t.forward(req)
case "streamable-http":
// Use manual HTTP client for streamable-http (request/response)
resp, err := t.manualForward(req)
if err != nil {
if errors.Is(err, context.Canceled) {
// Expected during shutdown or client disconnect—silently ignore
return nil, err
}
logger.Errorf("Failed to forward request: %v", err)
return nil, err
}
return resp, nil
default:
// Default to manual forwarding for unknown transport types
logger.Warnf("Unknown transport type '%s', using manual forwarding", t.p.transportType)
return t.manualForward(req)
// In case of remote servers, req.Host is set to the proxy host (localhost) which may cause 403 error,
// so we need to set it to the target URI host
if req.URL.Host != req.Host {
req.Host = req.URL.Host
}
}

// Original logic for local containers
reqBody := readRequestBody(req)

// thv proxy does not provide the transport type, so we need to detect it from the request
path := req.URL.Path
isMCP := strings.HasPrefix(path, "/mcp")
isJSON := strings.Contains(req.Header.Get("Content-Type"), "application/json")
sawInitialize := false

if isMCP && isJSON && len(reqBody) > 0 {
if len(reqBody) > 0 &&
((isMCP && isJSON) ||
t.p.transportType == types.TransportTypeStreamableHTTP.String()) {
sawInitialize = t.detectInitialize(reqBody)
}

Expand Down Expand Up @@ -343,17 +300,7 @@ func (p *TransparentProxy) Start(ctx context.Context) error {
return p.modifyForSessionID(resp)
}

// Create a handler that logs requests and strips /mcp path for remote servers
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// For remote servers, strip the /mcp path since they expect requests at the root
if p.isRemote && strings.HasPrefix(r.URL.Path, "/mcp") {
// Strip /mcp from the path for remote servers
r.URL.Path = strings.TrimPrefix(r.URL.Path, "/mcp")
if r.URL.Path == "" {
r.URL.Path = "/"
}
logger.Infof("Transparent proxy: %s %s -> %s", r.Method, r.URL.Path, targetURL)
}
proxy.ServeHTTP(w, r)
})

Expand Down
24 changes: 21 additions & 3 deletions pkg/transport/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,38 @@ package transport

import (
"fmt"
"net/url"

"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/transport/ssecommon"
"github.com/stacklok/toolhive/pkg/transport/streamable"
"github.com/stacklok/toolhive/pkg/transport/types"
)

// GenerateMCPServerURL generates the URL for an MCP server
func GenerateMCPServerURL(transportType string, host string, port int, containerName string) string {
// if remoteURL is provided, remote server path will be used as the path of the proxy
func GenerateMCPServerURL(transportType string, host string, port int, containerName, remoteURL string) string {
path := ""
if remoteURL != "" {
targetURL, err := url.Parse(remoteURL)
if err != nil {
logger.Errorf("Failed to parse target URI: %v", err)
return ""
}
path = targetURL.Path
}
// The URL format is: http://host:port/sse#container-name
// Both SSE and STDIO transport types use an SSE proxy
if transportType == types.TransportTypeSSE.String() || transportType == types.TransportTypeStdio.String() {
return fmt.Sprintf("http://%s:%d%s#%s", host, port, ssecommon.HTTPSSEEndpoint, containerName)
if path == "" || path == "/" {
path = ssecommon.HTTPSSEEndpoint
}
return fmt.Sprintf("http://%s:%d%s#%s", host, port, path, containerName)
} else if transportType == types.TransportTypeStreamableHTTP.String() {
return fmt.Sprintf("http://%s:%d/%s", host, port, streamable.HTTPStreamableHTTPEndpoint)
if path == "" || path == "/" {
path = "/" + streamable.HTTPStreamableHTTPEndpoint
}
return fmt.Sprintf("http://%s:%d%s", host, port, path)
}
return ""
}
70 changes: 61 additions & 9 deletions pkg/transport/url_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func TestGenerateMCPServerURL(t *testing.T) {
host string
port int
containerName string
targetURI string
expected string
}{
{
Expand All @@ -25,6 +26,7 @@ func TestGenerateMCPServerURL(t *testing.T) {
host: "localhost",
port: 12345,
containerName: "test-container",
targetURI: "",
expected: "http://localhost:12345" + ssecommon.HTTPSSEEndpoint + "#test-container",
},
{
Expand All @@ -33,6 +35,7 @@ func TestGenerateMCPServerURL(t *testing.T) {
host: "localhost",
port: 12345,
containerName: "test-container",
targetURI: "",
expected: "http://localhost:12345" + ssecommon.HTTPSSEEndpoint + "#test-container",
},
{
Expand All @@ -41,30 +44,79 @@ func TestGenerateMCPServerURL(t *testing.T) {
host: "localhost",
port: 12345,
containerName: "test-container",
targetURI: "",
expected: "http://localhost:12345/" + streamable.HTTPStreamableHTTPEndpoint,
},
{
name: "Different host with SSE",
transportType: types.TransportTypeSSE.String(),
host: "192.168.1.100",
port: 54321,
containerName: "another-container",
expected: "http://192.168.1.100:54321" + ssecommon.HTTPSSEEndpoint + "#another-container",
},
{
name: "Unsupported transport type",
transportType: "unsupported",
host: "localhost",
port: 12345,
containerName: "test-container",
targetURI: "",
expected: "",
},
{
name: "SSE transport with targetURI path",
transportType: types.TransportTypeSSE.String(),
host: "localhost",
port: 12345,
containerName: "test-container",
targetURI: "http://example.com/api/v1",
expected: "http://localhost:12345/api/v1#test-container",
},
{
name: "SSE transport with targetURI domain only",
transportType: types.TransportTypeSSE.String(),
host: "localhost",
port: 12345,
containerName: "test-container",
targetURI: "http://example.com",
expected: "http://localhost:12345/sse#test-container",
},
{
name: "SSE transport with targetURI root path",
transportType: types.TransportTypeSSE.String(),
host: "localhost",
port: 12345,
containerName: "test-container",
targetURI: "http://example.com/",
expected: "http://localhost:12345/sse#test-container",
},
// Major targetURI test cases - Streamable HTTP transport
{
name: "Streamable HTTP transport with targetURI path",
transportType: types.TransportTypeStreamableHTTP.String(),
host: "localhost",
port: 12345,
containerName: "test-container",
targetURI: "http://remote-server.com/path",
expected: "http://localhost:12345/path",
},
{
name: "Streamable HTTP transport with targetURI domain only",
transportType: types.TransportTypeStreamableHTTP.String(),
host: "localhost",
port: 12345,
containerName: "test-container",
targetURI: "http://remote-server.com",
expected: "http://localhost:12345/mcp",
},
{
name: "Streamable HTTP transport with targetURI root path",
transportType: types.TransportTypeStreamableHTTP.String(),
host: "localhost",
port: 12345,
containerName: "test-container",
targetURI: "http://remote-server.com/",
expected: "http://localhost:12345/mcp",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
url := GenerateMCPServerURL(tt.transportType, tt.host, tt.port, tt.containerName)
url := GenerateMCPServerURL(tt.transportType, tt.host, tt.port, tt.containerName, tt.targetURI)
if url != tt.expected {
t.Errorf("GenerateMCPServerURL() = %v, want %v", url, tt.expected)
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/workloads/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func WorkloadFromContainerInfo(container *runtime.ContainerInfo) (core.Workload,
// Generate URL for the MCP server
url := ""
if port > 0 {
url = transport.GenerateMCPServerURL(transportType, transport.LocalhostIPv4, port, name)
url = transport.GenerateMCPServerURL(transportType, transport.LocalhostIPv4, port, name, "")
}

tType, err := types.ParseTransportType(transportType)
Expand Down
Loading