diff --git a/.env.docker.template b/.env.docker.template new file mode 100644 index 0000000..2e4bbec --- /dev/null +++ b/.env.docker.template @@ -0,0 +1,25 @@ +# Admin credentials +ADMIN_USER=admin +ADMIN_PASSWORD=your_secure_password_here + +# API Keys +POLYGON_API_KEY=your_polygon_api_key_here + +# AWS Configuration +AWS_REGION=us-east-1 + +# OAuth Configuration +MCP_AUTH_ENABLED=true +MCP_AUTH_PROVIDER_TYPE=cognito +MCP_AUTH_BASE_URL=https://mcpgateway.ddns.net + +# Cognito Configuration +MCP_AUTH_COGNITO_USER_POOL_ID=your_user_pool_id_here +MCP_AUTH_COGNITO_CLIENT_ID=your_client_id_here +MCP_AUTH_COGNITO_CLIENT_SECRET=your_client_secret_here +MCP_AUTH_COGNITO_CALLBACK_URI=${MCP_AUTH_BASE_URL}/oauth/callback/cognito +MCP_AUTH_COGNITO_REGION=us-east-1 +MCP_AUTH_COGNITO_CUSTOM_DOMAIN=your_custom_domain_here + +# Secret Key (will be auto-generated if not provided) +# SECRET_KEY=your_secret_key_here \ No newline at end of file diff --git a/.gitignore b/.gitignore index ba8a3e2..4c3df39 100644 --- a/.gitignore +++ b/.gitignore @@ -182,3 +182,5 @@ cookies.txt registry/server_state.json registry/nginx_mcp_revproxy.conf logs/ +agents/test_results/ +.env.docker diff --git a/Dockerfile b/Dockerfile index 7d08259..bf420ef 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,7 +28,6 @@ COPY . /app/ # Note: We copy it here so it's part of the image layer COPY docker/nginx_rev_proxy.conf /app/docker/nginx_rev_proxy.conf - # Make the entrypoint script executable COPY docker/entrypoint.sh /app/docker/entrypoint.sh RUN chmod +x /app/docker/entrypoint.sh @@ -42,11 +41,16 @@ ARG SECRET_KEY="" ARG ADMIN_USER="admin" ARG ADMIN_PASSWORD="" ARG POLYGON_API_KEY="" +ARG MCP_AUTH_ENABLED="false" +ARG MCP_GATEWAY_DEV_MODE="true" +# Pass build args to runtime environment ENV SECRET_KEY=$SECRET_KEY ENV ADMIN_USER=$ADMIN_USER ENV ADMIN_PASSWORD=$ADMIN_PASSWORD ENV POLYGON_API_KEY=$POLYGON_API_KEY +ENV MCP_AUTH_ENABLED=$MCP_AUTH_ENABLED +ENV MCP_GATEWAY_DEV_MODE=$MCP_GATEWAY_DEV_MODE # Run the entrypoint script when the container launches ENTRYPOINT ["/app/docker/entrypoint.sh"] \ No newline at end of file diff --git a/README.md b/README.md index 1ecb61a..c41acea 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,14 @@ flowchart TB * **Unified access to a governed list of MCP servers:** Access multiple MCP servers through a common MCP gateway, enabling AI Agents to dynamically discover and execute MCP tools. * **Service Registration:** Register MCP services via JSON files or the web UI/API. * **Web UI:** Manage services, view status, and monitor health through a web interface. -* **Authentication:** Secure login system for the web UI and API access. +* **OAuth 2.1 Authentication:** + * Standards-compliant OAuth 2.1 authorization code flow with PKCE + * Support for multiple identity providers (AWS Cognito, Okta, and others) + * Fine-grained scope-based access control + * Secure session management with proper token handling +* **MCP Protocol Support:** + * Server-Sent Events (SSE) transport via `/api/execute/{service}` endpoint + * StreamableHTTP transport via `/api/streamable/{service}` endpoint * **Health Checks:** * Periodic background checks for enabled services (checks `/sse` endpoint). * Manual refresh trigger via UI button or API endpoint. @@ -175,12 +182,14 @@ The Gateway and the Registry are available as a Docker container. The package in export ADMIN_USER=admin export ADMIN_PASSWORD=your-admin-password export POLYGON_API_KEY=your-polygon-api-key + export GATEWAY_HOSTNAME=your-ec2-hostname # stop any previous instance docker stop mcp-gateway-container && docker rm mcp-gateway-container docker run -p 80:80 -p 443:443 -p 7860:7860 \ -e ADMIN_USER=$ADMIN_USER \ -e ADMIN_PASSWORD=$ADMIN_PASSWORD \ -e POLYGON_API_KEY=$POLYGON_API_KEY \ + -e GATEWAY_HOSTNAME=$GATEWAY_HOSTNAME \ -e SECRET_KEY=$(python3 -c 'import secrets; print(secrets.token_hex(32))') \ -v /var/log/mcp-gateway:/app/logs \ -v /opt/mcp-gateway/servers:/app/registry/servers \ @@ -224,6 +233,19 @@ The Gateway and the Registry are available as a Docker container. The package in 1. **View MCP server metadata:** Metadata about all MCP servers connected to the Registry is available in `/opt/mcp-gateway/servers` directory. The metadata includes information gathered from `ListTools` as well as information provided while registering the server. +1. **Test the Gateway and Registry with the sample Agent and test suite** + The repo includes a test agent that can connect to the Registry to discover tools and invoke them to do interesting tasks. This functionality can be invoked either standalone or as part of a test suite. + + ```{.python} + python agents\agent.py --mcp-registry-url http://localhost/mcpgw/sse --message "what is the current time in clarksburg, md" + ``` + + You can also run the full test suite and get a handy agent evaluation report. This test suite exercises the Registry functionality as well as tests the multiple built-in MCP servers provided by the Gateway. + ```{python} + python agents/test_suite.py + ``` + The result of the tests suites are available in the `agents/test_results` folder. It contains an `accuracy.json`, a `summary.json`, a `logs` folder and a `raw_data` folder that contains the verbose output from the agent. The test suite uses an LLM as a judge to evaluate the results for accuracy and tool usage quality. + #### Running the Gateway over HTTPS 1. Enable access to TCP port 443 from the IP address of your MCP client (your laptop, or anywhere) in the inbound rules in the security group associated with your EC2 instance. @@ -239,6 +261,7 @@ The Gateway and the Registry are available as a Docker container. The package in -e ADMIN_USER=$ADMIN_USER \ -e ADMIN_PASSWORD=$ADMIN_PASSWORD \ -e POLYGON_API_KEY=$POLYGON_API_KEY \ + -e GATEWAY_HOSTNAME=$GATEWAY_HOSTNAME \ -e SECRET_KEY=$(python3 -c 'import secrets; print(secrets.token_hex(32))') \ -v /path/to/certs:/etc/ssl/certs \ -v /path/to/private:/etc/ssl/private \ @@ -331,14 +354,96 @@ See the full API spec [here](docs/registry_api.md). *(Authentication via session cookie is required for most non-login routes)* +## OAuth 2.1 Authentication + +MCP Gateway now supports OAuth 2.1 authentication for secure access to the gateway and its services. + +```mermaid +sequenceDiagram + participant Client as MCP Client + participant Gateway as MCP Gateway + participant IdP as Identity Provider + participant Server as MCP Server + + Client->>Gateway: 1. Request tool execution + Gateway->>Gateway: 2. Check authentication + alt Not authenticated + Gateway->>Client: 3. Redirect to /oauth/login + Client->>Gateway: 4. Request /oauth/login + Gateway->>IdP: 5. Redirect to IdP with PKCE + Client->>IdP: 6. User authentication + IdP->>Gateway: 7. Authorization code + Gateway->>IdP: 8. Exchange code for tokens + Gateway->>Gateway: 9. Validate tokens, extract scopes + Gateway->>Client: 10. Set session cookie + end + Client->>Gateway: 11. Request with auth cookie + Gateway->>Gateway: 12. Verify scopes + alt Authorized + Gateway->>Server: 13. Proxy to MCP Server + Server->>Gateway: 14. Server response + Gateway->>Client: 15. Return response + else Not authorized + Gateway->>Client: Return 403 Forbidden + end +``` + +### Key Features + +- Standards-compliant OAuth 2.1 authorization code flow with PKCE +- Multiple identity provider support (AWS Cognito, Okta, and others) +- Fine-grained scope-based access control +- Support for both SSE and StreamableHTTP transport protocols +- Configurable through environment variables or config files + +### Setting up OAuth Authentication + +To enable OAuth 2.1 authentication: + +1. **Configure Environment Variables:** + - Set `MCP_AUTH_ENABLED=true` + - Generate a secure `SECRET_KEY` + - Configure provider-specific settings + +2. **Docker Run Command Example:** + +```bash +docker run -p 80:80 -p 443:443 -p 7860:7860 \ + -e ADMIN_USER=$ADMIN_USER \ + -e ADMIN_PASSWORD=$ADMIN_PASSWORD \ + -e POLYGON_API_KEY=$POLYGON_API_KEY \ + -e SECRET_KEY=$(python3 -c 'import secrets; print(secrets.token_hex(32))') \ + -e MCP_AUTH_ENABLED=true \ + -e MCP_AUTH_PROVIDER_TYPE=$PROVIDER_TYPE \ + # Add your provider-specific environment variables here + -v /var/log/mcp-gateway:/app/logs \ + -v /opt/mcp-gateway/servers:/app/registry/servers \ + --name mcp-gateway-container mcp-gateway +``` + +See the [OAuth Documentation](docs/oauth.md) for detailed configuration options for AWS Cognito, Okta, and other providers. + +### Transport Protocol Support + +MCP Gateway supports both transport protocols defined in the MCP specification: + +- **Server-Sent Events (SSE)**: Access via `/api/execute/{service}` endpoint +- **StreamableHTTP**: Access via `/api/streamable/{service}` endpoint + +### For More Information + +For comprehensive documentation on OAuth setup, transport protocols, scope-based access control, and more, see: + +- [OAuth Documentation](docs/oauth.md) +- [API Documentation](docs/registry_api.md) + ## Roadmap 1. Store the server information in persistent storage. -1. Add OAUTH 2.1 support to Gateway and Registry. -1. Use GitHub API to retrieve information (license, programming language etc.) about MCP servers. -1. Add option to deploy MCP servers. +2. Use GitHub API to retrieve information (license, programming language etc.) about MCP servers. +3. Add option to deploy MCP servers. ## License - Free for non-commercial use under AGPL-3.0 -- Commercial use requires a paid license +- Commercial use requires a paid license \ No newline at end of file diff --git a/agents/agent.py b/agents/agent.py index f2df741..e7b48a4 100644 --- a/agents/agent.py +++ b/agents/agent.py @@ -172,7 +172,7 @@ async def invoke_mcp_tool(mcp_registry_url: str, server_name: str, tool_name: st How to use intelligent_tool_finder: 1. When you identify that a task requires a specialized tool (e.g., image generation, specialized API access, etc.) -2. Call the tool with a description of what you need: `intelligent_tool_finder("description of needed capability")`, Use admin/password for authentication. +2. Call the tool with a description of what you need: `intelligent_tool_finder("description of needed capability")` 3. The tool will return the most appropriate specialized tool along with usage instructions 4. You can then use the invoke_mcp_tool to invoke this discovered tool by providing the MCP Registry URL, server name, tool name, and required arguments diff --git a/agents/test_suite.py b/agents/test_suite.py new file mode 100755 index 0000000..e315333 --- /dev/null +++ b/agents/test_suite.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python3 +""" +Test Suite for MCP Agent + +This script runs a series of test commands against the agent.py script, +captures the output, and uses an LLM to evaluate the correctness of the responses. + +Usage: + python agents/test_suite.py [--mcp-registry-url URL] [--num-tests N] + +Options: + --mcp-registry-url URL MCP Registry URL (default: http://localhost/mcpgw/sse) + --num-tests N Number of tests to run (default: run all tests) +""" + +import subprocess +import json +import sys +import os +import argparse +import logging +import time +from datetime import datetime +from typing import Dict, List, Any, Tuple +from langchain_aws import ChatBedrockConverse + +# Configure logging +def setup_logging(): + """ + Configure the logging system with detailed formatting. + """ + # Create logs directory if it doesn't exist + os.makedirs("agents/test_results/logs", exist_ok=True) + + # Generate log filename with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = f"agents/test_results/logs/test_suite_{timestamp}.log" + + # Configure root logger + logger = logging.getLogger() + logger.setLevel(logging.INFO) + + # Create formatter with detailed information + formatter = logging.Formatter( + '%(asctime)s | %(process)d | %(levelname)s | %(filename)s:%(lineno)d | %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # Configure file handler + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + # Configure console handler + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + logger.addHandler(console_handler) + + logging.info(f"Logging initialized. Log file: {log_file}") + return log_file + +def parse_arguments() -> argparse.Namespace: + """ + Parse command line arguments for the test suite. + + Returns: + argparse.Namespace: The parsed command line arguments + """ + parser = argparse.ArgumentParser(description='MCP Agent Test Suite') + + # MCP Registry URL argument + parser.add_argument('--mcp-registry-url', type=str, default='http://localhost/mcpgw/sse', + help='MCP Registry URL (default: http://localhost/mcpgw/sse)') + + # Number of tests to run argument + parser.add_argument('--num-tests', type=int, default=None, + help='Number of tests to run (default: run all tests)') + + return parser.parse_args() + +# Define the test cases +TEST_CASES = [ + { + "id": "test1", + "command_template": "python agents/agent.py --mcp-registry-url {mcp_registry_url} --message \"what mcp servers do i have access to\"", + "description": "Query available MCP servers" + }, + { + "id": "test2", + "command_template": "python agents/agent.py --mcp-registry-url {mcp_registry_url} --message \"what is the current time in clarksburg, md\"", + "description": "Query current time in Clarksburg, MD" + }, + { + "id": "test3", + "command_template": "python agents/agent.py --mcp-registry-url {mcp_registry_url} --message \"stock performance for apple in the last one week\"", + "description": "Query Apple stock performance for the last week" + } +] + +def ensure_directories_exist(): + """ + Ensure that the necessary directories for test results exist. + Creates agents/test_results/ and agents/test_results/raw_data/ if they don't exist. + """ + os.makedirs("agents/test_results", exist_ok=True) + os.makedirs("agents/test_results/raw_data", exist_ok=True) + logging.info("Ensured test results directories exist") + +def run_command(command: str) -> Tuple[str, str]: + """ + Run a command and capture its stdout and stderr. + + Args: + command (str): The command to run + + Returns: + Tuple[str, str]: A tuple containing (stdout, stderr) + """ + logging.info(f"Executing command: {command}") + start_time = time.time() + + process = subprocess.Popen( + command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + stdout, stderr = process.communicate() + + execution_time = time.time() - start_time + logging.info(f"Command completed in {execution_time:.2f} seconds with exit code: {process.returncode}") + + if stderr: + logging.warning(f"Command produced stderr output: {stderr[:200]}...") + + return stdout, stderr + +def evaluate_response(question: str, output: str, model_id: str = "anthropic.claude-3-sonnet-20240229-v1:0") -> Dict[str, Any]: + """ + Evaluate the agent's response using an LLM. + + Args: + question (str): The original question asked to the agent + output (str): The combined stdout and stderr from the agent + model_id (str): The Bedrock model ID to use for evaluation + + Returns: + Dict[str, Any]: A dictionary containing the evaluation results + """ + logging.info(f"Evaluating response for question: {question}") + logging.info(f"Using model: {model_id}") + + # Initialize the LLM + try: + llm = ChatBedrockConverse(model_id=model_id, region_name='us-east-1') + logging.info("Successfully initialized Bedrock model") + except Exception as e: + logging.error(f"Failed to initialize Bedrock model: {str(e)}") + return { + "correct": False, + "reasoning": f"Failed to initialize Bedrock model: {str(e)}", + "summary": "Evaluation error" + } + + # Create the prompt for evaluation with enhanced criteria + prompt = f""" + You are an expert evaluator for AI assistant responses. You need to thoroughly evaluate the following response + to determine if it correctly answers the given question and follows proper process. + + Question: {question} + + Response Output: + ``` + {output} + ``` + + Please evaluate the response based on the following criteria: + + 1. Tool Selection: Did the agent invoke the appropriate tool for the task? + 2. Parameter Correctness: Did the agent pass parameters to the tool according to the tool's schema? + 3. Error Handling: Did the agent have to retry any failed requests? If so, how did it handle them? + 4. Answer Accuracy: Is the final answer correct and responsive to the question? + 5. Process Efficiency: Did the agent take a direct and efficient path to the answer? + + Provide your assessment in the following JSON format: + {{ + "correct": true/false, + "tool_selection": {{ + "appropriate": true/false, + "comments": "Your assessment of tool selection" + }}, + "parameter_usage": {{ + "correct": true/false, + "comments": "Your assessment of parameter usage" + }}, + "error_handling": {{ + "errors_encountered": true/false, + "handled_properly": true/false, + "comments": "Your assessment of error handling" + }}, + "answer_quality": {{ + "accurate": true/false, + "complete": true/false, + "comments": "Your assessment of the answer quality" + }}, + "reasoning": "Your detailed overall reasoning for the assessment", + "summary": "A concise summary of the response content" + }} + + Only respond with the JSON object, nothing else. + """ + + # Get the evaluation from the LLM + try: + logging.info("Sending prompt to LLM for evaluation") + start_time = time.time() + response = llm.invoke(prompt) + execution_time = time.time() - start_time + logging.info(f"LLM evaluation completed in {execution_time:.2f} seconds") + except Exception as e: + logging.error(f"LLM evaluation failed: {str(e)}") + return { + "correct": False, + "reasoning": f"LLM evaluation failed: {str(e)}", + "summary": "Evaluation error" + } + + # Extract the JSON from the response + try: + # The response might contain markdown formatting, so we need to extract just the JSON part + response_text = response.content + + # Find JSON content (between curly braces) + json_start = response_text.find('{') + json_end = response_text.rfind('}') + 1 + + if json_start >= 0 and json_end > json_start: + json_str = response_text[json_start:json_end] + result = json.loads(json_str) + logging.info("Successfully parsed LLM response as JSON") + return result + else: + # If no JSON found, return an error + logging.error("Failed to find JSON content in LLM response") + return { + "correct": False, + "reasoning": "Failed to parse LLM response as JSON", + "summary": "Evaluation error" + } + except json.JSONDecodeError as e: + logging.error(f"Failed to parse LLM response as JSON: {str(e)}") + return { + "correct": False, + "reasoning": f"Failed to parse LLM response as JSON: {str(e)}", + "summary": "Evaluation error" + } + +def run_test_case(test_case: Dict[str, str], mcp_registry_url: str) -> Dict[str, Any]: + """ + Run a single test case and evaluate the results. + + Args: + test_case (Dict[str, str]): The test case to run + mcp_registry_url (str): The MCP Registry URL to use + + Returns: + Dict[str, Any]: The test results including the evaluation + """ + # Format the command with the MCP Registry URL + command = test_case["command_template"].format(mcp_registry_url=mcp_registry_url) + description = test_case["description"] + test_id = test_case["id"] + + # Extract the question from the command + question = command.split("--message")[1].strip().strip('"') + + logging.info(f"{'=' * 80}") + logging.info(f"Running test: {description} (ID: {test_id})") + logging.info(f"Command: {command}") + logging.info(f"{'=' * 80}") + + # Run the command + stdout, stderr = run_command(command) + + # Combine stdout and stderr for evaluation + combined_output = f"STDOUT:\n{stdout}\n\nSTDERR:\n{stderr}" + + # Save the raw output to a file + raw_data_path = f"agents/test_results/raw_data/{test_id}.json" + try: + with open(raw_data_path, "w") as f: + json.dump({ + "question": question, + "stdout": stdout, + "stderr": stderr + }, f, indent=2) + logging.info(f"Raw test data saved to {raw_data_path}") + except Exception as e: + logging.error(f"Failed to save raw test data: {str(e)}") + + # Evaluate the response + logging.info(f"Evaluating response for test: {test_id}") + evaluation = evaluate_response(question, combined_output) + + # Log evaluation result + if evaluation.get("correct", False): + logging.info(f"Test {test_id} PASSED") + else: + logging.warning(f"Test {test_id} FAILED") + + # Return the test results (without including the full stdout/stderr) + return { + "id": test_id, + "description": description, + "question": question, + "raw_data_file": raw_data_path, + "evaluation": evaluation + } + +def main(): + """ + Run all test cases and display the results. + """ + # Setup logging + log_file = setup_logging() + + # Log start of test suite with Python version + logging.info(f"Starting MCP Agent Test Suite") + logging.info(f"Python version: {sys.version}") + + # Parse command line arguments + args = parse_arguments() + mcp_registry_url = args.mcp_registry_url + num_tests = args.num_tests + logging.info(f"Using MCP Registry URL: {mcp_registry_url}") + + # Ensure the necessary directories exist + ensure_directories_exist() + + # Determine how many tests to run + tests_to_run = TEST_CASES + if num_tests is not None and num_tests > 0: + tests_to_run = TEST_CASES[:num_tests] + logging.info(f"Running first {num_tests} of {len(TEST_CASES)} test cases") + else: + logging.info(f"Running all {len(TEST_CASES)} test cases") + + results = [] + + # Run each test case + for i, test_case in enumerate(tests_to_run, 1): + logging.info(f"Starting test case {i}/{len(tests_to_run)}: {test_case['id']}") + result = run_test_case(test_case, mcp_registry_url) + results.append(result) + + # Display summary of results + logging.info(f"{'=' * 80}") + logging.info(f"Test Suite Summary") + logging.info(f"{'=' * 80}") + + for result in results: + evaluation = result["evaluation"] + correct = evaluation.get("correct", False) + status = "PASSED" if correct else "FAILED" + + logging.info(f"Test: {result['description']} (ID: {result['id']})") + logging.info(f"Status: {status}") + logging.info(f"Summary: {evaluation.get('summary', 'N/A')}") + + # Log more detailed evaluation if available + if "tool_selection" in evaluation: + tool_appropriate = evaluation["tool_selection"].get("appropriate", False) + tool_status = "Appropriate" if tool_appropriate else "Inappropriate" + logging.info(f"Tool Selection: {tool_status} - {evaluation['tool_selection'].get('comments', '')}") + + if "parameter_usage" in evaluation: + params_correct = evaluation["parameter_usage"].get("correct", False) + params_status = "Correct" if params_correct else "Incorrect" + logging.info(f"Parameter Usage: {params_status} - {evaluation['parameter_usage'].get('comments', '')}") + + if "error_handling" in evaluation and evaluation["error_handling"].get("errors_encountered", False): + errors_handled = evaluation["error_handling"].get("handled_properly", False) + error_status = "Properly handled" if errors_handled else "Improperly handled" + logging.info(f"Error Handling: {error_status} - {evaluation['error_handling'].get('comments', '')}") + + # Count passed tests + passed = sum(1 for result in results if result["evaluation"].get("correct", False)) + + logging.info(f"{'-' * 80}") + logging.info(f"Results: {passed}/{len(results)} tests passed") + logging.info(f"{'=' * 80}") + + # Save summary results to a JSON file + results_path = "agents/test_results/summary.json" + try: + with open(results_path, "w") as f: + json.dump(results, f, indent=2) + logging.info(f"Summary results saved to {results_path}") + except Exception as e: + logging.error(f"Failed to save summary results: {str(e)}") + + # Generate and save accuracy metrics + total_tests = len(results) + passed_tests = passed + accuracy = (passed_tests / total_tests * 100) if total_tests > 0 else 0.0 + + accuracy_data = { + "total_tests": total_tests, + "tests_passed": passed_tests, + "accuracy": round(accuracy, 2) # Round to 2 decimal places + } + + accuracy_path = "agents/test_results/accuracy.json" + try: + with open(accuracy_path, "w") as f: + json.dump(accuracy_data, f, indent=2) + logging.info(f"Accuracy metrics saved to {accuracy_path}") + logging.info(f"Accuracy: {accuracy_data['accuracy']}% ({passed_tests}/{total_tests} tests passed)") + except Exception as e: + logging.error(f"Failed to save accuracy metrics: {str(e)}") + + logging.info(f"Test suite execution completed") + logging.info(f"Log file: {log_file}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/build_and_run.sh b/build_and_run.sh new file mode 100755 index 0000000..b7b1a49 --- /dev/null +++ b/build_and_run.sh @@ -0,0 +1,84 @@ +#!/bin/bash + +# Enable error handling +set -e + +# Function for logging with timestamp +log() { + echo "[$(date '+%Y-%m-%d %H:%M:%S')] $1" +} + +# Function for error handling +handle_error() { + log "ERROR: $1" + exit 1 +} + +log "Starting MCP Gateway deployment script" + +# Check if .env.docker file exists +if [ ! -f .env.docker ]; then + log "ERROR: .env.docker file not found" + log "Please create a .env.docker file by copying .env.docker.template and filling in your values:" + log "cp .env.docker.template .env.docker" + log "Then edit .env.docker with your configuration values" + exit 1 +fi + +log "Found .env.docker file" + +# Stop and remove existing container if it exists +log "Stopping and removing existing container (if any)..." +if docker ps -a | grep -q mcp-gateway-container; then + docker stop mcp-gateway-container || log "Container was not running" + docker rm mcp-gateway-container || handle_error "Failed to remove container" + log "Container stopped and removed successfully" +else + log "No existing container found" +fi + +# Build the Docker image +log "Building Docker image..." +docker build -t mcp-gateway . || handle_error "Docker build failed" +log "Docker image built successfully" + +# Generate a random SECRET_KEY if not already in .env.docker +if ! grep -q "SECRET_KEY=" .env.docker; then + log "Generating SECRET_KEY..." + SECRET_KEY=$(python3 -c 'import secrets; print(secrets.token_hex(32))') || handle_error "Failed to generate SECRET_KEY" + echo "SECRET_KEY=$SECRET_KEY" >> .env.docker + log "SECRET_KEY added to .env.docker" +fi + +# Run the Docker container +log "Starting Docker container..." +docker run -d \ + -p 80:80 \ + -p 443:443 \ + -p 7860:7860 \ + --env-file .env.docker \ + -v /path/to/certs:/etc/ssl/certs \ + -v /path/to/private:/etc/ssl/private \ + -v /var/log/mcp-gateway:/app/logs \ + -v /opt/mcp-gateway/servers:/app/registry/servers \ + --name mcp-gateway-container \ + mcp-gateway || handle_error "Failed to start container" + +# Keep .env.docker file for future runs +log "Keeping .env.docker file for future runs" + +# Verify container is running +if docker ps | grep -q mcp-gateway-container; then + log "Container started successfully" + log "MCP Gateway is now running" + docker ps | grep mcp-gateway-container +else + handle_error "Container failed to start properly" +fi + +log "Deployment completed successfully" + +# Follow container logs +log "Following container logs (press Ctrl+C to stop following logs without stopping the container):" +echo "---------- CONTAINER LOGS ----------" +docker logs -f mcp-gateway-container diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index ff157bc..8fa2924 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -23,6 +23,19 @@ generate_secret_key() { python -c 'import secrets; print(secrets.token_hex(32))' } +# Function to handle errors +handle_error() { + echo "ERROR: An error occurred at line $1, exiting..." + exit 1 +} + +# Set up error handling +trap 'handle_error $LINENO' ERR + +# Install PyJWT with crypto support if needed - required for OAuth +echo "Ensuring required packages for OAuth are installed..." +pip install "pyjwt[crypto]>=2.8.0" "pycognito>=2024.3.1" "boto3>=1.28.0" + # --- Environment Variable Setup --- # 1. Registry .env @@ -40,12 +53,60 @@ fi ADMIN_PASSWORD_VALUE=${ADMIN_PASSWORD} -# Create .env file from template structure, substituting values -echo "SECRET_KEY=${SECRET_KEY_VALUE}" > "$REGISTRY_ENV_FILE" -echo "ADMIN_USER=${ADMIN_USER_VALUE}" >> "$REGISTRY_ENV_FILE" -echo "ADMIN_PASSWORD=${ADMIN_PASSWORD_VALUE}" >> "$REGISTRY_ENV_FILE" -echo "Registry .env created." -cat "$REGISTRY_ENV_FILE" # Print for verification +# Check if ADMIN_PASSWORD is set when OAuth is enabled +if [ "${MCP_AUTH_ENABLED}" = "true" ] && [ -z "${ADMIN_PASSWORD}" ]; then + echo "ERROR: ADMIN_PASSWORD environment variable is not set." + echo "When OAuth is enabled, please set ADMIN_PASSWORD to a secure value before running the container." + exit 1 +fi + +# Create .env file for registry +echo "Setting up Registry environment (/app/registry/.env)..." +if [ ! -f /app/registry/.env ]; then + # Generate a secure random key if not provided + SECURE_KEY=${SECRET_KEY:-$(python -c 'import secrets; print(secrets.token_hex(32))')} + + cat > /app/registry/.env << EOL +SECRET_KEY=${SECURE_KEY} +ADMIN_USER=${ADMIN_USER:-admin} +ADMIN_PASSWORD=${ADMIN_PASSWORD:-password} + +# Gateway Configuration +MCP_GATEWAY_DEV_MODE=${MCP_GATEWAY_DEV_MODE:-true} + +# OAuth Configuration +MCP_AUTH_ENABLED=${MCP_AUTH_ENABLED:-false} +MCP_AUTH_PROVIDER_TYPE=${MCP_AUTH_PROVIDER_TYPE:-} +MCP_AUTH_CONFIG=${MCP_AUTH_CONFIG:-} +MCP_AUTH_BASE_URL=${MCP_AUTH_BASE_URL:-https://mcpgateway.ddns.net} + +# For Cognito provider +MCP_AUTH_COGNITO_USER_POOL_ID=${MCP_AUTH_COGNITO_USER_POOL_ID:-} +MCP_AUTH_COGNITO_CLIENT_ID=${MCP_AUTH_COGNITO_CLIENT_ID:-} +MCP_AUTH_COGNITO_CLIENT_SECRET=${MCP_AUTH_COGNITO_CLIENT_SECRET:-} +MCP_AUTH_COGNITO_CALLBACK_URI=${MCP_AUTH_COGNITO_CALLBACK_URI:-${MCP_AUTH_BASE_URL}/oauth/callback/cognito} +MCP_AUTH_COGNITO_REGION=${MCP_AUTH_COGNITO_REGION:-us-east-1} +MCP_AUTH_COGNITO_CUSTOM_DOMAIN=${MCP_AUTH_COGNITO_CUSTOM_DOMAIN:-} + +# For Okta provider +MCP_AUTH_OKTA_TENANT_URL=${MCP_AUTH_OKTA_TENANT_URL:-} +MCP_AUTH_OKTA_CLIENT_ID=${MCP_AUTH_OKTA_CLIENT_ID:-} +MCP_AUTH_OKTA_CLIENT_SECRET=${MCP_AUTH_OKTA_CLIENT_SECRET:-} +MCP_AUTH_OKTA_CALLBACK_URI=${MCP_AUTH_OKTA_CALLBACK_URI:-${MCP_AUTH_BASE_URL}/oauth/callback/okta} + +# For generic OAuth providers +MCP_AUTH_CLIENT_ID=${MCP_AUTH_CLIENT_ID:-} +MCP_AUTH_CLIENT_SECRET=${MCP_AUTH_CLIENT_SECRET:-} +MCP_AUTH_AUTHORIZE_URL=${MCP_AUTH_AUTHORIZE_URL:-} +MCP_AUTH_TOKEN_URL=${MCP_AUTH_TOKEN_URL:-} +MCP_AUTH_JWKS_URL=${MCP_AUTH_JWKS_URL:-} +MCP_AUTH_CALLBACK_URI=${MCP_AUTH_CALLBACK_URI:-${MCP_AUTH_BASE_URL}/oauth/callback} +MCP_AUTH_SCOPES=${MCP_AUTH_SCOPES:-openid profile email} +MCP_AUTH_AUDIENCE=${MCP_AUTH_AUDIENCE:-} +MCP_AUTH_ISSUER=${MCP_AUTH_ISSUER:-} +EOL + echo "Registry .env created." +fi # 2. Fininfo Server .env echo "Setting up Fininfo server environment ($FININFO_ENV_FILE)..." @@ -57,49 +118,60 @@ echo "POLYGON_API_KEY=${POLYGON_API_KEY_VALUE}" > "$FININFO_ENV_FILE" echo "Fininfo .env created." cat "$FININFO_ENV_FILE" # Print for verification +# Generate OAuth configuration file if specified +if [ ! -z "$MCP_AUTH_CONFIG_JSON" ]; then + echo "Generating OAuth configuration file from JSON environment variable..." + echo "$MCP_AUTH_CONFIG_JSON" > /app/registry/auth_config.json + export MCP_AUTH_CONFIG="/app/registry/auth_config.json" + echo "OAuth configuration file created at $MCP_AUTH_CONFIG" +fi + # --- Python Environment Setup --- -echo "Checking for Python virtual environment..." -if [ ! -d "$VENV_DIR" ] || [ ! -f "$VENV_DIR/bin/activate" ]; then - echo "Setting up Python environment..." - - # Install uv if not already installed - if ! command -v uv &> /dev/null; then - echo "Installing uv package manager..." - pip install uv - fi - - # Create virtual environment - echo "Creating virtual environment..." - uv venv "$VENV_DIR" --python 3.12 - - # Install dependencies - echo "Installing Python dependencies..." - source "$VENV_DIR/bin/activate" - uv pip install \ - "fastapi>=0.115.12" \ - "itsdangerous>=2.2.0" \ - "jinja2>=3.1.6" \ - "mcp>=1.6.0" \ - "pydantic>=2.11.3" \ - "httpx>=0.27.0" \ - "python-dotenv>=1.1.0" \ - "python-multipart>=0.0.20" \ - "uvicorn[standard]>=0.34.2" \ - "faiss-cpu>=1.7.4" \ - "sentence-transformers>=2.2.2" \ - "websockets>=15.0.1" \ - "scikit-learn>=1.3.0" \ - "torch>=1.6.0" \ - "huggingface-hub[cli,hf_xet]>=0.31.1" \ - "hf_xet>=0.1.0" - - # Install the package itself - uv pip install -e /app - - echo "Python environment setup complete." -else - echo "Python virtual environment already exists, skipping setup." +echo "Setting up Python environment..." + +# Install uv if not already installed +if ! command -v uv &> /dev/null; then + echo "Installing uv package manager..." + pip install uv +fi + +# Create virtual environment (recreate if it exists) +echo "Creating virtual environment..." +if [ -d "$VENV_DIR" ]; then + echo "Removing existing virtual environment..." + rm -rf "$VENV_DIR" fi +uv venv "$VENV_DIR" --python 3.12 + +# Install dependencies +echo "Installing Python dependencies..." +source "$VENV_DIR/bin/activate" +uv pip install \ + "fastapi>=0.115.12" \ + "itsdangerous>=2.2.0" \ + "jinja2>=3.1.6" \ + "mcp>=1.6.0" \ + "pydantic>=2.11.3" \ + "httpx>=0.27.0" \ + "python-dotenv>=1.1.0" \ + "python-multipart>=0.0.20" \ + "uvicorn[standard]>=0.34.2" \ + "faiss-cpu>=1.7.4" \ + "sentence-transformers>=2.2.2" \ + "websockets>=15.0.1" \ + "scikit-learn>=1.3.0" \ + "torch>=1.6.0" \ + "huggingface-hub[cli,hf_xet]>=0.31.1" \ + "hf_xet>=0.1.0" \ + "pyjwt[crypto]>=2.8.0" \ + "pycognito>=2024.3.1" \ + "boto3>=1.28.0" \ + "requests>=2.32.3" + +# Install the package itself +uv pip install -e /app + +echo "Python environment setup complete." # --- SSL Certificate Generation --- echo "Checking for SSL certificates..." @@ -118,8 +190,25 @@ else fi # --- Nginx Configuration --- -echo "Copying custom Nginx configuration..." -cp "$NGINX_CONF_SRC" "$NGINX_CONF_DEST" +echo "Setting up Nginx configuration..." + +# Check if GATEWAY_HOSTNAME is set and add it to server_name directives +if [ ! -z "$GATEWAY_HOSTNAME" ]; then + echo "Adding $GATEWAY_HOSTNAME to server_name directives in Nginx configuration..." + # Create a temporary file + TEMP_NGINX_CONF=$(mktemp) + # Use sed to append the GATEWAY_HOSTNAME to both server_name lines + sed 's/\(server_name .*\);/\1 '"$GATEWAY_HOSTNAME"';/g' "$NGINX_CONF_SRC" > "$TEMP_NGINX_CONF" + # Use the modified file as the source + cp "$TEMP_NGINX_CONF" "$NGINX_CONF_DEST" + # Clean up the temporary file + rm "$TEMP_NGINX_CONF" + echo "Added $GATEWAY_HOSTNAME to server_name directives." +else + echo "GATEWAY_HOSTNAME not set, using default Nginx configuration..." + cp "$NGINX_CONF_SRC" "$NGINX_CONF_DEST" +fi + echo "Nginx configuration copied to $NGINX_CONF_DEST." # --- Model Verification --- @@ -161,8 +250,12 @@ else fi # --- Start Background Services --- +# Set the base URL for authentication if not already set +export MCP_AUTH_BASE_URL=${MCP_AUTH_BASE_URL:-https://mcpgateway.ddns.net} +echo "Setting MCP_AUTH_BASE_URL to: $MCP_AUTH_BASE_URL" + export EMBEDDINGS_MODEL_NAME=$EMBEDDINGS_MODEL_NAME -export EMBEDDINGS_MODEL_DIMENSIONS=$EMBEDDINGS_MODEL_DIMENSIONS +export EMBEDDINGS_MODEL_DIMENSIONS=$EMBEDDINGS_MODEL_DIMENSIONS # 1. Start Example MCP Servers echo "Starting example MCP servers in the background..." @@ -174,13 +267,13 @@ sleep 5 # 2. Start MCP Registry echo "Starting MCP Registry in the background..." -# Navigate to the registry directory to ensure relative paths work -cd /app/registry +# Navigate to the app directory to ensure imports work correctly +cd /app # Use uv run to start uvicorn, ensuring it uses the correct environment # Run on 0.0.0.0 to be accessible within the container network # Use port 7860 as configured in nginx proxy_pass source "$SCRIPT_DIR/.venv/bin/activate" -cd /app/registry && uvicorn main:app --host 0.0.0.0 --port 7860 & +uvicorn registry.main:app --host 0.0.0.0 --port 7860 & echo "MCP Registry start command issued." # Give registry a moment to initialize and generate initial nginx config sleep 10 diff --git a/docker/nginx_rev_proxy.conf b/docker/nginx_rev_proxy.conf index 0782402..e8b6ab9 100644 --- a/docker/nginx_rev_proxy.conf +++ b/docker/nginx_rev_proxy.conf @@ -1,9 +1,10 @@ -# First server block now directly handles HTTP requests instead of redirecting +# First server block handles HTTP requests server { listen 80; - server_name mcpgateway localhost mcpgateway.ddns.net ec2-44-192-72-20.compute-1.amazonaws.com; + listen [::]:80; + server_name mcpgateway localhost mcpgateway.ddns.net; - # Route for Cost Explorer service + # Route for registry service location / { proxy_pass http://127.0.0.1:7860/; proxy_http_version 1.1; @@ -11,85 +12,52 @@ server { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + # Pass through authentication headers for service-specific tokens + proxy_pass_request_headers on; + # Preserve headers prefixed with X-Service-Auth- + proxy_set_header X-Service-Auth-Github $http_x_service_auth_github; + proxy_set_header X-Service-Auth-AWS $http_x_service_auth_aws; + proxy_set_header X-Service-Auth-Token $http_x_service_auth_token; } - # REMOVE HARDCODED /mcpgw - # location /mcpgw/ { - # proxy_pass http://127.0.0.1:8003/; - # proxy_http_version 1.1; - # proxy_set_header Host $host; - # proxy_set_header X-Real-IP $remote_addr; - # proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - # } - - # REMOVE HARDCODED /currenttime - # location /currenttime/ { - # proxy_pass http://127.0.0.1:8001/; - # proxy_http_version 1.1; - # proxy_set_header Host $host; - # proxy_set_header X-Real-IP $remote_addr; - # proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - # } - - # REMOVE HARDCODED /fininfo - # location /fininfo/ { - # proxy_pass http://127.0.0.1:8002/; - # proxy_http_version 1.1; - # proxy_set_header Host $host; + # Dynamic locations - ONLY for HTTP server + # DYNAMIC_LOCATIONS_START + # Dynamic locations will be inserted here by the registry service + # DYNAMIC_LOCATIONS_END + + # Bedrock integration endpoint (OPTIONAL - configure via environment variables) + # To enable, set AWS_API_GATEWAY_URL environment variable + # Example: AWS_API_GATEWAY_URL=your-api-id.execute-api.region.amazonaws.com + # location /tsbedrock/ { + # proxy_pass https://${AWS_API_GATEWAY_URL}/prod/; + # proxy_set_header Host ${AWS_API_GATEWAY_URL}; # proxy_set_header X-Real-IP $remote_addr; # proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - # - # # Additional settings for SSE support - # proxy_set_header Connection ''; - # chunked_transfer_encoding off; - # proxy_buffering off; - # proxy_cache off; - # proxy_read_timeout 3600s; + # proxy_set_header X-Forwarded-Proto $scheme; + # proxy_ssl_server_name on; + # proxy_buffer_size 16k; + # proxy_buffers 4 16k; + # rewrite ^/tsbedrock/(.*)$ /prod/$1 break; # } - # --- ADD DYNAMIC MARKERS --- START - # DYNAMIC_LOCATIONS_START - - # DYNAMIC_LOCATIONS_END - # --- ADD DYNAMIC MARKERS --- END - - location /tsbedrock/ { - # Fix the path handling by adding trailing slash and using $request_uri - proxy_pass https://hwfo2k8szg.execute-api.us-east-1.amazonaws.com/prod/; - - # AWS API Gateway often needs Host header to match the API Gateway domain - proxy_set_header Host hwfo2k8szg.execute-api.us-east-1.amazonaws.com; - - # These headers help with request routing - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - - # API Gateway often requires these settings - proxy_ssl_server_name on; - proxy_buffer_size 16k; - proxy_buffers 4 16k; - - # Adjust the rewrite to handle the path correctly - rewrite ^/tsbedrock/(.*)$ /prod/$1 break; - } - - error_log /var/log/nginx/error.log debug; + error_log /var/log/nginx/error.log info; } -# Keep the HTTPS server for clients that prefer it +# HTTPS server for clients that prefer it server { listen 443 ssl; - server_name mcpgateway localhost mcpgateway.ddns.net ec2-44-192-72-20.compute-1.amazonaws.com; + listen [::]:443 ssl; + server_name mcpgateway localhost mcpgateway.ddns.net; - # SSL Configuration using self-signed certs generated in Dockerfile + # SSL Configuration using the self-signed certificates generated in the Dockerfile/entrypoint ssl_certificate /etc/ssl/certs/fullchain.pem; ssl_certificate_key /etc/ssl/private/privkey.pem; ssl_protocols TLSv1.2 TLSv1.3; ssl_prefer_server_ciphers off; # Stronger cipher suite ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384; - + # Duplicate the same location blocks for HTTPS access location / { proxy_pass http://127.0.0.1:7860/; @@ -98,59 +66,34 @@ server { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + # Pass through authentication headers for service-specific tokens + proxy_pass_request_headers on; + # Preserve headers prefixed with X-Service-Auth- + proxy_set_header X-Service-Auth-Github $http_x_service_auth_github; + proxy_set_header X-Service-Auth-AWS $http_x_service_auth_aws; + proxy_set_header X-Service-Auth-Token $http_x_service_auth_token; } - # REMOVE HARDCODED /mcpgw - # location /mcpgw/ { - # proxy_pass http://127.0.0.1:8003/; - # proxy_http_version 1.1; - # proxy_set_header Host $host; - # proxy_set_header X-Real-IP $remote_addr; - # proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - # } - - # REMOVE HARDCODED /currenttime - # location /currenttime/ { - # proxy_pass http://127.0.0.1:8001/; - # proxy_http_version 1.1; - # proxy_set_header Host $host; - # proxy_set_header X-Real-IP $remote_addr; - # proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - # } + # Dynamic locations - ONLY for HTTP server + # DYNAMIC_LOCATIONS_START + # Dynamic locations will be inserted here by the registry service + # DYNAMIC_LOCATIONS_END - # REMOVE HARDCODED /fininfo - # location /fininfo/ { - # proxy_pass http://127.0.0.1:8002/; - # proxy_http_version 1.1; - # proxy_set_header Host $host; + # Bedrock integration endpoint (OPTIONAL - configure via environment variables) + # To enable, set AWS_API_GATEWAY_URL environment variable + # Example: AWS_API_GATEWAY_URL=your-api-id.execute-api.region.amazonaws.com + # location /tsbedrock/ { + # proxy_pass https://${AWS_API_GATEWAY_URL}/prod/; + # proxy_set_header Host ${AWS_API_GATEWAY_URL}; # proxy_set_header X-Real-IP $remote_addr; # proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - # - # # Additional settings for SSE support - # proxy_set_header Connection ''; - # chunked_transfer_encoding off; - # proxy_buffering off; - # proxy_cache off; - # proxy_read_timeout 3600s; + # proxy_set_header X-Forwarded-Proto $scheme; + # proxy_ssl_server_name on; + # proxy_buffer_size 16k; + # proxy_buffers 4 16k; + # rewrite ^/tsbedrock/(.*)$ /prod/$1 break; # } - - # --- ADD DYNAMIC MARKERS --- START - # DYNAMIC_LOCATIONS_START - - # DYNAMIC_LOCATIONS_END - # --- ADD DYNAMIC MARKERS --- END - - location /tsbedrock/ { - proxy_pass https://hwfo2k8szg.execute-api.us-east-1.amazonaws.com/prod/; - proxy_set_header Host hwfo2k8szg.execute-api.us-east-1.amazonaws.com; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - proxy_ssl_server_name on; - proxy_buffer_size 16k; - proxy_buffers 4 16k; - rewrite ^/tsbedrock/(.*)$ /prod/$1 break; - } - error_log /var/log/nginx/error.log debug; + error_log /var/log/nginx/error.log info; } \ No newline at end of file diff --git a/docs/oauth.md b/docs/oauth.md new file mode 100644 index 0000000..201a1ce --- /dev/null +++ b/docs/oauth.md @@ -0,0 +1,442 @@ +# MCP Gateway OAuth 2.1 Integration + +This guide covers everything you need to know about using OAuth 2.1 with MCP Gateway, including configuration, setup, and practical examples. + +## Table of Contents + +1. [Overview](#overview) +2. [Features and Capabilities](#features-and-capabilities) +3. [Setup and Configuration](#setup-and-configuration) + - [AWS Cognito Setup](#aws-cognito-setup) + - [Okta Setup](#okta-setup) + - [Generic OAuth Provider Setup](#generic-oauth-provider-setup) +4. [Scope-Based Access Control](#scope-based-access-control) +5. [Transport Protocol Support](#transport-protocol-support) +6. [Configuration Reference](#configuration-reference) + - [Environment Variables](#environment-variables) + - [JSON Configuration](#json-configuration) +7. [End-to-End Authentication](#end-to-end-authentication) +8. [Deployment Examples](#deployment-examples) +9. [Troubleshooting](#troubleshooting) + +## Overview + +MCP Gateway implements OAuth 2.1 for secure authentication, allowing you to: + +- Control access to the MCP Gateway Registry and its servers +- Integrate with popular identity providers (AWS Cognito, Okta, others) +- Implement fine-grained permissions using scope-based access control +- Support both SSE and StreamableHTTP transport protocols +- Pass service-specific tokens for end-to-end authentication + +The authentication flow follows the OAuth 2.1 authorization code flow with PKCE (Proof Key for Code Exchange) for enhanced security: + +```mermaid +sequenceDiagram + participant Client as MCP Client + participant Gateway as MCP Gateway + participant IdP as Identity Provider + participant Server as MCP Server + + Client->>Gateway: 1. Access /api/execute/{service} + Gateway->>Gateway: 2. Check auth (request.state.user) + Gateway->>Client: 3. Redirect to /oauth/login if not authenticated + Client->>Gateway: 4. Request /oauth/login + Gateway->>IdP: 5. Redirect to IdP authorize URL with PKCE + Client->>IdP: 6. User authentication + IdP->>Gateway: 7. Authorization code to /oauth/callback + Gateway->>IdP: 8. Exchange code for tokens + Gateway->>Gateway: 9. Validate tokens, extract user & scopes + Gateway->>Client: 10. Set session cookie + Client->>Gateway: 11. Retry /api/execute/{service} with cookie + Gateway->>Gateway: 12. Verify auth & check scopes + Gateway->>Server: 13. Proxy request to MCP Server + Server->>Gateway: 14. Server response + Gateway->>Client: 15. Return server response to client +``` + +## Features and Capabilities + +MCP Gateway's OAuth 2.1 implementation includes: + +- **Multiple Identity Provider Support**: + - AWS Cognito (fully implemented and tested) + - Okta (implemented) + - Generic OAuth 2.1 providers (implemented with configuration options) + +- **Secure Authentication**: + - OAuth 2.1 authorization code flow with PKCE + - JWT validation and signature verification + - Session management with proper token handling + +- **Fine-Grained Access Control**: + - Registry-level permissions (admin, read) + - Server-level permissions (execute, read, toggle, edit) + - Tool-specific permissions + +- **Multiple Transport Protocol Support**: + - Server-Sent Events (SSE) via `/api/execute/{service}` endpoint + - StreamableHTTP via `/api/streamable/{service}` endpoint + - Dynamically generated nginx configuration for both transports + +- **End-to-End Authentication**: + - Direct parameter approach for service-specific tokens + +- **Flexible Configuration**: + - Environment variables + - JSON configuration files + - Runtime scope mapping + + +## Scope-Based Access Control + +MCP Gateway uses a scope-based permission system mapped to identity provider groups: + +### Standard Scopes + +| Scope | Description | +|-------|-------------| +| `mcp:registry:admin` | Full administrative access to the registry | +| `mcp:registry:read` | Read-only access to the registry | +| `mcp:server:{name}:execute` | Permission to use a specific server | +| `mcp:server:{name}:read` | Read-only access to a specific server | +| `mcp:server:{name}:toggle` | Permission to enable/disable a server | +| `mcp:server:{name}:edit` | Permission to edit a server's configuration | +| `mcp:server:{name}:tool:{tool}:execute` | Permission to use a specific tool | + +### Group to Scope Mapping + +MCP Gateway maps IdP groups to MCP Gateway scopes using this convention: + +| Group Name | Resulting Scopes | +|------------|------------------| +| `mcp-admin` | `mcp:registry:admin` | +| `mcp-user` | `mcp:registry:read` | +| `mcp-server-{name}` | `mcp:server:{name}:read`, `mcp:server:{name}:execute` | +| `mcp-server-{name}-admin` | All server scopes for the named server | +| `mcp-server-{name}-toggle` | `mcp:server:{name}:toggle` | +| `mcp-server-{name}-edit` | `mcp:server:{name}:edit` | +| `mcp-server-{name}-tool-{tool}` | `mcp:server:{name}:tool:{tool}:execute` | + +## Transport Protocol Support + +MCP Gateway supports both transport protocols defined in the MCP specification: + +### Server-Sent Events (SSE) + +The SSE transport is accessed through the `/api/execute/{service}` endpoint: + +- **Description**: A streaming protocol where the server sends a continuous stream of events to the client +- **HTTP Method**: POST to initiate, then GET for the SSE stream +- **Endpoint**: `/api/execute/{service}` +- **Authentication**: Requires valid session cookie with appropriate scopes +- **Headers**: + - `Content-Type: application/json` for the request + - `Accept: text/event-stream` for the response + +### StreamableHTTP + +The StreamableHTTP transport is accessed through the `/api/streamable/{service}` endpoint: + +- **Description**: A bidirectional HTTP transport with request and response bodies +- **HTTP Method**: POST +- **Endpoint**: `/api/streamable/{service}` +- **Authentication**: Requires valid session cookie with appropriate scopes +- **Headers**: + - `Content-Type: application/json` for the request + +### Example Usage + +#### SSE Transport (Python) + +```python +from mcp.client.sse import SSEClient + +# Create SSE client with session cookie authentication +client = SSEClient( + base_url="https://your-gateway.example.com/api/execute/currenttime", + cookies={"mcp_gateway_session": session_cookie} +) + +# Execute tool +result = await client.execute_tool( + "current_time_by_timezone", + {"params": {"tz_name": "America/New_York"}} +) +``` + +#### StreamableHTTP Transport (Python) + +```python +from mcp.client.streamable_http import StreamableHTTPClient + +# Create StreamableHTTP client with session cookie authentication +client = StreamableHTTPClient( + base_url="https://your-gateway.example.com/api/streamable/currenttime", + cookies={"mcp_gateway_session": session_cookie} +) + +# Execute tool +result = await client.execute_tool( + "current_time_by_timezone", + {"params": {"tz_name": "America/New_York"}} +) +``` + +## Setup and Configuration + +### AWS Cognito Setup + +#### 1. Create a Cognito User Pool + +1. Go to the AWS Management Console and navigate to Amazon Cognito +2. Create a new user pool with these key settings: + - Sign-in options: Email + - Security requirements: According to your needs + - Required attributes: Email + - App client: Confidential client type + - Callback URL: `http://localhost:7860/oauth/callback/cognito` (add your domain for production) + - OAuth grant types: Authorization code + +#### 2. Set Up Groups for Access Control + +Create the following groups in your Cognito user pool: + +```bash +# Admin access +aws cognito-idp create-group \ + --user-pool-id YOUR_USER_POOL_ID \ + --group-name mcp-admin \ + --description "Full administrative access" + +# Basic user access +aws cognito-idp create-group \ + --user-pool-id YOUR_USER_POOL_ID \ + --group-name mcp-user \ + --description "Read-only access" + +# Server-specific access +aws cognito-idp create-group \ + --user-pool-id YOUR_USER_POOL_ID \ + --group-name mcp-server-currenttime \ + --description "Access to currenttime server" + +# Server toggle permission +aws cognito-idp create-group \ + --user-pool-id YOUR_USER_POOL_ID \ + --group-name mcp-server-currenttime-toggle \ + --description "Permission to toggle currenttime server" + +# Tool-specific access +aws cognito-idp create-group \ + --user-pool-id YOUR_USER_POOL_ID \ + --group-name mcp-server-currenttime-tool-current_time_by_timezone \ + --description "Access to specific tool" +``` + +#### 3. Create Test Users + +```bash +# Create admin user +aws cognito-idp admin-create-user \ + --user-pool-id YOUR_USER_POOL_ID \ + --username admin@example.com \ + --user-attributes Name=email,Value=admin@example.com Name=email_verified,Value=true + +# Set password +aws cognito-idp admin-set-user-password \ + --user-pool-id YOUR_USER_POOL_ID \ + --username admin@example.com \ + --password "SecurePassword123!" \ + --permanent + +# Add to groups +aws cognito-idp admin-add-user-to-group \ + --user-pool-id YOUR_USER_POOL_ID \ + --username admin@example.com \ + --group-name mcp-admin +``` + +#### 4. Configure MCP Gateway for Cognito + +```bash +# Basic OAuth settings +export MCP_AUTH_ENABLED=true +export MCP_AUTH_PROVIDER_TYPE=cognito +export SECRET_KEY=$(python3 -c 'import secrets; print(secrets.token_hex(32))') + +# Cognito-specific settings +export MCP_AUTH_COGNITO_USER_POOL_ID=YOUR_USER_POOL_ID +export MCP_AUTH_COGNITO_CLIENT_ID=YOUR_CLIENT_ID +export MCP_AUTH_COGNITO_CLIENT_SECRET=YOUR_CLIENT_SECRET +export MCP_AUTH_COGNITO_CALLBACK_URI=http://localhost:7860/oauth/callback/cognito +export MCP_AUTH_COGNITO_REGION=YOUR_AWS_REGION +``` + +### Okta Setup + +#### 1. Create an Okta Application + +1. Sign in to the Okta Admin Console +2. Create a new app integration: + - Sign-in method: OIDC - OpenID Connect + - Application type: Web Application + - Grant types: Authorization Code + - Sign-in redirect URIs: `http://localhost:7860/oauth/callback/okta` + - Assignments: Assign to groups + +#### 2. Create Groups for Access Control + +Create the following groups in Okta: +- mcp-admin +- mcp-user +- mcp-server-{servername} +- mcp-server-{servername}-toggle + +#### 3. Configure MCP Gateway for Okta + +```bash +# Basic OAuth settings +export MCP_AUTH_ENABLED=true +export MCP_AUTH_PROVIDER_TYPE=okta +export SECRET_KEY=$(python3 -c 'import secrets; print(secrets.token_hex(32))') + +# Okta-specific settings +export MCP_AUTH_OKTA_TENANT_URL=https://your-domain.okta.com +export MCP_AUTH_OKTA_CLIENT_ID=YOUR_CLIENT_ID +export MCP_AUTH_OKTA_CLIENT_SECRET=YOUR_CLIENT_SECRET +export MCP_AUTH_OKTA_CALLBACK_URI=http://localhost:7860/oauth/callback/okta +``` + +### Generic OAuth Provider Setup + +For other OAuth 2.1 providers: + +```bash +# Basic OAuth settings +export MCP_AUTH_ENABLED=true +export MCP_AUTH_PROVIDER_TYPE=generic +export SECRET_KEY=$(python3 -c 'import secrets; print(secrets.token_hex(32))') + +# Provider settings +export MCP_AUTH_CLIENT_ID=YOUR_CLIENT_ID +export MCP_AUTH_CLIENT_SECRET=YOUR_CLIENT_SECRET +export MCP_AUTH_AUTHORIZE_URL=https://your-provider.com/oauth2/authorize +export MCP_AUTH_TOKEN_URL=https://your-provider.com/oauth2/token +export MCP_AUTH_JWKS_URL=https://your-provider.com/.well-known/jwks.json +export MCP_AUTH_CALLBACK_URI=http://localhost:7860/oauth/callback +export MCP_AUTH_SCOPES="openid profile email" +export MCP_AUTH_ISSUER=https://your-provider.com +``` + +## End-to-End Authentication + +MCP Gateway supports passing service-specific tokens (like GitHub, AWS) to MCP tools. Currently, the direct parameter approach can be leveraged: + +### Direct Parameter Approach + +This approach passes service-specific tokens directly as parameters to the tool: + +```python +@mcp.tool() +async def get_github_repos( + github_token: str = Field(..., description="GitHub Personal Access Token"), + username: str = Field(..., description="MCP Gateway username"), + password: str = Field(..., description="MCP Gateway password"), + # Other parameters... +) -> Dict[str, Any]: + # First authenticate with MCP Gateway + auth_credentials = Credentials(username=username, password=password) + await _ensure_authenticated(auth_credentials) + + # Use GitHub token for API calls + headers = { + "Authorization": f"token {github_token}", + "Accept": "application/vnd.github.v3+json" + } + + # Make GitHub API calls and return results + # ... +``` +## Deployment Examples + +### Docker Deployment with Cognito + +```bash +docker run -p 80:80 -p 443:443 -p 7860:7860 \ + -e ADMIN_USER=admin \ + -e ADMIN_PASSWORD=password \ + -e SECRET_KEY=$(python3 -c 'import secrets; print(secrets.token_hex(32))') \ + -e MCP_AUTH_ENABLED=true \ + -e MCP_AUTH_PROVIDER_TYPE=cognito \ + -e MCP_AUTH_COGNITO_USER_POOL_ID=YOUR_USER_POOL_ID \ + -e MCP_AUTH_COGNITO_CLIENT_ID=YOUR_CLIENT_ID \ + -e MCP_AUTH_COGNITO_CLIENT_SECRET=YOUR_CLIENT_SECRET \ + -e MCP_AUTH_COGNITO_CALLBACK_URI=http://localhost:7860/oauth/callback/cognito \ + -e MCP_AUTH_COGNITO_REGION=us-east-1 \ + -v $(pwd)/logs:/app/logs \ + -v $(pwd)/registry/servers:/app/registry/servers \ + --name mcp-gateway-container mcp-gateway +``` + +## Troubleshooting + +### Common Issues and Solutions + +#### 1. Authentication Failures + +- **Invalid Redirect URI Error**: + - Ensure the callback URI in IdP settings exactly matches the one in MCP Gateway + - Check for protocol mismatches (http vs https) + +- **Invalid Client Error**: + - Verify client ID and client secret + - Ensure the client is allowed to use the authorization code grant + +- **Token Validation Errors**: + - Check if JWKS endpoint is accessible + - Verify issuer and audience configuration + - Look for clock skew between servers + +#### 2. Permission Issues + +- **Access Denied Errors**: + - Confirm user is in the correct groups in the IdP + - Verify group naming convention follows the required pattern + - Check scope mapping configuration + +#### 3. JWT Validation Issues + +- **Signature Verification Failures**: + - Ensure JWKS URL is correct and accessible + - Check if token is properly formatted + +### Debugging + +To enable detailed debug logging: + +```bash +export LOGGING_LEVEL=DEBUG +``` + +Monitor authentication logs: + +```bash +docker logs -f mcp-gateway-container | grep -i 'auth\|scope\|access' +``` + +### Security Best Practices + +1. **Always use HTTPS** in production environments +2. **Minimize token scope** to follow least privilege principle +3. **Use short-lived tokens** and implement proper refresh logic +4. **Enable MFA** in your identity provider +5. **Log and monitor** authentication attempts + +--- + +## Further Resources + +- [OAuth 2.1 Specification](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-09) +- [AWS Cognito Documentation](https://docs.aws.amazon.com/cognito/latest/developerguide/cognito-user-pools.html) +- [Okta Developer Documentation](https://developer.okta.com/docs/guides/) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 60fde23..6d3d944 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "scikit-learn>=1.3.0", "torch>=1.6.0", "huggingface-hub[cli]>=0.31.1", + "pyjwt[crypto]>=2.8.0", "bandit>=1.8.3", "langchain-mcp-adapters>=0.0.11", "langgraph>=0.4.3", diff --git a/reddit_post.md b/reddit_post.md deleted file mode 100644 index ab31522..0000000 --- a/reddit_post.md +++ /dev/null @@ -1,63 +0,0 @@ -# [Open Source] MCP Gateway: A Centralized Hub for Managing Your Model Context Protocol Servers - -Hey r/mcp! - -I'm excited to share a new open-source project that might solve a pain point many of us are experiencing as we scale our MCP implementations: **MCP Gateway & Registry**. - -## The Problem - -As many of us have discovered, while MCP is revolutionizing how AI models connect with external tools and data sources, managing a growing collection of MCP servers quickly becomes challenging: - -* **Discoverability issues**: Which servers are available? What tools do they offer? -* **Configuration headaches**: Constantly updating URLs in AI agents for different servers -* **Management overhead**: Tracking health and status across multiple independent servers -* **Inconsistent access patterns**: Different teams implementing different approaches - -## The Solution: MCP Gateway & Registry - -The MCP Gateway transforms your scattered MCP landscape into an organized, manageable ecosystem: - -* **Single entry point** for all MCP traffic (both SSE and Streamable HTTP) -* **Centralized registry** with a web UI showing all available servers and their tools -* **Unified URL structure** (e.g., `gateway.mycorp.com/weather`, `gateway.mycorp.com/fininfo`) -* **Real-time health monitoring** with WebSocket updates -* **Dynamic configuration** that automatically updates routing rules - -![MCP Registry UI](https://github.com/aarora79/mcp-gateway/raw/main/docs/img/registry.png) - -## Meta-Capability: Self-Management Through MCP - -One of the coolest features is that the Gateway includes its own MCP server (`mcpgw`) that exposes management capabilities as MCP tools. This means AI agents can manage the Gateway directly through the MCP protocol! - -Tools include: -* `toggle_service`: Enable/disable servers -* `register_service`: Add new servers programmatically -* `get_service_tools`: List all tools from specific or all servers -* And more! - -## Tech Stack - -* **Nginx** as a powerful reverse proxy -* **FastAPI** for the Registry application -* **Docker** for easy deployment -* **WebSockets** for real-time updates - -## Getting Started - -The project is designed for both quick proof-of-concept deployments and production-ready implementations. Check out the [GitHub repo](https://github.com/aarora79/mcp-gateway/tree/main) for detailed instructions. - -## Roadmap - -Future plans include: -* OAUTH 2.1 support -* Intelligent tool finder -* Deployment automation for MCP servers -* GitHub API integration - -## Join Us! - -* **Try it out**: Follow the installation steps in the [README](https://github.com/aarora79/mcp-gateway/tree/main?tab=readme-ov-file#installation) -* **Contribute**: We welcome feedback, feature requests, and code contributions -* **Connect**: Join our community of AI practitioners building the future of AI tool integration - -Has anyone else been struggling with managing multiple MCP servers? Would love to hear your thoughts on this approach! \ No newline at end of file diff --git a/registry/__init__.py b/registry/__init__.py new file mode 100644 index 0000000..14cbe48 --- /dev/null +++ b/registry/__init__.py @@ -0,0 +1 @@ +# This file makes the registry directory a Python package \ No newline at end of file diff --git a/registry/auth/__init__.py b/registry/auth/__init__.py new file mode 100644 index 0000000..62176e3 --- /dev/null +++ b/registry/auth/__init__.py @@ -0,0 +1,18 @@ +""" +Authentication module for MCP Gateway. + +This module implements OAuth 2.1 authentication for MCP Gateway, +allowing integration with external identity providers like AWS Cognito and Okta. +""" + +from .provider import CognitoOAuthProvider, ConfigurableIdPAdapter +from .middleware import setup_auth_middleware, requires_scope +from .settings import AuthSettings + +__all__ = [ + "CognitoOAuthProvider", + "ConfigurableIdPAdapter", + "setup_auth_middleware", + "requires_scope", + "AuthSettings", +] \ No newline at end of file diff --git a/registry/auth/integration.py b/registry/auth/integration.py new file mode 100644 index 0000000..9590407 --- /dev/null +++ b/registry/auth/integration.py @@ -0,0 +1,302 @@ +""" +Integration module for MCP Gateway authentication. + +This module provides the integration functions for adding +OAuth 2.1 authentication to the MCP Gateway FastAPI application. +""" +import os +import json +import logging +from typing import Optional + +from fastapi import FastAPI +from fastapi.templating import Jinja2Templates + +from .middleware import setup_auth_middleware +from .settings import AuthSettings, IdPSettings, ScopeMapping +from .provider import ConfigurableIdPAdapter, CognitoOAuthProvider, OktaOAuthProvider +from .routes import setup_auth_routes + +logger = logging.getLogger(__name__) + + +def integrate_oauth(app: FastAPI, templates: Jinja2Templates) -> Optional[ConfigurableIdPAdapter]: + """ + Integrate OAuth 2.1 authentication with the MCP Gateway. + + This function sets up the authentication middleware, routes, and provider + for the MCP Gateway FastAPI application based on environment configuration. + + Args: + app: The FastAPI application + templates: The templates engine + + Returns: + The configured OAuth provider adapter, or None if disabled + """ + # Check if OAuth is enabled + enabled = os.environ.get("MCP_AUTH_ENABLED", "").lower() in ("true", "1", "yes") + if not enabled: + logger.info("OAuth 2.1 integration is disabled") + return None + + # Check for config file + config_path = os.environ.get("MCP_AUTH_CONFIG") + if config_path: + logger.info(f"Loading OAuth 2.1 configuration from {config_path}") + return setup_from_config(app, templates, config_path) + + # Load configuration from environment variables + logger.info("Loading OAuth 2.1 configuration from environment variables") + + # Determine provider type + provider_type = os.environ.get("MCP_AUTH_PROVIDER_TYPE", "").lower() + if not provider_type: + logger.warning("OAuth 2.1 provider type not specified") + return None + + # Configure the provider based on type + if provider_type == "cognito": + return _setup_cognito_from_env(app, templates) + elif provider_type == "okta": + return _setup_okta_from_env(app, templates) + else: + return _setup_generic_from_env(app, templates, provider_type) + + +def _setup_cognito_from_env(app: FastAPI, templates: Jinja2Templates) -> Optional[CognitoOAuthProvider]: + """Set up Cognito OAuth provider from environment variables.""" + user_pool_id = os.environ.get("MCP_AUTH_COGNITO_USER_POOL_ID") + client_id = os.environ.get("MCP_AUTH_COGNITO_CLIENT_ID") + client_secret = os.environ.get("MCP_AUTH_COGNITO_CLIENT_SECRET") + callback_uri = os.environ.get("MCP_AUTH_COGNITO_CALLBACK_URI") + region = os.environ.get("MCP_AUTH_COGNITO_REGION", "us-east-1") + custom_domain = os.environ.get("MCP_AUTH_COGNITO_CUSTOM_DOMAIN") + + if not all([user_pool_id, client_id, client_secret, callback_uri]): + logger.warning("Missing required Cognito configuration") + return None + + # Log the configuration for debugging + logger.info(f"Setting up Cognito with user_pool_id={user_pool_id}, region={region}") + if custom_domain: + logger.info(f"Using custom Cognito domain: {custom_domain}") + + # Create the Cognito provider with enhanced security + provider = CognitoOAuthProvider.from_user_pool( + user_pool_id=user_pool_id, + client_id=client_id, + client_secret=client_secret, + callback_uri=callback_uri, + region=region, + custom_domain=custom_domain + ) + + # Ensure we don't set audience for Cognito + if hasattr(provider.settings.idp_settings, 'audience'): + provider.settings.idp_settings.audience = None + logger.info("Removed audience parameter for Cognito OAuth flow") + + # Set up middleware and routes + setup_auth_middleware(app, provider, provider.settings) + setup_auth_routes(app, provider, provider.settings, templates) + + logger.info(f"Cognito OAuth provider set up with user pool {user_pool_id}") + + return provider + + +def _setup_okta_from_env(app: FastAPI, templates: Jinja2Templates) -> Optional[OktaOAuthProvider]: + """Set up Okta OAuth provider from environment variables.""" + tenant_url = os.environ.get("MCP_AUTH_OKTA_TENANT_URL") + client_id = os.environ.get("MCP_AUTH_OKTA_CLIENT_ID") + client_secret = os.environ.get("MCP_AUTH_OKTA_CLIENT_SECRET") + callback_uri = os.environ.get("MCP_AUTH_OKTA_CALLBACK_URI") + + if not all([tenant_url, client_id, client_secret, callback_uri]): + logger.warning("Missing required Okta configuration") + return None + + # Create the provider + provider = OktaOAuthProvider.from_tenant( + tenant_url=tenant_url, + client_id=client_id, + client_secret=client_secret, + callback_uri=callback_uri + ) + + # Set up middleware and routes + setup_auth_middleware(app, provider, provider.settings) + setup_auth_routes(app, provider, provider.settings, templates) + + logger.info(f"Okta OAuth provider set up with tenant {tenant_url}") + + return provider + + +def _setup_generic_from_env(app: FastAPI, templates: Jinja2Templates, provider_type: str) -> Optional[ConfigurableIdPAdapter]: + """Set up a generic OAuth provider from environment variables.""" + client_id = os.environ.get("MCP_AUTH_CLIENT_ID") + client_secret = os.environ.get("MCP_AUTH_CLIENT_SECRET") + authorize_url = os.environ.get("MCP_AUTH_AUTHORIZE_URL") + token_url = os.environ.get("MCP_AUTH_TOKEN_URL") + jwks_url = os.environ.get("MCP_AUTH_JWKS_URL") + callback_uri = os.environ.get("MCP_AUTH_CALLBACK_URI") + + if not all([client_id, client_secret, authorize_url, token_url, jwks_url, callback_uri]): + logger.warning(f"Missing required configuration for {provider_type} OAuth provider") + return None + + # Create settings + settings = AuthSettings() + settings.idp_settings = IdPSettings( + provider_type=provider_type, + client_id=client_id, + client_secret=client_secret, + authorize_url=authorize_url, + token_url=token_url, + jwks_url=jwks_url, + callback_uri=callback_uri, + scopes=os.environ.get("MCP_AUTH_SCOPES", "openid profile email").split(), + audience=os.environ.get("MCP_AUTH_AUDIENCE"), + issuer=os.environ.get("MCP_AUTH_ISSUER") + ) + + # Create default scope mapping + settings.scope_mapping = ScopeMapping( + idp_to_mcp={ + "admin": ["mcp:registry:admin"], + "user": ["mcp:registry:read"], + } + ) + + # Set default client ID and secret for client access + settings.default_client_id = client_id + settings.default_client_secret = client_secret + + # Create the provider + provider = ConfigurableIdPAdapter(settings) + + # Set up middleware and routes + setup_auth_middleware(app, provider, settings) + setup_auth_routes(app, provider, settings, templates) + + logger.info(f"{provider_type.title()} OAuth provider set up") + + return provider + + +def setup_from_config(app: FastAPI, templates, config_path: Optional[str] = None): + """ + Set up OAuth 2.1 from a configuration file. + + Args: + app: The FastAPI application + templates: The templates engine + config_path: Path to the configuration file + """ + settings = load_auth_settings(config_path) + + if not settings.enabled or not settings.idp_settings: + logger.info("OAuth 2.1 integration is disabled or not configured") + return None + + # Create the provider based on provider type + provider_type = settings.idp_settings.provider_type.lower() + if provider_type == "cognito": + provider = CognitoOAuthProvider(settings) + elif provider_type == "okta": + provider = OktaOAuthProvider(settings) + else: + # Generic provider for other IdPs + provider = ConfigurableIdPAdapter(settings) + + # Set up middleware and routes + setup_auth_middleware(app, provider, settings) + setup_auth_routes(app, provider, settings, templates) + + logger.info(f"OAuth 2.1 set up with {provider_type} provider from config file") + + return provider + + +def load_auth_settings(config_path: str) -> AuthSettings: + """ + Load authentication settings from a configuration file. + + Args: + config_path: Path to the configuration file + + Returns: + Authentication settings + """ + settings = AuthSettings() + + try: + with open(config_path, "r") as f: + config = json.load(f) + # Process config + settings.enabled = config.get("enabled", True) + + idp_config = config.get("idp", {}) + if idp_config: + settings.idp_settings = IdPSettings( + provider_type=idp_config.get("provider_type", ""), + client_id=idp_config.get("client_id", ""), + client_secret=idp_config.get("client_secret", ""), + authorize_url=idp_config.get("authorize_url", ""), + token_url=idp_config.get("token_url", ""), + jwks_url=idp_config.get("jwks_url", ""), + callback_uri=idp_config.get("callback_uri", ""), + scopes=idp_config.get("scopes", ["openid", "profile", "email"]), + audience=idp_config.get("audience"), + issuer=idp_config.get("issuer") + ) + + # Set default client ID and secret for client access + settings.default_client_id = idp_config.get("client_id", "") + settings.default_client_secret = idp_config.get("client_secret", "") + + # Process scope mappings + scope_mapping_config = config.get("scope_mapping", {}) + if scope_mapping_config: + settings.scope_mapping = ScopeMapping( + idp_to_mcp=scope_mapping_config.get("idp_to_mcp", {}), + mcp_to_idp=scope_mapping_config.get("mcp_to_idp", {}) + ) + else: + # Create default scope mapping if none provided + settings.scope_mapping = ScopeMapping( + idp_to_mcp={ + "admin": ["mcp:registry:admin"], + "user": ["mcp:registry:read"], + } + ) + + # Process scope names + scopes = config.get("scopes", {}) + if scopes: + settings.registry_admin_scope = scopes.get( + "registry_admin", settings.registry_admin_scope + ) + settings.registry_read_scope = scopes.get( + "registry_read", settings.registry_read_scope + ) + settings.server_execute_scope_prefix = scopes.get( + "server_prefix", settings.server_execute_scope_prefix + ) + settings.server_execute_scope_suffix = scopes.get( + "server_suffix", settings.server_execute_scope_suffix + ) + + # Process public routes + public_routes = config.get("public_routes") + if public_routes: + settings.public_routes = public_routes + + logger.info(f"Loaded auth settings from {config_path}") + return settings + + except Exception as e: + logger.error(f"Error loading auth config from {config_path}: {e}") + raise ValueError(f"Failed to load OAuth configuration from {config_path}: {e}") \ No newline at end of file diff --git a/registry/auth/middleware.py b/registry/auth/middleware.py new file mode 100644 index 0000000..993c59e --- /dev/null +++ b/registry/auth/middleware.py @@ -0,0 +1,416 @@ +""" +Authentication middleware for MCP Gateway. + +This module implements middleware for JWT verification and scope-based +access control using the MCP SDK's authentication components. +""" +import logging +from typing import List, Dict, Any, Union + +from fastapi import Request, HTTPException, status +from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.authentication import BaseUser + +logger = logging.getLogger(__name__) + +from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend +from mcp.server.auth.middleware.auth_context import AuthContextMiddleware + +from .provider import ConfigurableIdPAdapter +from .settings import AuthSettings + +from .provider import ConfigurableIdPAdapter +from .settings import AuthSettings + +logger = logging.getLogger(__name__) + + +class MCPUser(BaseUser): + """User with MCP-specific attributes.""" + + def __init__(self, client_id: str, scopes: List[str], claims: Dict[str, Any]): + self.client_id = client_id + self.scopes = set(scopes) + self.claims = claims + + @property + def is_authenticated(self) -> bool: + return True + + @property + def display_name(self) -> str: + return self.claims.get("name", self.claims.get("email", self.client_id)) + + def has_scope(self, scope: Union[str, List[str]]) -> bool: + """Check if the user has the specified scope(s).""" + if isinstance(scope, str): + return scope in self.scopes + return all(s in self.scopes for s in scope) + + +class SessionUser(BaseUser): + """User authenticated via session cookie.""" + + def __init__(self, username: str, groups: List[str] = None): + self.username = username + self.groups = groups or [] + + # Extract scopes from groups using the same logic as in provider.py + self.scopes = set() + for group in self.groups: + if isinstance(group, str): + # Add any group that starts with mcp: directly + if group.startswith("mcp:"): + self.scopes.add(group) + # Map specific groups to scopes + elif group == "mcp-admin": + self.scopes.add("mcp:registry:admin") + elif group == "mcp-user": + self.scopes.add("mcp:registry:read") + # Add server-specific scopes based on group names + elif group.startswith("mcp-server-"): + auth_settings = AuthSettings() + + # Process more specific scopes first + if "-toggle" in group: + # Extract server name (e.g., "mcp-server-currenttime-toggle" -> "currenttime") + server_name = group[len("mcp-server-"):group.find("-toggle")] + if server_name: + toggle_scope = f"{auth_settings.server_execute_scope_prefix}{server_name}:toggle" + self.scopes.add(toggle_scope) + # Also add read access with toggle permission + read_scope = f"{auth_settings.server_execute_scope_prefix}{server_name}:read" + self.scopes.add(read_scope) + # Log the scope mapping for debugging + logger.info(f"Group '{group}' mapped to scopes: {toggle_scope}, {read_scope}") + elif "-edit" in group: + # Extract server name (e.g., "mcp-server-currenttime-edit" -> "currenttime") + server_name = group[len("mcp-server-"):group.find("-edit")] + if server_name: + edit_scope = f"{auth_settings.server_execute_scope_prefix}{server_name}:edit" + self.scopes.add(edit_scope) + # Also add read access with edit permission + read_scope = f"{auth_settings.server_execute_scope_prefix}{server_name}:read" + self.scopes.add(read_scope) + elif "-tool-" in group: + # Handle tool-specific groups (e.g., "mcp-server-currenttime-tool-xyz") + parts = group.split("-tool-") + if len(parts) == 2: + server_part = parts[0] + tool_part = parts[1] + server_name = server_part[len("mcp-server-"):] + tool_scope = f"{auth_settings.server_execute_scope_prefix}{server_name}:tool:{tool_part}:execute" + self.scopes.add(tool_scope) + else: + # Extract server name from group (e.g., "mcp-server-currenttime" -> "currenttime") + server_name = group[len("mcp-server-"):] + if server_name: + # Create the server execute scope + server_scope = f"{auth_settings.server_execute_scope_prefix}{server_name}{auth_settings.server_execute_scope_suffix}" + self.scopes.add(server_scope) + # Also add read access for the server + read_scope = f"{auth_settings.server_execute_scope_prefix}{server_name}:read" + self.scopes.add(read_scope) + + @property + def is_authenticated(self) -> bool: + return True + + @property + def display_name(self) -> str: + return self.username + + def has_scope(self, scope: Union[str, List[str]]) -> bool: + """Check if the user has the specified scope(s).""" + if isinstance(scope, str): + return scope in self.scopes + return all(s in self.scopes for s in scope) + +class MCPAuthBackend(BearerAuthBackend): + """Authentication backend for MCP Gateway extending the SDK BearerAuthBackend.""" + + def __init__(self, provider: ConfigurableIdPAdapter, settings: AuthSettings): + super().__init__(provider) + self.settings = settings + + async def authenticate(self, request: Request): + """Authenticate the request using JWT.""" + # Skip authentication for public routes + path = request.url.path + if any(path.startswith(public_path) for public_path in self.settings.public_routes): + # Return anonymous user with no credentials + return None + + # Use the SDK's authenticate method for JWT validation + credentials = await super().authenticate(request) + + if not credentials: + return None + + auth_credentials, auth_user = credentials + + # Convert the SDK's AuthenticatedUser to our MCPUser + claims = getattr(auth_user.access_token, 'raw_claims', {}) or {} + user = MCPUser( + client_id=auth_user.username, + scopes=auth_user.scopes, + claims=claims + ) + + # Return the auth credentials and user + return auth_credentials, user + + +def setup_auth_middleware(app, provider: ConfigurableIdPAdapter, settings: AuthSettings): + """ + Set up authentication middleware for the application. + + Args: + app: The FastAPI application + provider: The OAuth provider adapter + settings: Authentication settings + """ + # Add the authentication middleware + app.add_middleware( + AuthenticationMiddleware, + backend=MCPAuthBackend(provider, settings) + ) + + # Add the auth context middleware + app.add_middleware(AuthContextMiddleware) + + +def requires_scope(scope: Union[str, List[str]]): + """ + Dependency for requiring a specific scope or scopes. + + Args: + scope: The scope or list of scopes required + + Returns: + A dependency function that checks if the user has the required scope(s) + """ + def dependency(request: Request): + # Check for user in request.state instead of directly on request + if not hasattr(request.state, "user") or not hasattr(request.state.user, "has_scope"): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if not request.state.user.has_scope(scope): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Missing required scope: {scope}", + ) + + return True + + return dependency + + +def requires_server_access(server_path: str): + """ + Dependency for requiring access to a specific server. + + Args: + server_path: The path of the server (e.g., "/currenttime") + + Returns: + A dependency function that checks if the user has access to the server + """ + auth_settings = AuthSettings() + required_scope = auth_settings.get_server_read_scope(server_path) + execute_scope = auth_settings.get_server_execute_scope(server_path) + + def dependency(request: Request): + # Check if user has admin scope (grants access to all servers) + if hasattr(request.state, "user") and hasattr(request.state.user, "has_scope"): + if request.state.user.has_scope(auth_settings.registry_admin_scope): + logger.info(f"User {request.state.user.display_name} granted access to {server_path} via admin scope") + return True + + # Check if user has the specific read scope + if request.state.user.has_scope(required_scope): + logger.info(f"User {request.state.user.display_name} granted access to {server_path} via read scope") + return True + + # Check if user has the execute scope (which implies read access) + if request.state.user.has_scope(execute_scope): + logger.info(f"User {request.state.user.display_name} granted access to {server_path} via execute scope") + return True + + # User doesn't have the required scope + logger.warning(f"User {request.state.user.display_name} denied access to {server_path} - missing scope: {required_scope}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Missing required scope for server access: {required_scope}", + ) + + # User is not authenticated + logger.warning(f"Unauthenticated user denied access to {server_path}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return dependency + + +def requires_server_toggle(server_path: str): + """ + Dependency for requiring toggle permission for a specific server. + + Args: + server_path: The path of the server (e.g., "/currenttime") + + Returns: + A dependency function that checks if the user has toggle permission + """ + + auth_settings = AuthSettings() + base_scope = auth_settings.server_execute_scope_prefix + server_path.lstrip("/") + toggle_scope = f"{base_scope}:toggle" + + def dependency(request: Request): + # Check if user has admin scope (grants access to all servers) + if hasattr(request.state, "user") and hasattr(request.state.user, "has_scope"): + if request.state.user.has_scope(auth_settings.registry_admin_scope): + logger.info(f"User {request.state.user.display_name} granted toggle access to {server_path} via admin scope") + return True + + # Check if user has the specific toggle scope + if request.state.user.has_scope(toggle_scope): + logger.info(f"User {request.state.user.display_name} granted toggle access to {server_path} via toggle scope") + return True + + # User doesn't have the required scope + logger.warning(f"User {request.state.user.display_name} denied toggle access to {server_path} - missing scope: {toggle_scope}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Missing required scope for server toggle: {toggle_scope}", + ) + + # User is not authenticated + logger.warning(f"Unauthenticated user denied toggle access to {server_path}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return dependency + + +def requires_server_edit(server_path: str): + """ + Dependency for requiring edit permission for a specific server. + + Args: + server_path: The path of the server (e.g., "/currenttime") + + Returns: + A dependency function that checks if the user has edit permission + """ + + auth_settings = AuthSettings() + base_scope = auth_settings.server_execute_scope_prefix + server_path.lstrip("/") + edit_scope = f"{base_scope}:edit" + + def dependency(request: Request): + # Check if user has admin scope (grants access to all servers) + if hasattr(request.state, "user") and hasattr(request.state.user, "has_scope"): + if request.state.user.has_scope(auth_settings.registry_admin_scope): + logger.info(f"User {request.state.user.display_name} granted edit access to {server_path} via admin scope") + return True + + # Check if user has the specific edit scope + if request.state.user.has_scope(edit_scope): + logger.info(f"User {request.state.user.display_name} granted edit access to {server_path} via edit scope") + return True + + # User doesn't have the required scope + logger.warning(f"User {request.state.user.display_name} denied edit access to {server_path} - missing scope: {edit_scope}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Missing required scope for server edit: {edit_scope}", + ) + + # User is not authenticated + logger.warning(f"Unauthenticated user denied edit access to {server_path}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + return dependency +# Helper functions for route-specific dependencies +def require_toggle_for_path(service_path: str): + """ + Dependency function for requiring toggle permission for a specific path. + + Args: + service_path: The path parameter from the route + + Returns: + A dependency function that checks if the user has toggle permission + """ + def dependency(request: Request): + return requires_server_toggle(service_path)(request) + return dependency + +def require_edit_for_path(service_path: str): + """ + Dependency function for requiring edit permission for a specific path. + + Args: + service_path: The path parameter from the route + + Returns: + A dependency function that checks if the user has edit permission + """ + def dependency(request: Request): + return requires_server_edit(service_path)(request) + return dependency + +def require_access_for_path(service_path: str): + """ + Dependency function for requiring access permission for a specific path. + + Args: + service_path: The path parameter from the route + + Returns: + A dependency function that checks if the user has access permission + """ + def dependency(request: Request): + return requires_server_access(service_path)(request) + return dependency +def check_admin_scope(): + """Dependency function for requiring admin scope.""" + def dependency(request: Request): + auth_settings = AuthSettings() + + if not hasattr(request.state, "user") or not hasattr(request.state.user, "has_scope"): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if not request.state.user.has_scope(auth_settings.registry_admin_scope): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Missing required scope: {auth_settings.registry_admin_scope}", + ) + + return True + + return dependency + +def require_registry_admin(): + """Dependency function for requiring registry admin scope.""" + return requires_scope("mcp:registry:admin") diff --git a/registry/auth/provider.py b/registry/auth/provider.py new file mode 100644 index 0000000..9856028 --- /dev/null +++ b/registry/auth/provider.py @@ -0,0 +1,1337 @@ +""" +OAuth 2.1 provider adapters for MCP Gateway. + +This module implements the OAuthAuthorizationServerProvider interface from +the MCP Python SDK, delegating authentication to external identity providers. +""" +import json +import time +import os +import secrets +import urllib.parse +from typing import Dict, List, Optional, Any, Tuple +import logging +import hashlib +import base64 +import httpx +import jwt +from fastapi import Request + +from mcp.server.auth.provider import ( + OAuthAuthorizationServerProvider, + AuthorizationParams, + AuthorizationCode, + RefreshToken, + AccessToken, + TokenError, + AuthorizeError, + RegistrationError +) +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthToken, +) + +from .settings import IdPSettings, AuthSettings, ScopeMapping + +logger = logging.getLogger(__name__) + +def generate_pkce_pair(): + """Generate a PKCE code_verifier and code_challenge pair following SDK standards. + + Returns: + tuple: A tuple containing (code_verifier, code_challenge) + """ + # Generate a secure random string for the code verifier + code_verifier = secrets.token_urlsafe(43) # 43 bytes → ≈ 58 characters + + # Create code challenge with S256 method + code_challenge = base64.urlsafe_b64encode( + hashlib.sha256(code_verifier.encode()).digest() + ).decode().rstrip("=") + + return code_verifier, code_challenge + + +def create_authorization_params( + redirect_uri: str, + scopes: List[str] = None, + state: str = None, + code_challenge_method: str = "S256" +) -> AuthorizationParams: + """ + Create AuthorizationParams for OAuth flows. + + Args: + redirect_uri: The callback URI for the authorization flow + scopes: List of requested scopes (optional) + state: State parameter for the flow (generated if not provided) + code_challenge_method: PKCE code challenge method + + Returns: + Properly configured AuthorizationParams + """ + if state is None: + state = secrets.token_hex(16) + + # Generate PKCE code verifier and challenge using SDK-compatible method + code_verifier, code_challenge = generate_pkce_pair() + + # Create parameters object with SDK's expected format + params = AuthorizationParams( + redirect_uri=redirect_uri, + scopes=scopes or [], + state=state, + code_challenge=code_challenge, + redirect_uri_provided_explicitly=True, + ) + + # Store code_verifier as a custom attribute + # We need this later but the SDK doesn't have it on AuthorizationParams + setattr(params, "code_verifier", code_verifier) + setattr(params, "code_challenge_method", "S256") # Always use S256 for security + + return params + + +class MCP_AuthCode(AuthorizationCode): + """Authorization code for tracking the external IdP code flow.""" + external_code: Optional[str] = None + state: Optional[str] = None + + +class MCP_AccessToken(AccessToken): + """Access token with JWT-specific fields.""" + id_token: Optional[str] = None + raw_claims: Optional[Dict[str, Any]] = None + + +class ConfigurableIdPAdapter(OAuthAuthorizationServerProvider[MCP_AuthCode, RefreshToken, MCP_AccessToken]): + """ + Generic OAuth provider adapter for external identity providers. + Supports any OAuth 2.1 compliant provider including Cognito, Okta, etc. + """ + + def __init__(self, settings: AuthSettings): + self.settings = settings + self.idp_settings = settings.idp_settings + + # State tracking + self._state_mapping = {} # Maps external state to client request details + self._auth_codes = {} # Maps MCP auth codes to AuthCode objects + self._access_tokens = {} # Maps access tokens to token info + self._refresh_tokens = {} # Maps refresh tokens to token info + self._clients = {} # Maps client IDs to client information + + async def authorize( + self, client: OAuthClientInformationFull, params: AuthorizationParams + ) -> str: + """Generate an authorization URL for the external IdP.""" + if not self.idp_settings: + raise AuthorizeError( + error="server_error", + error_description="Identity provider not configured" + ) + + # Generate state for tracking this request + state = secrets.token_hex(16) + + # Generate PKCE code verifier and challenge + code_verifier, code_challenge = generate_pkce_pair() + + # Add logging for the timestamp-modified state from oauth_login + import logging + logger = logging.getLogger(__name__) + logger.info(f"OAuth authorize - Original params.state: {params.state[:20] if params.state else None}") + + # Store original request details with our state + self._state_mapping[state] = { + "client_id": client.client_id, + "redirect_uri": str(params.redirect_uri), + "redirect_uri_provided_explicitly": params.redirect_uri_provided_explicitly, + "scopes": params.scopes or [], + "state": params.state, # This includes the timestamp from oauth_login + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "code_verifier": code_verifier, # Store code verifier for later + "original_timestamp": int(time.time()), # Add timestamp for debugging + } + + # Map MCP scopes to IdP scopes if needed + requested_scopes = params.scopes or [] + idp_scopes = self.idp_settings.scopes.copy() # Start with default IdP scopes + + # Custom scope mapping for specific providers + if hasattr(self.settings, 'scope_mapping') and self.settings.scope_mapping: + if hasattr(self.settings.scope_mapping, 'mcp_to_idp'): + for scope in requested_scopes: + mapped = self.settings.scope_mapping.mcp_to_idp.get(scope, []) + for mapped_scope in mapped: + if mapped_scope not in idp_scopes: + idp_scopes.append(mapped_scope) + + # Build the authorization URL + # Use the environment variable for the base URL if available + callback_uri = self.idp_settings.callback_uri + base_url = os.environ.get("MCP_AUTH_BASE_URL") + + # Ensure callback_uri is a fully qualified URL + if base_url: + # If callback_uri is relative or contains localhost, update it with base_url + if not callback_uri.startswith('http') or "localhost" in callback_uri: + from urllib.parse import urlparse + parsed_uri = urlparse(callback_uri) if callback_uri.startswith('http') else None + path = parsed_uri.path if parsed_uri else callback_uri + if not path.startswith('/'): + path = '/' + path + callback_uri = f"{base_url}{path}" + logger.info(f"Updated callback URI for authorization: {callback_uri}") + + auth_params = { + "client_id": self.idp_settings.client_id, + "redirect_uri": callback_uri, + "response_type": "code", + "state": state, + "scope": " ".join(idp_scopes), + "code_challenge": code_challenge, + "code_challenge_method": "S256" + } + + # Only add audience if it's explicitly set and not None + if self.idp_settings.audience is not None: + auth_params["audience"] = self.idp_settings.audience + logger.info(f"Adding audience parameter: {self.idp_settings.audience}") + + url = f"{self.idp_settings.authorize_url}?{urllib.parse.urlencode(auth_params)}" + return url + + async def handle_external_callback( + self, code: str, state: str, request: Request + ) -> Tuple[str, str]: + """ + Handle the callback from the external Identity Provider. + + This is a custom extension method not part of the OAuthAuthorizationServerProvider + Protocol interface. It's used to handle the redirection from the external IdP and + create a proper authorization code in our system. + + Args: + code: The authorization code from the external IdP + state: The state parameter from the external IdP + request: The FastAPI request object + + Returns: + A tuple of (redirect_url, authorization_code) where: + - redirect_url: URL to redirect the user back to the client app + - authorization_code: The generated MCP authorization code + + Raises: + AuthorizeError: If the callback contains invalid parameters + """ + # Add debug logging + import logging + logger = logging.getLogger(__name__) + logger.info(f"handle_external_callback called with state={state[:10] if state else None}") + logger.info(f"Current state mappings: {list(self._state_mapping.keys())}") + + # Verify the state is one we generated + if state not in self._state_mapping: + logger.error(f"State '{state[:10] if state else None}' not found in state mapping") + # Try to find similar state values for debugging + similar_states = [s for s in self._state_mapping.keys() if state and s.startswith(state[:5])] + if similar_states: + logger.info(f"Found similar states: {similar_states}") + + # Try using the first similar state as a fallback + state = similar_states[0] + logger.info(f"Using similar state as fallback: {state[:10]}") + else: + # If we can't find any similar states, return redirect to login page with error + logger.error("No similar states found, returning auth error") + raise AuthorizeError( + error="invalid_state", + error_description="Invalid state parameter - no matching or similar state found" + ) + + # Get the original request details + request_details = self._state_mapping[state] + + # Generate a new MCP authorization code + mcp_code = secrets.token_hex(16) + + # Create and store authorization code with external code + auth_code = MCP_AuthCode( + code=mcp_code, + client_id=request_details["client_id"], + redirect_uri=request_details["redirect_uri"], + redirect_uri_provided_explicitly=request_details["redirect_uri_provided_explicitly"], + scopes=request_details["scopes"], + code_challenge=request_details.get("code_challenge"), + code_challenge_method=request_details.get("code_challenge_method"), + expires_at=int(time.time() + 600), # 10 minute expiry + external_code=code, + state=state, # Store state for accessing code_verifier later + ) + + self._auth_codes[mcp_code] = auth_code + + # Build redirect back to original client using our own implementation + # instead of the SDK utility which is causing issues + from urllib.parse import urlparse, urlencode + + # Parse the redirect URI + parsed_uri = urlparse(request_details["redirect_uri"]) + + # Create query parameters + query_params = { + "code": mcp_code, + "state": request_details["state"] + } + + # Construct the new URI with query parameters + scheme = parsed_uri.scheme + netloc = parsed_uri.netloc + path = parsed_uri.path + query = urlencode(query_params) + + # Combine all parts to form the redirect URL + redirect_url = f"{scheme}://{netloc}{path}?{query}" + + # Ensure we return a proper tuple + from typing import Tuple + + # Force return as proper tuple with explicit typing + result: Tuple[str, str] = (redirect_url, mcp_code) + logger.info(f"Returning callback result as tuple: {result}") + + # Return the tuple explicitly to ensure correct unpacking + return redirect_url, mcp_code + + async def load_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: str + ) -> Optional[MCP_AuthCode]: + """Load the authorization code.""" + code = self._auth_codes.get(authorization_code) + if not code: + return None + + # Verify code belongs to this client + if code.client_id != client.client_id: + return None + + # Check if expired + if code.expires_at < time.time(): + if authorization_code in self._auth_codes: + del self._auth_codes[authorization_code] + return None + + return code + + async def exchange_authorization_code( + self, client: OAuthClientInformationFull, authorization_code: MCP_AuthCode + ) -> OAuthToken: + """Exchange auth code for tokens, communicating with external IdP.""" + # Get our code with the external_code + auth_code = self._auth_codes.get(authorization_code.code) + if not auth_code or not hasattr(auth_code, 'external_code'): + raise TokenError( + error="invalid_grant", + error_description="Invalid authorization code" + ) + + # Exchange the external code for tokens with the IdP + try: + token_data = await self._exchange_code_with_idp(auth_code.external_code, auth_code) + except Exception as e: + raise TokenError( + error="server_error", + error_description=f"Failed to exchange token with IdP: {str(e)}" + ) + + # Clean up the used code + if authorization_code.code in self._auth_codes: + del self._auth_codes[authorization_code.code] + + # Generate local tokens for MCP + access_token = secrets.token_hex(32) + refresh_token = secrets.token_hex(32) + + # Store the access token with IdP token info + expires_in = token_data.get("expires_in", 3600) + expires_at = int(time.time() + expires_in) + + # Extract scopes from the token response + scope_str = token_data.get("scope", "") + scopes = scope_str.split() if isinstance(scope_str, str) else [] + + # Add scopes from the authorization code + if authorization_code.scopes: + for scope in authorization_code.scopes: + if scope not in scopes: + scopes.append(scope) + + token_obj = MCP_AccessToken( + token=access_token, + client_id=client.client_id, + scopes=scopes, + expires_at=expires_at, + id_token=token_data.get("id_token"), + raw_claims={"id_token": token_data.get("id_token"), "access_token": token_data.get("access_token")} + ) + + self._access_tokens[access_token] = token_obj + + # Store refresh token + refresh_obj = RefreshToken( + token=refresh_token, + client_id=client.client_id, + scopes=scopes + ) + + self._refresh_tokens[refresh_token] = refresh_obj + + # Return token response + return OAuthToken( + access_token=access_token, + token_type="Bearer", + expires_in=expires_in, + refresh_token=refresh_token, + scope=" ".join(scopes) + ) + + async def _exchange_code_with_idp(self, code: str, auth_code: MCP_AuthCode = None) -> dict: + """Exchange authorization code with the IdP.""" + # Use the environment variable for the base URL if available + callback_uri = self.idp_settings.callback_uri + base_url = os.environ.get("MCP_AUTH_BASE_URL") + + # Ensure callback_uri is a fully qualified URL + if base_url: + # If callback_uri is relative or contains localhost, update it with base_url + if not callback_uri.startswith('http') or "localhost" in callback_uri: + from urllib.parse import urlparse + parsed_uri = urlparse(callback_uri) if callback_uri.startswith('http') else None + path = parsed_uri.path if parsed_uri else callback_uri + if not path.startswith('/'): + path = '/' + path + callback_uri = f"{base_url}{path}" + logger.info(f"Updated callback URI for token exchange: {callback_uri}") + + token_params = { + "grant_type": "authorization_code", + "client_id": self.idp_settings.client_id, + "client_secret": self.idp_settings.client_secret, + "code": code, + "redirect_uri": callback_uri + } + + # Get the code verifier for PKCE + code_verifier = None + + # Get from state mapping if available + if auth_code and auth_code.state and auth_code.state in self._state_mapping: + state_data = self._state_mapping[auth_code.state] + if "code_verifier" in state_data and state_data["code_verifier"]: + code_verifier = state_data["code_verifier"] + logger.info(f"Using code_verifier from state mapping for PKCE token exchange") + + # Add code_verifier if found (required for PKCE with Cognito) + if code_verifier: + token_params["code_verifier"] = code_verifier + logger.info(f"Added code_verifier to token request parameters") + else: + logger.warning(f"No code_verifier found for PKCE token exchange") + + logger.info(f"Making token request to {self.idp_settings.token_url}") + + # Use a reasonably short timeout to prevent delays for users + # Default is 10 seconds, which is generous but prevents very long hangs + timeout_seconds = int(os.environ.get("MCP_AUTH_IDP_TIMEOUT", "10")) + + logger.info(f"Making token request with {timeout_seconds}s timeout to {self.idp_settings.token_url}") + + async with httpx.AsyncClient() as client: + try: + response = await client.post( + self.idp_settings.token_url, + data=token_params, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=timeout_seconds # Add timeout to prevent delays + ) + + # Log response status and headers for debugging + logger.info(f"Token response status: {response.status_code}") + + # Raise for HTTP errors + response.raise_for_status() + + # Parse response JSON + token_data = response.json() + logger.info(f"Token response received with keys: {list(token_data.keys())}") + + return token_data + except httpx.HTTPStatusError as e: + # Log detailed error information + error_detail = f"HTTP Status {e.response.status_code}" + try: + error_body = e.response.json() + error_detail += f" - Error: {error_body.get('error', 'unknown')}" + error_detail += f" - Description: {error_body.get('error_description', 'No description')}" + except Exception: + error_detail += f" - Body: {e.response.text[:200]}" + + logger.error(f"Token exchange HTTP error: {error_detail}") + raise TokenError( + error="invalid_request", + error_description=f"Failed to exchange token with IdP: {error_detail}" + ) + except httpx.RequestError as e: + logger.error(f"Token exchange request error: {e}") + raise TokenError( + error="server_error", + error_description=f"Connection error during token exchange: {e}" + ) + except Exception as e: + logger.error(f"Unexpected error during token exchange: {e}", exc_info=True) + raise TokenError( + error="server_error", + error_description=f"Unexpected error during token exchange: {e}" + ) + raise TokenError( + error="invalid_request", + error_description=f"Failed to exchange code with IdP: {response.status_code} - {response.text[:100]}" + ) + + async def load_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: str + ) -> Optional[RefreshToken]: + """Load refresh token.""" + token = self._refresh_tokens.get(refresh_token) + if not token: + return None + + if token.client_id != client.client_id: + return None + + return token + + async def exchange_refresh_token( + self, client: OAuthClientInformationFull, refresh_token: RefreshToken, scopes: List[str] + ) -> OAuthToken: + """Exchange refresh token for new tokens.""" + # Remove the old refresh token + old_token = refresh_token.token + if old_token in self._refresh_tokens: + del self._refresh_tokens[old_token] + + # Generate new tokens + access_token = secrets.token_hex(32) + new_refresh_token = secrets.token_hex(32) + + # Store new access token + expires_in = 3600 + expires_at = int(time.time() + expires_in) + + # Use requested scopes or fall back to original + final_scopes = scopes if scopes else refresh_token.scopes + + token_obj = MCP_AccessToken( + token=access_token, + client_id=client.client_id, + scopes=final_scopes, + expires_at=expires_at + ) + + self._access_tokens[access_token] = token_obj + + # Store new refresh token + refresh_obj = RefreshToken( + token=new_refresh_token, + client_id=client.client_id, + scopes=final_scopes + ) + + self._refresh_tokens[new_refresh_token] = refresh_obj + + # Return token response + return OAuthToken( + access_token=access_token, + token_type="Bearer", + expires_in=expires_in, + refresh_token=new_refresh_token, + scope=" ".join(final_scopes) + ) + + async def load_access_token(self, token: str) -> Optional[MCP_AccessToken]: + """Load and validate access token.""" + # Check our local token store first + local_token = self._access_tokens.get(token) + if local_token: + # Check if expired + if local_token.expires_at and local_token.expires_at < time.time(): + if token in self._access_tokens: + del self._access_tokens[token] + return None + + return local_token + + # If not in our store, it might be a direct JWT from the IdP + try: + # Extract provider-specific claims and scopes + decoded = self._decode_and_validate_jwt(token) + if not decoded: + return None + + # Extract scopes based on provider type + scopes = self._extract_scopes_from_claims(decoded) + + # Map external scopes to MCP scopes if mappings exist + mapped_scopes = [] + if hasattr(self.settings, 'scope_mapping') and self.settings.scope_mapping: + if hasattr(self.settings.scope_mapping, 'idp_to_mcp'): + for scope in scopes: + mapped = self.settings.scope_mapping.idp_to_mcp.get(scope, []) + mapped_scopes.extend(mapped) + + # Create token object + token_obj = MCP_AccessToken( + token=token, + client_id=decoded.get("client_id", decoded.get("aud", "unknown")), + scopes=list(set(scopes + mapped_scopes)), # Remove duplicates + expires_at=decoded.get("exp"), + id_token=token, + raw_claims=decoded + ) + + # Cache the validated token + self._access_tokens[token] = token_obj + + return token_obj + + except TokenError as te: + # Use TokenError if already raised by underlying validation + logger.warning(f"JWT validation failed: {te.error} - {te.error_description}") + return None + except Exception as e: + # Convert other exceptions to standard format + logger.warning(f"JWT validation failed: {e}") + return None + + async def _get_jwks(self) -> Optional[Dict[str, Any]]: + """ + Fetch JWKS (JSON Web Key Set) from the identity provider. + + Returns: + The JWKS data or None if retrieval fails + """ + if not self.idp_settings.jwks_url: + logger.warning("No JWKS URL configured for JWT validation") + return None + + try: + # Use a reasonably short timeout to prevent delays + timeout_seconds = int(os.environ.get("MCP_AUTH_IDP_TIMEOUT", "10")) + + async with httpx.AsyncClient() as client: + response = await client.get( + self.idp_settings.jwks_url, + timeout=timeout_seconds + ) + response.raise_for_status() + return response.json() + + except Exception as e: + logger.error(f"Failed to fetch JWKS from {self.idp_settings.jwks_url}: {e}") + return None + + def _decode_and_validate_jwt(self, token: str) -> Optional[Dict[str, Any]]: + """ + Decode and validate a JWT token with proper signature verification. + + Subclasses can override this method to implement provider-specific validation. + """ + try: + # First, decode headers to get key ID + unverified_header = jwt.get_unverified_header(token) + kid = unverified_header.get("kid") + + if not kid: + logger.warning("JWT token missing key ID (kid) in header") + raise TokenError( + error="invalid_token", + error_description="Token missing key ID" + ) + + # Get JWKS to find the public key + import asyncio + jwks = asyncio.run(self._get_jwks()) + if not jwks: + logger.error("Unable to retrieve JWKS for token verification") + raise TokenError( + error="invalid_token", + error_description="Unable to verify token signature" + ) + + # Find the matching key + public_key = None + for key in jwks.get("keys", []): + if key.get("kid") == kid: + # Convert JWK to PEM format for PyJWT + from jwt.algorithms import RSAAlgorithm + public_key = RSAAlgorithm.from_jwk(key) + break + + if not public_key: + logger.warning(f"No matching key found for kid: {kid}") + raise TokenError( + error="invalid_token", + error_description="Unable to verify token signature" + ) + + # Decode and verify the token with signature verification enabled + decoded = jwt.decode( + token, + public_key, + algorithms=["RS256"], # Most OAuth providers use RS256 + audience=self.idp_settings.audience, + issuer=self.idp_settings.issuer, + options={ + "verify_signature": True, + "verify_exp": True, + "verify_nbf": True, + "verify_iat": True, + "verify_aud": True if self.idp_settings.audience else False, + "verify_iss": True if self.idp_settings.issuer else False + } + ) + + # PyJWT has already validated exp, nbf, iat, aud, and iss claims + # Return the validated claims + logger.info("JWT token validated successfully with signature verification") + return decoded + + except jwt.ExpiredSignatureError: + logger.warning("JWT token has expired") + raise TokenError( + error="invalid_token", + error_description="Token has expired" + ) + except jwt.InvalidAudienceError: + logger.warning("JWT token has invalid audience") + raise TokenError( + error="invalid_token", + error_description="Invalid token audience" + ) + except jwt.InvalidIssuerError: + logger.warning("JWT token has invalid issuer") + raise TokenError( + error="invalid_token", + error_description="Invalid token issuer" + ) + except jwt.InvalidSignatureError: + logger.warning("JWT token has invalid signature") + raise TokenError( + error="invalid_token", + error_description="Invalid token signature" + ) + except jwt.InvalidTokenError as e: + logger.warning(f"JWT token is invalid: {e}") + raise TokenError( + error="invalid_token", + error_description="Invalid token" + ) + except Exception as e: + logger.error(f"Unexpected error during JWT validation: {e}") + raise TokenError( + error="invalid_token", + error_description="Token validation failed" + ) + + def _extract_scopes_from_claims(self, claims: Dict[str, Any]) -> List[str]: + """ + Extract scopes from JWT claims. + + Different providers store scopes in different claim formats. + Subclasses can override this method for provider-specific scope extraction. + """ + scopes = [] + + # Standard 'scope' claim (space-separated string) + if "scope" in claims: + if isinstance(claims["scope"], str): + scopes.extend(claims["scope"].split()) + elif isinstance(claims["scope"], list): + scopes.extend([s for s in claims["scope"] if isinstance(s, str)]) + + # OIDC 'scp' claim (array of strings) + if "scp" in claims and isinstance(claims["scp"], list): + scopes.extend([s for s in claims["scp"] if isinstance(s, str)]) + + # Generic handling for groups/roles + # Handle standard OIDC/OAuth groups claim + if "groups" in claims and isinstance(claims["groups"], list): + for group in claims["groups"]: + if isinstance(group, str) and group.startswith("mcp:"): + scopes.append(group) + + # Handle roles claim + if "roles" in claims and isinstance(claims["roles"], list): + for role in claims["roles"]: + if isinstance(role, str) and role.startswith("mcp:"): + scopes.append(role) + + return scopes + + async def get_client(self, client_id: str) -> Optional[OAuthClientInformationFull]: + """ + Retrieves client information by client ID. + + This implementation supports two client types: + 1. Clients stored in the internal registry via register_client + 2. A default client using the configured IdP client credentials + + Args: + client_id: The ID of the client to retrieve. + + Returns: + The client information, or None if the client does not exist. + """ + # Return from cache if available + if client_id in self._clients: + return self._clients[client_id] + + # If we're using a default client from settings, create it + if self.settings.default_client_id and self.settings.default_client_id == client_id: + if not self.idp_settings or not self.idp_settings.callback_uri: + return None + + client = OAuthClientInformationFull( + client_id=client_id, + client_secret=self.settings.default_client_secret, + # Use the callback URI from IdP settings as redirect URI + redirect_uris=[self.idp_settings.callback_uri], + scope=" ".join(self.idp_settings.scopes) if self.idp_settings.scopes else None, + token_endpoint_auth_method="client_secret_post", + grant_types=["authorization_code", "refresh_token"], + response_types=["code"], + client_name="MCP Gateway" + ) + # Cache the client + self._clients[client_id] = client + return client + + # Client not found + return None + + async def register_client(self, client_info: OAuthClientInformationFull) -> None: + """ + Saves client information as part of registering it. + + Args: + client_info: The client metadata to register. + + Raises: + RegistrationError: If the client metadata is invalid. + """ + # Validate client metadata before storing + if not client_info.redirect_uris or len(client_info.redirect_uris) == 0: + raise RegistrationError( + error="invalid_redirect_uri", + error_description="At least one redirect URI must be provided" + ) + + # Validate grant types include authorization_code + if "authorization_code" not in client_info.grant_types: + raise RegistrationError( + error="invalid_client_metadata", + error_description="Client must support 'authorization_code' grant type" + ) + + # Store the client + self._clients[client_info.client_id] = client_info + logger.info(f"Registered client with ID {client_info.client_id}") + + async def revoke_token(self, token: AccessToken | RefreshToken) -> None: + """Revoke a token.""" + token_str = token.token + + if isinstance(token, AccessToken): + if token_str in self._access_tokens: + del self._access_tokens[token_str] + elif isinstance(token, RefreshToken): + if token_str in self._refresh_tokens: + del self._refresh_tokens[token_str] + + +class CognitoOAuthProvider(ConfigurableIdPAdapter): + """ + AWS Cognito OAuth provider adapter using pycognito for enhanced security. + + This implementation leverages the pycognito library to handle: + - Proper JWT verification and token validation + - Secure token exchange with PKCE + - Automatic token refresh + + It maintains compatibility with the MCP Python SDK by implementing the + OAuthAuthorizationServerProvider protocol. + """ + + @classmethod + def from_user_pool(cls, user_pool_id: str, client_id: str, client_secret: str, + callback_uri: str, region: str = "us-east-1", + custom_domain: str = None) -> "CognitoOAuthProvider": + """ + Create a Cognito provider from user pool details. + + Args: + user_pool_id: The Cognito user pool ID + client_id: The app client ID + client_secret: The app client secret + callback_uri: The callback URI for the OAuth flow + region: AWS region, defaults to us-east-1 + custom_domain: Optional custom domain for Cognito + + Returns: + Configured CognitoOAuthProvider + """ + # Extract region from the user pool ID (format is region_poolID) + region_from_id = user_pool_id.split('_')[0] + pool_id = user_pool_id.split('_')[1] + + # Log the parsed values for debugging + logger.info(f"Parsed user pool ID: region={region_from_id}, pool_id={pool_id}") + + # Cognito hosted UI domains + domain_prefix = f"{region_from_id}{pool_id}" + logger.info(f"Using domain prefix: {domain_prefix}") + + # Standard Cognito domain formats + domain = f"cognito-idp.{region}.amazonaws.com/{user_pool_id}" + + # Determine the domain to use for authorization and token endpoints + if custom_domain: + auth_domain = custom_domain + logger.info(f"Using custom Cognito domain: {auth_domain}") + else: + # Cognito OAuth uses the hosted UI domain, not the API domain + auth_domain = f"{domain_prefix}.auth.{region}.amazoncognito.com" + logger.info(f"Using Cognito hosted UI domain: {auth_domain}") + + # Build IdP settings for the provider + idp_settings = IdPSettings( + provider_type="cognito", + client_id=client_id, + client_secret=client_secret, + authorize_url=f"https://{auth_domain}/oauth2/authorize", + token_url=f"https://{auth_domain}/oauth2/token", + jwks_url=f"https://cognito-idp.{region}.amazonaws.com/{user_pool_id}/.well-known/jwks.json", + callback_uri=callback_uri, + audience=None, # Setting audience to None for Cognito to avoid scope issues + issuer=f"https://cognito-idp.{region}.amazonaws.com/{user_pool_id}", + scopes=["openid", "email", "profile"] # Default scopes for Cognito + ) + + # Create scope mapping for Cognito groups + scope_mapping = ScopeMapping( + idp_to_mcp={ + "admin": ["mcp:registry:admin"], + "user": ["mcp:registry:read"], + "cognito:admin": ["mcp:registry:admin"], + "cognito:user": ["mcp:registry:read"], + } + ) + + # Create auth settings that combine everything + settings = AuthSettings( + enabled=True, + idp_settings=idp_settings, + scope_mapping=scope_mapping, + default_client_id=client_id, + default_client_secret=client_secret + ) + + # Initialize the provider with these settings + return cls(settings) + + def __init__(self, settings: AuthSettings): + """ + Initialize the Cognito provider using pycognito. + + Args: + settings: Authentication settings including Cognito configuration + """ + super().__init__(settings) + + # Import pycognito for Cognito operations + from pycognito import Cognito + self.Cognito = Cognito + + # Extract Cognito-specific settings for easy access + self.user_pool_id = None + self.region = "us-east-1" # Default region + + # Parse user pool ID from issuer URL if available + if settings.idp_settings and settings.idp_settings.issuer: + issuer_parts = settings.idp_settings.issuer.split('/') + for part in issuer_parts: + if '_' in part: # User pool IDs contain an underscore + self.user_pool_id = part + self.region = part.split('_')[0] + break + + # Create a Cognito client for admin operations + try: + self.cognito_client = self.Cognito( + user_pool_id=self.user_pool_id, + client_id=settings.idp_settings.client_id, + client_secret=settings.idp_settings.client_secret, + ) + logger.info("Initialized pycognito client for Cognito provider") + except Exception as e: + logger.warning(f"Error initializing pycognito client: {e}") + + logger.info(f"Initialized CognitoOAuthProvider with user pool ID: {self.user_pool_id} in region: {self.region}") + + async def _exchange_code_with_idp(self, code: str, auth_code: MCP_AuthCode = None) -> dict: + """ + Exchange authorization code with Cognito. + + Args: + code: The authorization code from Cognito + auth_code: The MCP authorization code object + + Returns: + Token response from Cognito + + Raises: + TokenError: If token exchange fails + """ + # Get the code verifier for PKCE + code_verifier = None + + # Get from state mapping if available + if auth_code and auth_code.state and auth_code.state in self._state_mapping: + state_data = self._state_mapping[auth_code.state] + if "code_verifier" in state_data and state_data["code_verifier"]: + code_verifier = state_data["code_verifier"] + logger.info(f"Using code_verifier from state mapping for PKCE token exchange") + + # Prepare token exchange parameters + token_params = { + "grant_type": "authorization_code", + "client_id": self.settings.idp_settings.client_id, + "client_secret": self.settings.idp_settings.client_secret, + "code": code, + "redirect_uri": self.settings.idp_settings.callback_uri + } + + # Add code_verifier if found (required for PKCE) + if code_verifier: + token_params["code_verifier"] = code_verifier + + try: + # Use HTTPX for token exchange since pycognito doesn't directly support PKCE with code_verifier + # Use a short timeout to prevent long user-visible delays + timeout_seconds = int(os.environ.get("MCP_AUTH_IDP_TIMEOUT", "10")) + logger.info(f"Making Cognito token request with {timeout_seconds}s timeout") + + async with httpx.AsyncClient() as client: + response = await client.post( + self.settings.idp_settings.token_url, + data=token_params, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + timeout=timeout_seconds # Add timeout to prevent delays + ) + + if response.status_code != 200: + logger.error(f"Cognito token error: {response.status_code} {response.text}") + raise TokenError( + error="invalid_request", + error_description=f"Failed to exchange code with Cognito: {response.status_code} - {response.text[:100]}" + ) + + token_data = response.json() + logger.info(f"Successfully exchanged code for tokens with Cognito") + + return token_data + + except Exception as e: + if isinstance(e, TokenError): + raise + logger.error(f"Error during token exchange: {e}") + raise TokenError( + error="server_error", + error_description=f"Failed to exchange token with Cognito: {str(e)}" + ) + + def _decode_and_validate_jwt(self, token: str) -> Optional[Dict[str, Any]]: + """ + Decode and validate a JWT token using pycognito. + + Args: + token: The JWT token to validate + + Returns: + The validated token claims or None if validation fails + """ + if not token: + return None + + try: + # Create a temporary Cognito instance with the token + cognito = self.Cognito( + user_pool_id=self.user_pool_id, + client_id=self.settings.idp_settings.client_id, + id_token=token + ) + + # Verify the token - this will check the signature using JWKS + cognito.verify_token(token, 'id', cognito.id_token) + + # If verification is successful, return the claims + if hasattr(cognito, 'id_claims'): + return cognito.id_claims + + # If pycognito doesn't provide claims directly, decode the token manually + return jwt.decode( + token, + options={"verify_signature": False} # Already verified by pycognito + ) + + except Exception as e: + logger.warning(f"JWT validation failed: {e}") + return None + + # The _get_jwks method is no longer needed since we're using pycognito for JWT validation + # which handles JWKS retrieval internally + + def _extract_scopes_from_claims(self, claims: Dict[str, Any]) -> List[str]: + """ + Extract scopes from Cognito-specific JWT claims. + + Args: + claims: The JWT claims + + Returns: + List of scopes extracted from the claims + """ + scopes = super()._extract_scopes_from_claims(claims) + + # Cognito-specific: extract scopes from Cognito groups + if "cognito:groups" in claims and isinstance(claims["cognito:groups"], list): + for group in claims["cognito:groups"]: + if isinstance(group, str): + # Add any group that starts with mcp: directly + if group.startswith("mcp:"): + scopes.append(group) + logger.info(f"Added direct scope from Cognito group: {group}") + + # Map specific groups to scopes + elif group == "mcp-admin" or group == "admin": + scopes.append("mcp:registry:admin") + logger.info(f"Added admin scope from Cognito group: {group}") + elif group == "mcp-user" or group == "user": + scopes.append("mcp:registry:read") + logger.info(f"Added read scope from Cognito group: {group}") + + # Server-specific groups + elif group.startswith("mcp-server-"): + parts = group[len("mcp-server-"):].split("-") + + # Basic server group (e.g., mcp-server-currenttime) + if len(parts) == 1: + server_name = parts[0] + if server_name: + # Add read and execute scopes + read_scope = f"{self.settings.server_execute_scope_prefix}{server_name}:read" + execute_scope = f"{self.settings.server_execute_scope_prefix}{server_name}{self.settings.server_execute_scope_suffix}" + scopes.append(read_scope) + scopes.append(execute_scope) + logger.info(f"Added read and execute scopes for server {server_name} from group {group}") + + # Server admin group (e.g., mcp-server-currenttime-admin) + elif len(parts) > 1 and parts[-1] == "admin": + server_name = "-".join(parts[:-1]) + if server_name: + # Add all server scopes (read, execute, toggle, edit) + base_scope = f"{self.settings.server_execute_scope_prefix}{server_name}" + scopes.append(f"{base_scope}:read") + scopes.append(f"{base_scope}:execute") + scopes.append(f"{base_scope}:toggle") + scopes.append(f"{base_scope}:edit") + logger.info(f"Added all admin scopes for server {server_name} from group {group}") + + # Server toggle group (e.g., mcp-server-currenttime-toggle) + elif len(parts) > 1 and parts[-1] == "toggle": + server_name = "-".join(parts[:-1]) + if server_name: + toggle_scope = f"{self.settings.server_execute_scope_prefix}{server_name}:toggle" + scopes.append(toggle_scope) + logger.info(f"Added toggle scope for server {server_name} from group {group}") + + # Server edit group (e.g., mcp-server-currenttime-edit) + elif len(parts) > 1 and parts[-1] == "edit": + server_name = "-".join(parts[:-1]) + if server_name: + edit_scope = f"{self.settings.server_execute_scope_prefix}{server_name}:edit" + scopes.append(edit_scope) + logger.info(f"Added edit scope for server {server_name} from group {group}") + + # Tool-specific group (e.g., mcp-server-currenttime-tool-toolname) + elif len(parts) > 2 and parts[-2] == "tool": + server_name = "-".join(parts[:-2]) + tool_name = parts[-1] + if server_name and tool_name: + tool_scope = f"{self.settings.server_execute_scope_prefix}{server_name}:tool:{tool_name}:execute" + scopes.append(tool_scope) + logger.info(f"Added tool-specific scope for {tool_name} on server {server_name} from group {group}") + + # Extract from custom:roles attribute (JSON or comma-separated) + if "custom:roles" in claims and isinstance(claims["custom:roles"], str): + try: + # Try to parse as JSON + roles = json.loads(claims["custom:roles"]) + if isinstance(roles, list): + for role in roles: + if isinstance(role, str) and role.startswith("mcp:"): + scopes.append(role) + logger.info(f"Added scope from custom:roles JSON: {role}") + except json.JSONDecodeError: + # If not JSON, treat as comma-separated string + for role in claims["custom:roles"].split(","): + role = role.strip() + if role.startswith("mcp:"): + scopes.append(role) + logger.info(f"Added scope from custom:roles string: {role}") + + return scopes + + async def load_access_token(self, token: str) -> Optional[MCP_AccessToken]: + """ + Load and validate access token using pycognito. + + Args: + token: The access token to validate + + Returns: + Token object or None if validation fails + """ + # Check our local token store first + local_token = self._access_tokens.get(token) + if local_token: + # Check if expired + if local_token.expires_at and local_token.expires_at < time.time(): + if token in self._access_tokens: + del self._access_tokens[token] + return None + + return local_token + + # If not in our store, validate the JWT token with pycognito + try: + # Extract and validate claims using pycognito + decoded = self._decode_and_validate_jwt(token) + if not decoded: + return None + + # Extract scopes based on provider type + scopes = self._extract_scopes_from_claims(decoded) + + # Map external scopes to MCP scopes if mappings exist + mapped_scopes = [] + if hasattr(self.settings, 'scope_mapping') and self.settings.scope_mapping: + if hasattr(self.settings.scope_mapping, 'idp_to_mcp'): + for scope in scopes: + mapped = self.settings.scope_mapping.idp_to_mcp.get(scope, []) + mapped_scopes.extend(mapped) + + # Create token object + token_obj = MCP_AccessToken( + token=token, + client_id=decoded.get("client_id", decoded.get("aud", "unknown")), + scopes=list(set(scopes + mapped_scopes)), # Remove duplicates + expires_at=decoded.get("exp"), + id_token=token, + raw_claims=decoded + ) + + # Cache the validated token + self._access_tokens[token] = token_obj + + return token_obj + + except Exception as e: + # Handle any validation errors + logger.warning(f"JWT validation failed: {e}") + return None + + +class OktaOAuthProvider(ConfigurableIdPAdapter): + """Okta-specific OAuth provider adapter.""" + + @classmethod + def from_tenant(cls, tenant_url: str, client_id: str, client_secret: str, + callback_uri: str) -> "OktaOAuthProvider": + """ + Create an Okta provider from tenant details. + + Args: + tenant_url: The Okta tenant URL (https://your-org.okta.com) + client_id: The app client ID + client_secret: The app client secret + callback_uri: The callback URI for the OAuth flow + + Returns: + Configured OktaOAuthProvider + """ + # Remove trailing slash if present + tenant_url = tenant_url.rstrip('/') + + idp_settings = IdPSettings( + provider_type="okta", + client_id=client_id, + client_secret=client_secret, + authorize_url=f"{tenant_url}/oauth2/v1/authorize", + token_url=f"{tenant_url}/oauth2/v1/token", + jwks_url=f"{tenant_url}/oauth2/v1/keys", + callback_uri=callback_uri, + audience="api://default", + issuer=tenant_url + ) + + # Create scope mapping + scope_mapping = ScopeMapping( + idp_to_mcp={ + "admin": ["mcp:registry:admin"], + "user": ["mcp:registry:read"], + "mcp-admin": ["mcp:registry:admin"], + "mcp-user": ["mcp:registry:read"], + } + ) + + settings = AuthSettings( + enabled=True, + idp_settings=idp_settings, + scope_mapping=scope_mapping, + default_client_id=client_id, + default_client_secret=client_secret + ) + + return cls(settings) + + def _extract_scopes_from_claims(self, claims: Dict[str, Any]) -> List[str]: + """Extract scopes from Okta-specific JWT claims.""" + scopes = super()._extract_scopes_from_claims(claims) + + # Okta-specific: groups claim + if "groups" in claims and isinstance(claims["groups"], list): + for group in claims["groups"]: + if isinstance(group, str): + # Add mcp: prefix to Okta groups that match our naming convention + if group.startswith("mcp-"): + scopes.append(f"mcp:{group[4:]}") + # Add direct matches for groups already prefixed + elif group.startswith("mcp:"): + scopes.append(group) + + return scopes \ No newline at end of file diff --git a/registry/auth/routes.py b/registry/auth/routes.py new file mode 100644 index 0000000..2b14beb --- /dev/null +++ b/registry/auth/routes.py @@ -0,0 +1,491 @@ +""" +OAuth routes for MCP Gateway. + +This module implements the routes required for the OAuth 2.1 flow, +including the login route and callback handler. +""" +import os +import logging +import secrets +from typing import Optional, List +from urllib.parse import urlparse +from datetime import datetime +import jwt + +from fastapi import APIRouter, Request, Depends, HTTPException, Form +from fastapi.responses import HTMLResponse, RedirectResponse +from fastapi.templating import Jinja2Templates +from starlette import status +from starlette.routing import Route +from pydantic import AnyHttpUrl +from fastapi.responses import HTMLResponse +from itsdangerous import URLSafeTimedSerializer + +from mcp.server.auth.routes import create_auth_routes, AUTHORIZATION_PATH, TOKEN_PATH +from mcp.server.auth.provider import TokenError, AuthorizeError, RegistrationError, AuthorizationParams +from .provider import create_authorization_params +from mcp.shared.auth import OAuthMetadata, OAuthClientInformationFull, OAuthClientMetadata + +from .provider import ConfigurableIdPAdapter +from .settings import AuthSettings + +logger = logging.getLogger(__name__) + +# Create a router for auth routes +router = APIRouter(tags=["auth"]) + +# Global storage for OAuth state +OAUTH_STATE_STORAGE = {} + +# Global variables to store provider, settings, and templates +# These will be set by setup_auth_routes +_provider = None +_settings = None +_templates = None + + +def setup_auth_routes(app, provider: ConfigurableIdPAdapter, settings: AuthSettings, templates: Jinja2Templates): + """ + Set up authentication routes for the application. + + Args: + app: The FastAPI application + provider: The OAuth provider adapter + settings: Authentication settings + templates: The templates engine + """ + # Store globals for route handlers + global _provider, _settings, _templates + _provider = provider + _settings = settings + _templates = templates + + # Add standard OAuth routes using the SDK helper + if settings.enabled and settings.idp_settings: + base_url = os.environ.get("MCP_AUTH_BASE_URL", "http://localhost:7860") + issuer_url = AnyHttpUrl(base_url) + + # Update callback URI to use the base_url + if settings.idp_settings.callback_uri and "localhost" in settings.idp_settings.callback_uri: + # Extract the path part from the callback URI + from urllib.parse import urlparse + parsed_uri = urlparse(settings.idp_settings.callback_uri) + path = parsed_uri.path + + # Create new callback URI with the base_url + settings.idp_settings.callback_uri = f"{base_url}{path}" + logger.info(f"Updated callback URI to: {settings.idp_settings.callback_uri}") + + # Create standard OAuth routes + sdk_routes = create_auth_routes( + provider=provider, + issuer_url=issuer_url, + service_documentation_url=None, + client_registration_options=None, + revocation_options=None + ) + + # Add these routes to the application + for route in sdk_routes: + app.routes.append(route) + + # Add custom routes + app.include_router(router) + +@router.get("/login", response_class=HTMLResponse) +async def login_form(request: Request, error: Optional[str] = None): + """Render the login page with OAuth option.""" + # Create the response with the login template + response = _templates.TemplateResponse( + "login.html", + { + "request": request, + "error": error, + "oauth_enabled": _settings.enabled if _settings else False, + "provider_type": _settings.idp_settings.provider_type if _settings and _settings.enabled and _settings.idp_settings else None, + "timestamp": datetime.now().timestamp(), + } + ) + + # Add cache control headers to prevent browsers from caching the login page + # This helps prevent automatic login after logout + response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' + response.headers['Pragma'] = 'no-cache' + response.headers['Expires'] = '0' + + return response + + +@router.get("/oauth/login") +async def oauth_login(request: Request, t: str = None): + """ + Initiate the OAuth flow by redirecting to the IdP. + The 't' parameter is a timestamp for cache busting - used for state generation. + """ + if not _settings or not _provider or not _settings.enabled or not _settings.idp_settings: + raise AuthorizeError( + error="server_error", + error_description="OAuth is not enabled or not configured properly" + ) + + # Ensure we always have a timestamp parameter for cache-busting + timestamp = t if t else str(int(datetime.now().timestamp())) + logger.info(f"OAuth login initiated with timestamp: {timestamp}") + + # Create a simple client for the OAuth flow + client_id = _settings.idp_settings.client_id + callback_uri = _settings.idp_settings.callback_uri + + # Create OAuth client information + base_url = os.environ.get("MCP_AUTH_BASE_URL", "http://localhost:7860") + logger.info(f"Using base URL: {base_url}") + + # Ensure callback_uri is a fully qualified URL + if not callback_uri.startswith('http'): + callback_uri = f"{base_url}{callback_uri if callback_uri.startswith('/') else '/' + callback_uri}" + logger.info(f"Converted callback URI to absolute URL: {callback_uri}") + + # Generate a completely unique callback URI with current timestamp + # This ensures the browser cannot reuse a cached redirect + if "?" not in callback_uri: + callback_uri = f"{callback_uri}?t={timestamp}&r={secrets.token_hex(8)}" + else: + callback_uri = f"{callback_uri}&t={timestamp}&r={secrets.token_hex(8)}" + + logger.info(f"Generated callback URI with cache busting: {callback_uri}") + # Ensure we have a valid absolute URL for OAuthClientMetadata + if not callback_uri.startswith('http'): + logger.error(f"Callback URI is not an absolute URL: {callback_uri}") + raise AuthorizeError( + error="server_error", + error_description="Invalid callback URI format - must be absolute URL" + ) + + client = OAuthClientInformationFull( + client_id=client_id, + client_secret=_settings.idp_settings.client_secret, + redirect_uris=[callback_uri], + client_metadata=OAuthClientMetadata( + client_name="MCP Gateway", + client_uri=base_url, + redirect_uris=[callback_uri] + ) + ) + + # Log the base URL and callback URI for debugging + logger.info(f"Using base URL: {base_url}") + logger.info(f"Using callback URI: {callback_uri}") + + # Create completely unique state with timestamp and random token + # This makes each OAuth flow unique and prevents browser caching + unique_state = f"{secrets.token_hex(16)}_{timestamp}_{secrets.token_hex(8)}" + + # Create base params with unique state and additional parameters + params = AuthorizationParams( + redirect_uri=callback_uri, + scopes=_settings.idp_settings.scopes or [], + state=unique_state, + code_challenge=secrets.token_hex(32), + redirect_uri_provided_explicitly=True, + # Add prompt=login to force re-authentication even if already logged in + extra_params={"prompt": "login"} + ) + + # Get the authorization URL - our implementation will handle PKCE internally + auth_url = await _provider.authorize(client, params) + + # Log the full authorization URL for debugging + logger.info(f"Generated authorization URL: {auth_url}") + + # Redirect to the IdP's authorization page with strong no-cache headers + response = RedirectResponse(url=auth_url, status_code=status.HTTP_303_SEE_OTHER) + response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate, private' + response.headers['Pragma'] = 'no-cache' + response.headers['Expires'] = '0' + response.headers['Clear-Site-Data'] = '"cookies", "storage"' + return response + + +@router.get("/oauth/callback") +async def oauth_callback(request: Request, code: str = None, state: str = None): + """Handle the callback from the IdP.""" + logger.error(f"OAuth callback received with code={code is not None}, state={state[:10] if state else None}") + logger.error(f"Full request URL: {request.url}") + logger.error(f"Request query params: {request.query_params}") + + # Check for error parameters in the callback + error = request.query_params.get("error") + error_description = request.query_params.get("error_description") + + if error: + logger.error(f"Error in OAuth callback parameters: {error} - {error_description}") + raise AuthorizeError( + error=error, + error_description=error_description or "Error in OAuth callback" + ) + + if not code or not state: + # Use StandardError format from SDK + logger.error(f"Missing code or state parameter. code={code is not None}, state={state is not None}") + raise AuthorizeError( + error="invalid_request", + error_description="Missing code or state parameter" + ) + + if not _settings or not _provider: + raise AuthorizeError( + error="server_error", + error_description="OAuth is not configured properly" + ) + + try: + # Process the callback using the SDK + try: + # Step 1: Get redirect URL and auth code + # The provider's handle_external_callback method will do the exchange + redirect_url = "/" + auth_code = None + + # Try to process with the provider with improved logging and error handling + try: + logger.info(f"Calling provider.handle_external_callback with code length: {len(code) if code else 0}, state: {state[:10] if state else None}") + + # Get the result from the provider + redirect_result = await _provider.handle_external_callback(code, state, request) + + # Log the received result type for debugging + logger.info(f"Received result from handle_external_callback: type={type(redirect_result)}") + + # Properly handle the tuple result + if isinstance(redirect_result, tuple) and len(redirect_result) == 2: + redirect_url, auth_code = redirect_result + logger.info(f"Successfully unpacked redirect_url: {redirect_url} and auth_code: {auth_code}") + elif isinstance(redirect_result, str): + # If we got just a string (the redirect URL), use that + redirect_url = redirect_result + # Create a new auth code + auth_code = secrets.token_hex(16) + logger.info(f"Got redirect URL string only, generated auth_code: {auth_code}") + else: + # This should not happen with the fixed provider code + logger.error(f"Unexpected result type from handle_external_callback: {type(redirect_result)}") + raise ValueError(f"Expected tuple or string result from handle_external_callback, got {type(redirect_result)}") + except Exception as e: + # Log the error with full traceback and raise it - no fallback + logger.error(f"Error in handle_external_callback: {e}", exc_info=True) + raise + + # Step 2: Get the auth code object (if it exists) + auth_code_obj = _provider._auth_codes.get(auth_code) + + # Step 3: Extract user information from token with enhanced logging + user_info = {"name": None, "email": ""} + + # If we have an auth code object with an external code, use it to get the token + if auth_code_obj and hasattr(auth_code_obj, 'external_code'): + try: + # Get the external token with detailed logging + logger.info(f"Attempting to exchange auth code for token using external_code") + token_data = await _provider._exchange_code_with_idp(auth_code_obj.external_code, auth_code_obj) + logger.info(f"Token exchange successful. Token data keys: {list(token_data.keys())}") + + # Extract user identity from token with detailed logging + if 'id_token' in token_data: + logger.info("ID token found in token data, attempting to decode") + try: + id_token_claims = jwt.decode( + token_data['id_token'], + options={"verify_signature": False} + ) + logger.info(f"JWT decoded successfully. Available claims: {list(id_token_claims.keys())}") + + # Extract user identity information with detailed logging + if 'cognito:username' in id_token_claims: + user_info['name'] = id_token_claims['cognito:username'] + logger.info(f"Using 'cognito:username' claim: {user_info['name']}") + + # Extract Cognito groups if available - but don't store them yet + if 'cognito:groups' in id_token_claims and isinstance(id_token_claims['cognito:groups'], list): + groups = id_token_claims['cognito:groups'] + logger.info(f"Found Cognito groups in token: {groups}") + elif 'preferred_username' in id_token_claims: + user_info['name'] = id_token_claims['preferred_username'] + logger.info(f"Using 'preferred_username' claim: {user_info['name']}") + elif 'name' in id_token_claims: + user_info['name'] = id_token_claims['name'] + logger.info(f"Using 'name' claim: {user_info['name']}") + elif 'email' in id_token_claims: + user_info['name'] = id_token_claims['email'] + user_info['email'] = id_token_claims['email'] + logger.info(f"Using 'email' claim: {user_info['name']}") + else: + # Log all available claims to help diagnose issues + logger.error(f"No usable identity claims found in token. Available claims: {list(id_token_claims.keys())}") + raise ValueError("No usable identity claims found in token") + + logger.info(f"Successfully extracted user identity: {user_info['name']}") + except Exception as e: + logger.error(f"Error decoding ID token: {e}", exc_info=True) + raise # Re-raise to prevent fallback to default user + else: + logger.error(f"No ID token found in token data. Available keys: {list(token_data.keys())}") + raise ValueError("No ID token found in token data") + except Exception as e: + logger.error(f"Error exchanging auth code for token: {e}", exc_info=True) + raise # Re-raise to prevent fallback to default user + else: + logger.error("No auth_code_obj available or missing external_code attribute") + raise ValueError("Cannot proceed without valid authorization code") + + # Ensure we have a valid user name + if not user_info['name']: + logger.error("Failed to extract user identity from token") + raise ValueError("Failed to extract user identity from token") + + # Step 4: Create HTML response + html_content = f""" + + + + Authenticated - Redirecting... + + + + +
+

Authentication Successful

+

Welcome, {user_info['name']}! You have been authenticated successfully.

+

Redirecting to the dashboard...

+

If you are not redirected automatically, click here.

+
+ + + """ + + # Use HTMLResponse instead of RedirectResponse + response = HTMLResponse(content=html_content, status_code=200) + + + # Step 5: Set up the session cookie + secret_key = os.environ.get("SECRET_KEY", "insecure-default-key-for-testing-only") + signer = URLSafeTimedSerializer(secret_key) + session_cookie_name = "mcp_gateway_session" + session_max_age = 60 * 60 * 8 # 8 hours + + # Create session data with actual username, groups, and provider type + session_data = { + "username": user_info['name'], # Use the actual username from token claims + "oauth_code": auth_code, + "is_oauth": True, # Make sure this is explicitly set to True + "email": user_info.get('email', ''), + "provider_type": _settings.idp_settings.provider_type, # Store the provider type for logout + "login_time": datetime.now().isoformat(), # Add timestamp for debugging + "session_id": secrets.token_hex(16), # Unique session identifier + "auth_method": "oauth" # Add redundant auth method indicator for safety + } + + logger.info(f"Creating OAuth session with provider: {_settings.idp_settings.provider_type}") + + # Store Cognito groups in session if available + if 'cognito:groups' in id_token_claims and isinstance(id_token_claims['cognito:groups'], list): + session_data["groups"] = id_token_claims['cognito:groups'] + logger.info(f"Storing Cognito groups in session: {id_token_claims['cognito:groups']}") + + # Serialize the session data + serialized_session = signer.dumps(session_data) + + # Set the session cookie + logger.info(f"Setting session cookie {session_cookie_name} with data length: {len(serialized_session)} for user {user_info['name']}") + response.set_cookie( + key=session_cookie_name, + value=serialized_session, + max_age=session_max_age, + httponly=True, + path="/", # Ensure the cookie is accessible across all paths + samesite="lax" + ) + + # Log to confirm cookie was set + logger.info(f"Session cookie set. Response headers: {response.headers}") + + return response + + except AuthorizeError as ae: + # Pass through SDK AuthorizeError with proper OAuth error format + logger.error(f"OAuth authorization error: {ae.error} - {ae.error_description}") + return RedirectResponse( + url=f"/login?error={ae.error}&error_description={ae.error_description}", + status_code=status.HTTP_303_SEE_OTHER + ) + except TokenError as te: + # Handle SDK TokenError with proper OAuth error format + logger.error(f"OAuth token error: {te.error} - {te.error_description}") + return RedirectResponse( + url=f"/login?error={te.error}&error_description={te.error_description}", + status_code=status.HTTP_303_SEE_OTHER + ) + except Exception as e: + # Convert generic exceptions to standard OAuth server_error + logger.error(f"Error handling OAuth callback: {e}") + return RedirectResponse( + url=f"/login?error=server_error&error_description={str(e)}", + status_code=status.HTTP_303_SEE_OTHER + ) + + +@router.get("/oauth/callback/{provider}") +async def provider_callback(request: Request, provider: str, code: str = None, state: str = None): + """Handle callbacks for specific providers.""" + # Log detailed information to debug session issues + logger.info(f"Provider callback received for {provider}. Code length: {len(code) if code else 0}, State: {state[:10] if state else None}") + logger.info(f"Provider callback URL: {request.url}") + logger.info(f"Provider callback query params: {request.query_params}") + + try: + # Forward to the main callback handler + logger.error(f"Provider callback - Request query params: {request.query_params}") + logger.error(f"Provider callback - Request headers: {request.headers}") + + # Check for error parameters in the callback + error = request.query_params.get("error") + error_description = request.query_params.get("error_description") + if error: + logger.error(f"Error in OAuth callback: {error} - {error_description}") + + # Provide more specific guidance based on the error + user_message = error_description + if error == "invalid_request" and "invalid_scope" in error_description: + user_message = "Invalid scope error. Please check that the scopes configured in Cognito match the requested scopes (openid, profile, email)." + + return RedirectResponse( + url=f"/login?error={error}&error_description={user_message}", + status_code=status.HTTP_303_SEE_OTHER + ) + + response = await oauth_callback(request, code, state) + # Log the response details to debug cookie issues + logger.info(f"Callback completed. Response status: {response.status_code}, Headers: {response.headers}") + return response + except Exception as e: + logger.error(f"Exception in provider callback: {str(e)}", exc_info=True) + # Return a more detailed error for debugging + return RedirectResponse( + url=f"/login?error=callback_error&error_description=Provider+callback+error:+{str(e)}", + status_code=status.HTTP_303_SEE_OTHER + ) \ No newline at end of file diff --git a/registry/auth/settings.py b/registry/auth/settings.py new file mode 100644 index 0000000..c607dfa --- /dev/null +++ b/registry/auth/settings.py @@ -0,0 +1,118 @@ +""" +Settings for MCP Gateway authentication. +""" +from dataclasses import dataclass, field +from typing import Dict, List, Optional + + +@dataclass +class IdPSettings: + """Settings for a specific identity provider.""" + provider_type: str # "cognito", "okta", etc. + client_id: str + client_secret: str + authorize_url: str + token_url: str + jwks_url: str + callback_uri: str + scopes: List[str] = field(default_factory=lambda: ["openid", "profile", "email"]) + audience: Optional[str] = None + issuer: Optional[str] = None + + +@dataclass +class ScopeMapping: + """Maps between IdP scopes and MCP Gateway scopes.""" + idp_to_mcp: Dict[str, List[str]] = field(default_factory=dict) + mcp_to_idp: Dict[str, List[str]] = field(default_factory=dict) + + +@dataclass +class AuthSettings: + """Main authentication settings for MCP Gateway.""" + enabled: bool = True + idp_settings: Optional[IdPSettings] = None + scope_mapping: ScopeMapping = field(default_factory=ScopeMapping) + registry_admin_scope: str = "mcp:registry:admin" + registry_read_scope: str = "mcp:registry:read" + server_execute_scope_prefix: str = "mcp:server:" + server_execute_scope_suffix: str = ":execute" + public_routes: List[str] = field(default_factory=lambda: [ + "/login", "/oauth/callback", "/static", "/favicon.ico" + ]) + # OAuth client ID and secret used by the MCP Gateway with the external IdP + default_client_id: Optional[str] = None + default_client_secret: Optional[str] = None + + def get_server_execute_scope(self, server_path: str) -> str: + """Get the scope needed to execute tools on a specific server.""" + # Remove leading slash and replace remaining slashes with underscores + normalized_path = server_path.lstrip("/").replace("/", "_") + return f"{self.server_execute_scope_prefix}{normalized_path}{self.server_execute_scope_suffix}" + + def get_server_read_scope(self, server_path: str) -> str: + """Get the read scope for a server.""" + # Remove leading slash and replace remaining slashes with underscores + normalized_path = server_path.lstrip("/").replace("/", "_") + return f"{self.server_execute_scope_prefix}{normalized_path}:read" + + def get_server_toggle_scope(self, server_path: str) -> str: + """Get the toggle scope for a server.""" + # Remove leading slash and replace remaining slashes with underscores + normalized_path = server_path.lstrip("/").replace("/", "_") + return f"{self.server_execute_scope_prefix}{normalized_path}:toggle" + + def get_server_edit_scope(self, server_path: str) -> str: + """Get the edit scope for a server.""" + # Remove leading slash and replace remaining slashes with underscores + normalized_path = server_path.lstrip("/").replace("/", "_") + return f"{self.server_execute_scope_prefix}{normalized_path}:edit" + + def get_tool_execute_scope(self, server_path: str, tool_name: str) -> str: + """Get the execute scope for a specific tool.""" + # Remove leading slash and replace remaining slashes with underscores + normalized_path = server_path.lstrip("/").replace("/", "_") + return f"{self.server_execute_scope_prefix}{normalized_path}:tool:{tool_name}:execute" + + def load_from_env(self, env_dict: dict) -> "AuthSettings": + """Load settings from environment variables.""" + self.enabled = env_dict.get("MCP_AUTH_ENABLED", "true").lower() == "true" + + if not self.enabled: + return self + + provider_type = env_dict.get("MCP_AUTH_PROVIDER_TYPE", "").lower() + if not provider_type: + return self + + self.idp_settings = IdPSettings( + provider_type=provider_type, + client_id=env_dict.get("MCP_AUTH_CLIENT_ID", ""), + client_secret=env_dict.get("MCP_AUTH_CLIENT_SECRET", ""), + authorize_url=env_dict.get("MCP_AUTH_AUTHORIZE_URL", ""), + token_url=env_dict.get("MCP_AUTH_TOKEN_URL", ""), + jwks_url=env_dict.get("MCP_AUTH_JWKS_URL", ""), + callback_uri=env_dict.get("MCP_AUTH_CALLBACK_URI", ""), + scopes=env_dict.get("MCP_AUTH_SCOPES", "openid profile email").split(), + audience=env_dict.get("MCP_AUTH_AUDIENCE"), + issuer=env_dict.get("MCP_AUTH_ISSUER") + ) + + # Override default scopes if specified + registry_admin = env_dict.get("MCP_AUTH_REGISTRY_ADMIN_SCOPE") + if registry_admin: + self.registry_admin_scope = registry_admin + + registry_read = env_dict.get("MCP_AUTH_REGISTRY_READ_SCOPE") + if registry_read: + self.registry_read_scope = registry_read + + server_prefix = env_dict.get("MCP_AUTH_SERVER_SCOPE_PREFIX") + if server_prefix: + self.server_execute_scope_prefix = server_prefix + + server_suffix = env_dict.get("MCP_AUTH_SERVER_SCOPE_SUFFIX") + if server_suffix: + self.server_execute_scope_suffix = server_suffix + + return self \ No newline at end of file diff --git a/registry/main.py b/registry/main.py index 5075c9e..0fb323c 100644 --- a/registry/main.py +++ b/registry/main.py @@ -3,16 +3,29 @@ import secrets import asyncio import subprocess +import urllib.parse +import httpx + # argparse removed as we're using environment variables instead from contextlib import asynccontextmanager -from pathlib import Path # Import Path -from typing import Annotated, List, Set +from pathlib import Path as PathLib # Rename Path import to avoid conflict +from typing import Annotated, List, Set, Dict, Any, Union, Callable, Awaitable from datetime import datetime, timezone +import re +from registry.auth.settings import AuthSettings +from itsdangerous import URLSafeTimedSerializer +import uvicorn +import time import faiss import numpy as np from sentence_transformers import SentenceTransformer + +from mcp import ClientSession +from mcp.client.sse import sse_client + + # Get configuration from environment variables EMBEDDINGS_MODEL_NAME = os.environ.get('EMBEDDINGS_MODEL_NAME', 'all-MiniLM-L6-v2') EMBEDDINGS_MODEL_DIMENSIONS = int(os.environ.get('EMBEDDINGS_MODEL_DIMENSIONS', '384')) @@ -27,28 +40,63 @@ Cookie, WebSocket, WebSocketDisconnect, + Response, + File, + UploadFile, + Query, ) from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates + +# Custom exception for redirecting to login +class RedirectToLogin(Exception): + """Exception raised when user needs to be redirected to login page.""" + pass from itsdangerous import URLSafeTimedSerializer, SignatureExpired, BadSignature from dotenv import load_dotenv import logging -# --- MCP Client Imports --- START -from mcp import ClientSession -from mcp.client.sse import sse_client -# --- MCP Client Imports --- END +# Import OAuth integration +from registry.auth.integration import integrate_oauth + +# Lightweight session invalidation using session timestamps +# Instead of storing all invalidated sessions, we track logout times +SESSION_LOGOUT_TIMES: Dict[str, float] = {} +MAX_LOGOUT_ENTRIES = 1000 # Keep only recent logout times + +# Cache for SECRET_KEY validation to avoid repeated validation +_SECRET_KEY_VALIDATED = False +_LAST_VALIDATED_SECRET_KEY = None + +# --- OAuth 2.1 Integration --- START +from registry.auth.middleware import SessionUser + +from registry.auth.integration import integrate_oauth +from registry.auth.middleware import ( + requires_scope, requires_server_access, requires_server_toggle, requires_server_edit, + require_toggle_for_path, require_edit_for_path, require_access_for_path, + check_admin_scope, require_registry_admin +) +from fastapi import Path, Depends, Form, Request, HTTPException, status, WebSocket, WebSocketDisconnect +# --- OAuth 2.1 Integration --- END # --- Define paths based on container structure --- START -CONTAINER_APP_DIR = Path("/app") +CONTAINER_APP_DIR = PathLib("/app") CONTAINER_REGISTRY_DIR = CONTAINER_APP_DIR / "registry" CONTAINER_LOG_DIR = CONTAINER_APP_DIR / "logs" EMBEDDINGS_MODEL_DIR = CONTAINER_REGISTRY_DIR / "models" / EMBEDDINGS_MODEL_NAME # --- Define paths based on container structure --- END +# Helper function to run async dependencies +async def run_async_dependency(dependency, kwargs): + """Run an async dependency with the given kwargs.""" + if asyncio.iscoroutinefunction(dependency): + return await dependency(**kwargs) + return dependency(**kwargs) + # Determine the base directory of this script (registry folder) -# BASE_DIR = Path(__file__).resolve().parent # Less relevant inside container +# BASE_DIR = PathLib(__file__).resolve().parent # Less relevant inside container # --- Load .env if it exists in the expected location relative to the app --- START # Assumes .env might be mounted at /app/.env or similar @@ -66,7 +114,10 @@ # NGINX_CONFIG_PATH = ( # CONTAINER_REGISTRY_DIR / "nginx_mcp_revproxy.conf" # ) -NGINX_CONFIG_PATH = Path("/etc/nginx/conf.d/nginx_rev_proxy.conf") # Target the actual Nginx config file +NGINX_CONFIG_PATH = PathLib("/etc/nginx/conf.d/nginx_rev_proxy.conf") # Target the actual Nginx config file + +# Force dev mode to be enabled for easier startup +os.environ["MCP_GATEWAY_DEV_MODE"] = "true" # Use the mounted volume path for server definitions SERVERS_DIR = CONTAINER_REGISTRY_DIR / "servers" STATIC_DIR = CONTAINER_REGISTRY_DIR / "static" @@ -92,24 +143,6 @@ next_faiss_id_counter = 0 # --- FAISS Vector DB Configuration --- END -# --- REMOVE Logging Setup from here --- START -# # Ensure log directory exists -# CONTAINER_LOG_DIR.mkdir(parents=True, exist_ok=True) -# -# # Configure logging -# logging.basicConfig( -# level=logging.INFO, -# format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", -# handlers=[ -# logging.FileHandler(LOG_FILE_PATH), # Log to file in /app/logs -# logging.StreamHandler() # Log to console (stdout/stderr) -# ] -# ) -# -# logger = logging.getLogger(__name__) # Get a logger instance -# logger.info("Logging configured. Application starting...") -# --- REMOVE Logging Setup from here --- END - # --- Define logger at module level (unconfigured initially) --- START # Configure logging with process ID, filename, line number, and millisecond precision logging.basicConfig( @@ -128,9 +161,153 @@ HEALTH_CHECK_TIMEOUT_SECONDS = 10 # Timeout for each curl check (Increased to 10) SERVER_LAST_CHECK_TIME = {} # path -> datetime of last check attempt (UTC) +# Force all servers to healthy status in development mode +def dev_mode_mark_servers_healthy(): + """Mark all servers as healthy in development mode.""" + if os.environ.get("MCP_GATEWAY_DEV_MODE", "").lower() in ("1", "true", "yes"): + logger.info("Development mode enabled - marking all servers as healthy") + for path in REGISTERED_SERVERS: + SERVER_HEALTH_STATUS[path] = "healthy" + SERVER_LAST_CHECK_TIME[path] = datetime.now(timezone.utc) + + # Get service info for this path + service_info = REGISTERED_SERVERS[path] + + # Set real tools for known servers in dev mode + server_name = path.lstrip("/") + server_path = os.path.join(os.environ.get("SERVER_DIR", "/Users/aaronbw/Documents/DEV/v1/mcp-gateway/servers"), server_name) + server_py_path = os.path.join(server_path, "server.py") + + # Check if this server has a server.py file and we haven't set tools yet + if os.path.exists(server_py_path) and not service_info.get("real_tools_set"): + # Try to automatically extract tools from server.py + extracted_tools = try_extract_tools_from_server_py(server_py_path) + + if extracted_tools: + # Only set tools if we successfully extracted them + service_info["num_tools"] = len(extracted_tools) + service_info["tool_list"] = extracted_tools + logger.info(f"Using {len(extracted_tools)} automatically extracted tools for {path}") + else: + # If extraction failed, leave tools as-is - don't show anything until the real tools are available + logger.warning(f"Failed to extract tools from {path}, no tools will be shown in dev mode") + # Make sure we don't have any old placeholder values by explicitly setting to empty + service_info["num_tools"] = 0 + service_info["tool_list"] = [] + + # Mark that we've set real tools for this server + service_info["real_tools_set"] = True + REGISTERED_SERVERS[path] = service_info + logger.info(f"Set real tools for {path} in dev mode") + # Never use placeholder tools - if we don't have real tools, set an empty list + elif not service_info.get("tool_list"): + service_info["num_tools"] = 0 + service_info["tool_list"] = [] + REGISTERED_SERVERS[path] = service_info + logger.info(f"No tools set for {path} - waiting for real tools") + # --- WebSocket Connection Management --- active_connections: Set[WebSocket] = set() +# --- Helper for extracting tools from server.py files --- START +def try_extract_tools_from_server_py(server_py_path: str) -> list: + """ + Attempts to extract tool definitions from a server.py file. + This is used in dev mode to provide realistic tool definitions without having to query the server. + + Args: + server_py_path: Path to the server.py file + + Returns: + List of tool definitions if successful, empty list otherwise + """ + try: + # Check if the file exists + if not os.path.exists(server_py_path): + logger.warning(f"Server.py file not found at {server_py_path}") + return [] + + # Read the file + with open(server_py_path, "r") as f: + content = f.read() + + # Look for @mcp.tool() decorated functions + tool_pattern = r'@mcp\.tool\(\).*?def\s+(\w+)\(([^)]*)\)' + param_pattern = r'(\w+)\s*:\s*Annotated\[(\w+),\s*Field\(\s*(?:[^)]*?description\s*=\s*"([^"]*)")?(?:[^)]*?default\s*=\s*"?([^",\)]*)"?)?' + + tools = [] + matches = re.finditer(tool_pattern, content, re.DOTALL) + + for match in matches: + try: + func_name = match.group(1) + params_str = match.group(2) + + # Try to extract docstring for description + func_block = content[match.end():].split('\n\n')[0] + desc_match = re.search(r'"""(.*?)"""', func_block, re.DOTALL) + description = desc_match.group(1).strip() if desc_match else f"Function {func_name}" + + # Short description (first sentence) + short_desc = description.split('.')[0].strip() + + # Extract parameters + parameters = { + "type": "object", + "properties": {} + } + + param_matches = re.finditer(param_pattern, params_str, re.DOTALL) + for param_match in param_matches: + param_name = param_match.group(1) + param_type = param_match.group(2).lower() + param_desc = param_match.group(3) if param_match.group(3) else f"Parameter {param_name}" + param_default = param_match.group(4) if param_match.group(4) else None + + param_info = { + "type": "string" if param_type == "str" else + "number" if param_type in ["int", "float"] else + "boolean" if param_type == "bool" else + "string", + "description": param_desc + } + + if param_default: + if param_type == "str": + param_info["default"] = param_default + elif param_type == "int": + try: + param_info["default"] = int(param_default) + except: + pass + elif param_type == "float": + try: + param_info["default"] = float(param_default) + except: + pass + elif param_type == "bool": + param_info["default"] = param_default.lower() in ["true", "1", "yes"] + + parameters["properties"][param_name] = param_info + + tools.append({ + "name": func_name, + "description": short_desc, + "parameters": parameters + }) + + except Exception as e: + logger.warning(f"Error extracting tool info for {match.group(1)}: {e}") + continue + + logger.info(f"Successfully extracted {len(tools)} tools from {server_py_path}") + return tools + + except Exception as e: + logger.warning(f"Failed to extract tools from {server_py_path}: {e}") + return [] +# --- Helper for extracting tools from server.py files --- END + # --- FAISS Helper Functions --- START def _get_text_for_embedding(server_info: dict) -> str: @@ -157,7 +334,7 @@ def load_faiss_data(): os.environ['SENTENCE_TRANSFORMERS_HOME'] = str(model_cache_path) # Check if the model path exists and is not empty - model_path = Path(EMBEDDINGS_MODEL_PATH) + model_path = PathLib(EMBEDDINGS_MODEL_PATH) model_exists = model_path.exists() and any(model_path.iterdir()) if model_path.exists() else False if model_exists: @@ -314,37 +491,36 @@ async def broadcast_health_status(): # --- Add num_tools --- END message = json.dumps(data_to_send) - - # Keep track of connections that fail during send disconnected_clients = set() - # Iterate over a copy of the set to allow modification during iteration - current_connections = list(active_connections) - - # Create send tasks and associate them with the connection + # Concurrent sending + current_connections = list(active_connections) # Make a copy of current set send_tasks = [] - for conn in current_connections: - send_tasks.append((conn, conn.send_text(message))) - # Run tasks concurrently and check results + # Schedule all send operations + for connection in current_connections: + # Store connection and task together for easier error handling + send_tasks.append((connection, connection.send_text(message))) + + # Wait for all to complete results = await asyncio.gather(*(task for _, task in send_tasks), return_exceptions=True) + # Process results for i, result in enumerate(results): - conn, _ = send_tasks[i] # Get the corresponding connection + conn, _ = send_tasks[i] if isinstance(result, Exception): - # Check if it's a connection-related error (more specific checks possible) - # For now, assume any exception during send means the client is gone - logger.warning(f"Error sending to WebSocket client {conn.client}: {result}. Marking for removal.") + logger.warning(f"Error sending status update to WebSocket client {conn.client}: {result}. Marking for removal.") disconnected_clients.add(conn) - # Remove all disconnected clients identified during the broadcast + # Remove any disconnected clients if disconnected_clients: logger.info(f"Removing {len(disconnected_clients)} disconnected clients after broadcast.") for conn in disconnected_clients: if conn in active_connections: active_connections.remove(conn) -# Session management configuration +# --- Setup FastAPI Application --- + # Session management configuration SECRET_KEY = os.environ.get("SECRET_KEY") if not SECRET_KEY: @@ -357,1072 +533,1593 @@ async def broadcast_health_status(): signer = URLSafeTimedSerializer(SECRET_KEY) SESSION_MAX_AGE_SECONDS = 60 * 60 * 8 # 8 hours -# --- Nginx Config Generation --- - -LOCATION_BLOCK_TEMPLATE = """ - location {path}/ {{ - proxy_pass {proxy_pass_url}; - proxy_http_version 1.1; - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - }} -""" - -COMMENTED_LOCATION_BLOCK_TEMPLATE = """ -# location {path}/ {{ -# proxy_pass {proxy_pass_url}; -# proxy_http_version 1.1; -# proxy_set_header Host $host; -# proxy_set_header X-Real-IP $remote_addr; -# proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; -# proxy_set_header X-Forwarded-Proto $scheme; -# }} -""" - -def regenerate_nginx_config(): - """Generates the nginx config file based on registered servers and their state.""" - logger.info(f"Attempting to directly modify Nginx config at {NGINX_CONFIG_PATH}...") +# Lifespan handler to initialize and cleanup resources +@asynccontextmanager +async def lifespan(app: FastAPI): + # --- Startup Code --- + logger.info("Application startup. Initializing...") - # Define markers - START_MARKER = "# DYNAMIC_LOCATIONS_START" - END_MARKER = "# DYNAMIC_LOCATIONS_END" + # Create paths if they don't exist + # --- Ensure Directories Exist --- START + CONTAINER_LOG_DIR.mkdir(parents=True, exist_ok=True) + STATIC_DIR.mkdir(parents=True, exist_ok=True) + TEMPLATES_DIR.mkdir(parents=True, exist_ok=True) + SERVERS_DIR.mkdir(parents=True, exist_ok=True) + # --- Ensure Directories Exist --- END + # --- Load FAISS data and model --- START + logger.info("Pre-loading FAISS data and model...") try: - # Read the *target* Nginx config file - with open(NGINX_CONFIG_PATH, 'r') as f_target: - target_content = f_target.read() - - # Generate the location blocks section content (only needs to be done once) - location_blocks_content = [] - sorted_paths = sorted(REGISTERED_SERVERS.keys()) - - for path in sorted_paths: - server_info = REGISTERED_SERVERS[path] - proxy_url = server_info.get("proxy_pass_url") - is_enabled = MOCK_SERVICE_STATE.get(path, False) - health_status = SERVER_HEALTH_STATUS.get(path) - - if not proxy_url: - logger.warning(f"Skipping server '{server_info['server_name']}' ({path}) - missing proxy_pass_url.") - continue - - if is_enabled and health_status == "healthy": - block = LOCATION_BLOCK_TEMPLATE.format(path=path, proxy_pass_url=proxy_url) - else: - block = COMMENTED_LOCATION_BLOCK_TEMPLATE.format(path=path, proxy_pass_url=proxy_url) - location_blocks_content.append(block) - - generated_section = "\n".join(location_blocks_content).strip() - - # --- Replace content between ALL marker pairs --- START - new_content = "" - current_pos = 0 - while True: - # Find the next start marker - start_index = target_content.find(START_MARKER, current_pos) - if start_index == -1: - # No more start markers found, append the rest of the file - new_content += target_content[current_pos:] - break - - # Find the corresponding end marker after the start marker - end_index = target_content.find(END_MARKER, start_index + len(START_MARKER)) - if end_index == -1: - # Found a start marker without a matching end marker, log error and stop - logger.error(f"Found '{START_MARKER}' at position {start_index} without a matching '{END_MARKER}' in {NGINX_CONFIG_PATH}. Aborting regeneration.") - # Append the rest of the file to avoid data loss, but don't reload - new_content += target_content[current_pos:] - # Write back the partially processed content? Or just return False? - # Let's return False to indicate failure without modifying the file potentially incorrectly. - return False # Indicate failure - - # Append the content before the current start marker - new_content += target_content[current_pos:start_index + len(START_MARKER)] - # Append the newly generated section (with appropriate newlines) - new_content += f"\n\n{generated_section}\n\n " - # Update current position to be after the end marker - current_pos = end_index - - # Check if any replacements were made (i.e., if current_pos moved beyond 0) - if current_pos == 0: - logger.error(f"No marker pairs '{START_MARKER}'...'{END_MARKER}' found in {NGINX_CONFIG_PATH}. Cannot regenerate.") - return False - - final_config = new_content # Use the iteratively built content - # --- Replace content between ALL marker pairs --- END - - # # Find the start and end markers in the target content - # start_index = target_content.find(START_MARKER) - # end_index = target_content.find(END_MARKER) - # - # if start_index == -1 or end_index == -1 or end_index <= start_index: - # logger.error(f"Markers '{START_MARKER}' and/or '{END_MARKER}' not found or in wrong order in {NGINX_CONFIG_PATH}. Cannot regenerate.") - # return False - # - # # Extract the parts before the start marker and after the end marker - # prefix = target_content[:start_index + len(START_MARKER)] - # suffix = target_content[end_index:] - # - # # Construct the new content - # # Add newlines around the generated section for readability - # final_config = f"{prefix}\n\n{generated_section}\n\n {suffix}" - - # Write the modified content back to the target file - with open(NGINX_CONFIG_PATH, 'w') as f_out: - f_out.write(final_config) - logger.info(f"Nginx config file {NGINX_CONFIG_PATH} modified successfully.") - - # --- Reload Nginx --- START + # We do this in a thread to not block + await asyncio.to_thread(load_faiss_data) + except Exception as e: + logger.error(f"Error pre-loading FAISS: {e}", exc_info=True) + # --- Load FAISS data and model --- END + + # --- Load server state --- START + if STATE_FILE_PATH.exists(): try: - logger.info("Attempting to reload Nginx configuration...") - result = subprocess.run(['/usr/sbin/nginx', '-s', 'reload'], check=True, capture_output=True, text=True) - logger.info(f"Nginx reload successful. stdout: {result.stdout.strip()}") - return True - except FileNotFoundError: - logger.error("'nginx' command not found. Cannot reload configuration.") - return False - except subprocess.CalledProcessError as e: - logger.error(f"Failed to reload Nginx configuration. Return code: {e.returncode}") - logger.error(f"Nginx reload stderr: {e.stderr.strip()}") - logger.error(f"Nginx reload stdout: {e.stdout.strip()}") - return False + with open(STATE_FILE_PATH, "r") as f: + MOCK_SERVICE_STATE.update(json.load(f)) + logger.info(f"Loaded server state from {STATE_FILE_PATH} with {len(MOCK_SERVICE_STATE)} entries") except Exception as e: - logger.error(f"An unexpected error occurred during Nginx reload: {e}", exc_info=True) - return False - # --- Reload Nginx --- END - - except FileNotFoundError: - logger.error(f"Target Nginx config file not found at {NGINX_CONFIG_PATH}. Cannot regenerate.") - return False - except Exception as e: - logger.error(f"Failed to modify Nginx config at {NGINX_CONFIG_PATH}: {e}", exc_info=True) - return False - -# --- Helper function to normalize a path to a filename --- -def path_to_filename(path): - # Remove leading slash and replace remaining slashes with underscores - normalized = path.lstrip("/").replace("/", "_") - # Append .json extension if not present - if not normalized.endswith(".json"): - normalized += ".json" - return normalized - - -# --- Data Loading --- -def load_registered_servers_and_state(): - global REGISTERED_SERVERS, MOCK_SERVICE_STATE + logger.error(f"Error loading server state: {e}") + else: + logger.info(f"No server state file found at {STATE_FILE_PATH}. Starting with empty state.") + # --- Load server state --- END + + # --- Load existing server JSON files --- START logger.info(f"Loading server definitions from {SERVERS_DIR}...") - - # Create servers directory if it doesn't exist - SERVERS_DIR.mkdir(parents=True, exist_ok=True) # Added parents=True - - temp_servers = {} - server_files = list(SERVERS_DIR.glob("**/*.json")) - logger.info(f"Found {len(server_files)} JSON files in {SERVERS_DIR} and its subdirectories") - for file in server_files: - logger.info(f"[DEBUG] - {file.relative_to(SERVERS_DIR)}") - - if not server_files: - logger.warning(f"No server definition files found in {SERVERS_DIR}. Initializing empty registry.") - REGISTERED_SERVERS = {} - # Don't return yet, need to load state file - # return - - for server_file in server_files: - if server_file.name == STATE_FILE_PATH.name: # Skip the state file itself - continue - try: - with open(server_file, "r") as f: - server_info = json.load(f) - - if ( - isinstance(server_info, dict) - and "path" in server_info - and "server_name" in server_info - ): - server_path = server_info["path"] - if server_path in temp_servers: - logger.warning(f"Duplicate server path found in {server_file}: {server_path}. Overwriting previous definition.") - - # Add new fields with defaults - server_info["description"] = server_info.get("description", "") - server_info["tags"] = server_info.get("tags", []) - server_info["num_tools"] = server_info.get("num_tools", 0) - server_info["num_stars"] = server_info.get("num_stars", 0) - server_info["is_python"] = server_info.get("is_python", False) - server_info["license"] = server_info.get("license", "N/A") - server_info["proxy_pass_url"] = server_info.get("proxy_pass_url", None) - server_info["tool_list"] = server_info.get("tool_list", []) # Initialize tool_list if missing - - temp_servers[server_path] = server_info + if SERVERS_DIR.exists() and SERVERS_DIR.is_dir(): + json_server_files = list(SERVERS_DIR.glob("*.json")) + logger.info(f"Found {len(json_server_files)} JSON files in {SERVERS_DIR}") + + for server_file in json_server_files: + try: + # Skip _metadata and _index, these aren't service files + if server_file.name.startswith("service_index_"): + continue + + with open(server_file, "r") as f: + server_data = json.load(f) + + # Check if this is a server definition (has path, server_name, proxy_pass_url) + if "path" in server_data and "server_name" in server_data and "proxy_pass_url" in server_data: + path = server_data["path"] + # Register the service in memory + REGISTERED_SERVERS[path] = server_data + # Mark it as either enabled or disabled, defaulting to disabled if not in state + is_enabled = MOCK_SERVICE_STATE.get(path, False) + MOCK_SERVICE_STATE[path] = is_enabled + # Initialize health status for this service + SERVER_HEALTH_STATUS[path] = "unknown" + logger.info(f"Loaded server definition for {server_data['server_name']} at {path} (enabled={is_enabled})") else: - logger.warning(f"Invalid server entry format found in {server_file}. Skipping.") - except FileNotFoundError: - logger.error(f"Server definition file {server_file} reported by glob not found.") - except json.JSONDecodeError as e: - logger.error(f"Could not parse JSON from {server_file}: {e}.") - except Exception as e: - logger.error(f"An unexpected error occurred loading {server_file}: {e}", exc_info=True) - - REGISTERED_SERVERS = temp_servers - logger.info(f"Successfully loaded {len(REGISTERED_SERVERS)} server definitions.") - - # --- Load persisted mock service state --- START - logger.info(f"Attempting to load persisted state from {STATE_FILE_PATH}...") - loaded_state = {} + logger.warning(f"Skipping incomplete server definition in {server_file}") + except Exception as e: + logger.error(f"Error loading server definition from {server_file}: {e}") + else: + logger.warning(f"Servers directory not found or not a directory: {SERVERS_DIR}") + # --- Load existing server JSON files --- END + + # --- Mark servers as healthy in dev mode --- START + dev_mode_mark_servers_healthy() + # --- Mark servers as healthy in dev mode --- END + + # --- Generate initial Nginx config --- START + logger.info("Generating initial Nginx configuration...") + regenerate_nginx_config() + # --- Generate initial Nginx config --- END + + # --- Start Health Check Background Task --- START + # Start the background health check task + logger.info("Starting health check background task...") + asyncio.create_task(run_health_checks()) + # --- Start Health Check Background Task --- END + + logger.info("Application initialization complete. Ready to serve requests.") + + yield # Application runs here + + # --- Cleanup Code --- + logger.info("Application shutdown. Cleaning up...") + + # Ensure latest service state is saved try: - if STATE_FILE_PATH.exists(): - with open(STATE_FILE_PATH, "r") as f: - loaded_state = json.load(f) - if not isinstance(loaded_state, dict): - logger.warning(f"Invalid state format in {STATE_FILE_PATH}. Expected a dictionary. Resetting state.") - loaded_state = {} # Reset if format is wrong - else: - logger.info("Successfully loaded persisted state.") - else: - logger.info(f"No persisted state file found at {STATE_FILE_PATH}. Initializing state.") - - except json.JSONDecodeError as e: - logger.error(f"Could not parse JSON from {STATE_FILE_PATH}: {e}. Initializing empty state.") - loaded_state = {} + with open(STATE_FILE_PATH, "w") as f: + json.dump(MOCK_SERVICE_STATE, f, indent=2) + logger.info(f"Saved server state to {STATE_FILE_PATH} with {len(MOCK_SERVICE_STATE)} entries") except Exception as e: - logger.error(f"Failed to read state file {STATE_FILE_PATH}: {e}. Initializing empty state.", exc_info=True) - loaded_state = {} - - # Initialize MOCK_SERVICE_STATE: Use loaded state if valid, otherwise default to False. - # Ensure state only contains keys for currently registered servers. - MOCK_SERVICE_STATE = {} - for path in REGISTERED_SERVERS.keys(): - MOCK_SERVICE_STATE[path] = loaded_state.get(path, False) # Default to False if not in loaded state or state was invalid - - logger.info(f"Initial mock service state loaded: {MOCK_SERVICE_STATE}") - # --- Load persisted mock service state --- END - - - # Initialize health status to 'checking' or 'disabled' based on the just loaded state - global SERVER_HEALTH_STATUS - SERVER_HEALTH_STATUS = {} # Start fresh - for path, is_enabled in MOCK_SERVICE_STATE.items(): - if path in REGISTERED_SERVERS: # Should always be true here now - SERVER_HEALTH_STATUS[path] = "checking" if is_enabled else "disabled" - else: - # This case should ideally not happen if MOCK_SERVICE_STATE is built from REGISTERED_SERVERS - logger.warning(f"Path {path} found in loaded state but not in registered servers. Ignoring.") + logger.error(f"Error saving server state: {e}") + + # Any other cleanup goes here... + logger.info("Cleanup complete. Application shutting down.") + +# Create FastAPI application +app = FastAPI( + title="MCP Gateway Registry Service", + description="Registry service for MCP Gateway to manage server registrations", + lifespan=lifespan, +) - logger.info(f"Initialized health status based on loaded state: {SERVER_HEALTH_STATUS}") +# Add exception handler for RedirectToLogin +@app.exception_handler(RedirectToLogin) +async def redirect_to_login_handler(request: Request, exc: RedirectToLogin): + """Handle RedirectToLogin exception by redirecting to login page.""" + logger.info(f"Redirecting unauthenticated user to login page from {request.url.path}") + return RedirectResponse(url="/login", status_code=status.HTTP_303_SEE_OTHER) - # We no longer need the explicit default initialization block below - # print("Initializing mock service state (defaulting to disabled)...") - # MOCK_SERVICE_STATE = {path: False for path in REGISTERED_SERVERS.keys()} - # # TODO: Consider loading initial state from a persistent store if needed - # print(f"Initial mock state: {MOCK_SERVICE_STATE}") +# Mount static files directory +app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") +# Set up templates +templates = Jinja2Templates(directory=TEMPLATES_DIR) -# --- Helper function to save server data --- -def save_server_to_file(server_info): +# --- Set up OAuth 2.1 integration --- START +# Call our integration function to set up OAuth 2.1 with the app +oauth_provider = integrate_oauth(app, templates) +logger.info(f"OAuth 2.1 integration setup: {'ENABLED' if oauth_provider else 'DISABLED'}") +# --- Set up OAuth 2.1 integration --- END + +# -- Authentication Helper Functions -- + +def get_session_fingerprint(session_data: dict) -> str: + """Generate a unique session fingerprint using session_id.""" + username = session_data.get("username", "") + session_id = session_data.get("session_id", "") + return f"{username}:{session_id}" + +def cleanup_logout_times(): + """Clean up old logout times to prevent memory leaks.""" + global SESSION_LOGOUT_TIMES + if len(SESSION_LOGOUT_TIMES) > MAX_LOGOUT_ENTRIES: + # Keep only the most recent entries + sorted_items = sorted(SESSION_LOGOUT_TIMES.items(), key=lambda x: x[1]) + SESSION_LOGOUT_TIMES = dict(sorted_items[-MAX_LOGOUT_ENTRIES//2:]) + +def is_session_logged_out(session_data: dict) -> bool: + """Check if session was logged out after its creation time.""" + fingerprint = get_session_fingerprint(session_data) + logout_time = SESSION_LOGOUT_TIMES.get(fingerprint) + + if logout_time is None: + return False + + # Parse session login time + login_time_str = session_data.get("login_time", "") + if not login_time_str: + return False + try: - # Create servers directory if it doesn't exist - SERVERS_DIR.mkdir(parents=True, exist_ok=True) # Ensure it exists - - # Generate filename based on path - path = server_info["path"] - filename = path_to_filename(path) - file_path = SERVERS_DIR / filename - - with open(file_path, "w") as f: - json.dump(server_info, f, indent=2) - - logger.info(f"Successfully saved server '{server_info['server_name']}' to {file_path}") - return True - except Exception as e: - logger.error(f"Failed to save server '{server_info.get('server_name', 'UNKNOWN')}' data to {filename}: {e}", exc_info=True) + from datetime import datetime + login_time = datetime.fromisoformat(login_time_str.replace('Z', '+00:00')) + login_timestamp = login_time.timestamp() + return logout_time > login_timestamp + except (ValueError, AttributeError): return False +def handle_auth_failure(request: Request, detail: str): + """Handle authentication failure by redirecting browser requests or raising HTTPException for API requests.""" + accept_header = request.headers.get("accept", "") + is_browser_request = "text/html" in accept_header + + if is_browser_request: + logger.info(f"Browser request detected, redirecting to login page. Reason: {detail}") + raise RedirectToLogin() + else: + logger.info(f"API request detected, returning 401. Reason: {detail}") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=detail, + headers={"WWW-Authenticate": "Bearer"}, + ) -# --- MCP Client Function to Get Tool List --- START (Renamed) -async def get_tools_from_server(base_url: str) -> List[dict] | None: # Return list of dicts +def validate_secret_key(secret_key: str) -> None: """ - Connects to an MCP server via SSE, lists tools, and returns their details - (name, description, schema). - + Validate that the SECRET_KEY meets security requirements. + Args: - base_url: The base URL of the MCP server (e.g., http://localhost:8000). - - Returns: - A list of tool detail dictionaries (keys: name, description, schema), - or None if connection/retrieval fails. + secret_key: The secret key to validate + + Raises: + ValueError: If the secret key is weak or insecure """ - # Determine scheme and construct the full /sse URL - if not base_url: - logger.error("MCP Check Error: Base URL is empty.") - return None - - sse_url = base_url.rstrip('/') + "/sse" - # Simple check for https, might need refinement for edge cases - secure_prefix = "s" if sse_url.startswith("https://") else "" - mcp_server_url = f"http{secure_prefix}://{sse_url[len(f'http{secure_prefix}://'):]}" # Ensure correct format for sse_client - - - logger.info(f"Attempting to connect to MCP server at {mcp_server_url} to get tool list...") - try: - # Connect using the sse_client context manager directly - async with sse_client(mcp_server_url) as (read, write): - # Use the ClientSession context manager directly - async with ClientSession(read, write, sampling_callback=None) as session: - # Apply timeout to individual operations within the session - await asyncio.wait_for(session.initialize(), timeout=10.0) # Timeout for initialize - tools_response = await asyncio.wait_for(session.list_tools(), timeout=15.0) # Renamed variable - - # Extract tool details - tool_details_list = [] - if tools_response and hasattr(tools_response, 'tools'): - for tool in tools_response.tools: - # Access attributes directly based on MCP documentation - tool_name = getattr(tool, 'name', 'Unknown Name') # Direct attribute access - tool_desc = getattr(tool, 'description', None) or getattr(tool, '__doc__', None) - - # --- Parse Docstring into Sections --- START - parsed_desc = { - "main": "No description available.", - "args": None, - "returns": None, - "raises": None, - } - if tool_desc: - tool_desc = tool_desc.strip() - # Simple parsing logic (can be refined) - lines = tool_desc.split('\n') - main_desc_lines = [] - current_section = "main" - section_content = [] - - for line in lines: - stripped_line = line.strip() - if stripped_line.startswith("Args:"): - parsed_desc["main"] = "\n".join(main_desc_lines).strip() - current_section = "args" - section_content = [stripped_line[len("Args:"):].strip()] - elif stripped_line.startswith("Returns:"): - if current_section != "main": - parsed_desc[current_section] = "\n".join(section_content).strip() - else: - parsed_desc["main"] = "\n".join(main_desc_lines).strip() - current_section = "returns" - section_content = [stripped_line[len("Returns:"):].strip()] - elif stripped_line.startswith("Raises:"): - if current_section != "main": - parsed_desc[current_section] = "\n".join(section_content).strip() - else: - parsed_desc["main"] = "\n".join(main_desc_lines).strip() - current_section = "raises" - section_content = [stripped_line[len("Raises:"):].strip()] - elif current_section == "main": - main_desc_lines.append(line.strip()) # Keep leading whitespace for main desc if intended - else: - section_content.append(line.strip()) - - # Add the last collected section - if current_section != "main": - parsed_desc[current_section] = "\n".join(section_content).strip() - elif not parsed_desc["main"] and main_desc_lines: # Handle case where entire docstring was just main description - parsed_desc["main"] = "\n".join(main_desc_lines).strip() - - # Ensure main description has content if others were parsed but main was empty - if not parsed_desc["main"] and (parsed_desc["args"] or parsed_desc["returns"] or parsed_desc["raises"]): - parsed_desc["main"] = "(No primary description provided)" - - else: - parsed_desc["main"] = "No description available." - # --- Parse Docstring into Sections --- END - - tool_schema = getattr(tool, 'inputSchema', {}) # Use inputSchema attribute - - tool_details_list.append({ - "name": tool_name, - "parsed_description": parsed_desc, # Store parsed sections - "schema": tool_schema - }) - - logger.info(f"Successfully retrieved details for {len(tool_details_list)} tools from {mcp_server_url}.") - return tool_details_list # Return the list of details - except asyncio.TimeoutError: - logger.error(f"MCP Check Error: Timeout during session operation with {mcp_server_url}.") - return None - except ConnectionRefusedError: - logger.error(f"MCP Check Error: Connection refused by {mcp_server_url}.") - return None - except Exception as e: - logger.error(f"MCP Check Error: Failed to get tool list from {mcp_server_url}: {type(e).__name__} - {e}") - return None - -# --- MCP Client Function to Get Tool List --- END + global _SECRET_KEY_VALIDATED, _LAST_VALIDATED_SECRET_KEY + + # Use cache to avoid repeated validation of the same key + if _SECRET_KEY_VALIDATED and _LAST_VALIDATED_SECRET_KEY == secret_key: + return + + import re + + # Check minimum length (should be at least 32 characters for good entropy) + if len(secret_key) < 32: + raise ValueError(f"SECRET_KEY must be at least 32 characters long. Current length: {len(secret_key)}") + + # Check for sufficient complexity (should contain different character types) + has_upper = bool(re.search(r'[A-Z]', secret_key)) + has_lower = bool(re.search(r'[a-z]', secret_key)) + has_digit = bool(re.search(r'[0-9]', secret_key)) + has_special = bool(re.search(r'[^A-Za-z0-9]', secret_key)) + + complexity_score = sum([has_upper, has_lower, has_digit, has_special]) + + # For randomly generated hex keys, require at least 2 character types (numbers + letters) + # For other keys, require at least 3 character types for good complexity + min_complexity = 2 if all(c in '0123456789abcdefABCDEF' for c in secret_key) else 3 + + if complexity_score < min_complexity: + if min_complexity == 2: + raise ValueError( + "SECRET_KEY lacks sufficient complexity. For hex keys, ensure both letters and numbers are present." + ) + else: + raise ValueError( + "SECRET_KEY lacks sufficient complexity. It should contain at least 3 of: " + "uppercase letters, lowercase letters, digits, special characters" + ) + + # Check for obvious patterns + if secret_key.lower() in secret_key or secret_key == secret_key[::-1]: + # Additional pattern checks could be added here + pass + + # Mark as validated and cache the key + _SECRET_KEY_VALIDATED = True + _LAST_VALIDATED_SECRET_KEY = secret_key + logger.info("SECRET_KEY validation passed - key meets security requirements") + +def get_current_user(request: Request, session: str = Cookie(None)) -> str: + """Get the current user from session cookie, or redirect to login for browser requests.""" + SECRET_KEY = os.environ.get("SECRET_KEY", "insecure-default-key-for-testing-only") + + # Validate SECRET_KEY strength before using it for authentication + try: + validate_secret_key(SECRET_KEY) + except ValueError as e: + logger.error(f"SECRET_KEY validation failed: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Server configuration error: Invalid SECRET_KEY. Please configure a strong secret key." + ) + + session_cookie_name = "mcp_gateway_session" + default_session_expr = 3600 # Default: 1 hour + USERNAME_ENV = os.environ.get("ADMIN_USER", "admin") + PASSWORD_ENV = os.environ.get("ADMIN_PASSWORD", "password") + + # Log detailed debugging information + logger.info(f"Authentication check - Path: {request.url.path}, Host: {request.headers.get('host')}") + logger.info(f"Request cookies: {request.cookies}") + + # DEBUG: Always consider the user as authenticated during testing + disable_auth_env = os.environ.get("MCP_GATEWAY_DISABLE_AUTH", "").lower() + if disable_auth_env in ("1", "true", "yes"): + logger.warning("AUTH DISABLED: Using automatic admin access") + # Add admin user to request context for RBAC + # Create the user object but don't try to set it directly on request + user = SessionUser(USERNAME_ENV, ["mcp-admin"]) + # Store the user in request.state which is designed for this purpose + request.state.user = user + return USERNAME_ENV + + # Check for session in cookies directly if Cookie dependency didn't work + if not session and session_cookie_name in request.cookies: + session = request.cookies[session_cookie_name] + logger.info(f"Using session from request.cookies: {session[:20]}...") + + # No session? Use helper function to handle browser vs API requests + if not session: + logger.warning("No session cookie found.") + handle_auth_failure(request, "Not authenticated") + # Unsign the cookie + try: + s = URLSafeTimedSerializer(SECRET_KEY) + MAX_AGE = int(os.environ.get("SESSION_EXPIRATION_SECONDS", default_session_expr)) + data = s.loads(session, max_age=MAX_AGE) + + # Check if this session has been logged out + if is_session_logged_out(data): + fingerprint = get_session_fingerprint(data) + logger.warning(f"Session {fingerprint} was logged out after creation") + handle_auth_failure(request, "Session has been logged out") + + # If OAuth session + if data.get("is_oauth", False): + logger.info("Validated OAuth session.") + + # Extract groups from session if available + groups = data.get("groups", []) + username = data.get("username", "oauth_user") + + # Create SessionUser and attach to request + if groups: + logger.info(f"Session contains groups: {groups}") + # Create the user object but don't try to set it directly on request + # Since FastAPI's Request object doesn't support attribute setting + user = SessionUser(username, groups) + # Store the user in request.state which is designed for this purpose + request.state.user = user + # Log mapped scopes for debugging + if hasattr(user, "scopes"): + logger.info(f"User {username} groups mapped to scopes: {user.scopes}") + + return username + + # Check if regular session data looks valid + username = data.get("username") + is_authenticated = data.get("authenticated", False) + + if username and is_authenticated: + logger.debug(f"Validated session for {username}") + # Create a standard user with admin permissions for non-OAuth sessions + # Create the user object but don't try to set it directly on request + user = SessionUser(username, ["mcp-admin"]) + # Store the user in request.state which is designed for this purpose + request.state.user = user + return username + + logger.warning(f"Session found but invalid structure: {data}") + handle_auth_failure(request, "Invalid session") + except SignatureExpired: + logger.warning("Session expired") + handle_auth_failure(request, "Session expired") + except BadSignature: + logger.warning("Invalid session signature") + handle_auth_failure(request, "Invalid session") + except Exception as e: + logger.error(f"Error validating session: {e}") + handle_auth_failure(request, "Authentication error") -# --- Single Health Check Logic --- -async def perform_single_health_check(path: str) -> tuple[str, datetime | None]: - """Performs a health check for a single service path and updates global state.""" - global SERVER_HEALTH_STATUS, SERVER_LAST_CHECK_TIME, REGISTERED_SERVERS # Ensure REGISTERED_SERVERS is global - server_info = REGISTERED_SERVERS.get(path) - # --- Store previous status --- START - previous_status = SERVER_HEALTH_STATUS.get(path) # Get status before check - # --- Store previous status --- END +def api_auth(request: Request, session: str = Cookie(None)) -> str: + """Similar to get_current_user but returns UnauthorizedResponse instead of redirects.""" + try: + username = get_current_user(request, session) + + # Log the user's scopes for debugging + if hasattr(request.state, "user") and hasattr(request.state.user, "scopes"): + scopes = request.state.user.scopes + logger.info(f"API Auth - User {username} has scopes: {scopes}") + else: + logger.warning(f"API Auth - User {username} has no scopes attribute") + + return username + except HTTPException as http_exc: + # Log authentication failures + logger.warning(f"API Auth failed: {http_exc.detail}") + + # Convert HTTPException to JSON response + return JSONResponse( + status_code=http_exc.status_code, + content={"detail": http_exc.detail} + ) - if not server_info: - # Should not happen if called correctly, but handle defensively - return "error: server not registered", None - - url = server_info.get("proxy_pass_url") - is_enabled = MOCK_SERVICE_STATE.get(path, False) # Get enabled state for later check - - # --- Record check time --- - last_checked_time = datetime.now(timezone.utc) - SERVER_LAST_CHECK_TIME[path] = last_checked_time - # --- Record check time --- - - if not url: - current_status = "error: missing URL" - SERVER_HEALTH_STATUS[path] = current_status - logger.info(f"Health check skipped for {path}: Missing URL.") - # --- Regenerate Nginx if status affecting it changed --- START - if is_enabled and previous_status == "healthy": # Was healthy, now isn't (due to missing URL) - logger.info(f"Status changed from healthy for {path}, regenerating Nginx config...") - regenerate_nginx_config() - # --- Regenerate Nginx if status affecting it changed --- END - return current_status, last_checked_time - - # Update status to 'checking' before performing the check - # Only print if status actually changes to 'checking' - if previous_status != "checking": - logger.info(f"Setting status to 'checking' for {path} ({url})...") - SERVER_HEALTH_STATUS[path] = "checking" - # Optional: Consider a targeted broadcast here if immediate 'checking' feedback is desired - # await broadcast_specific_update(path, "checking", last_checked_time) - - # --- Append /sse to the health check URL --- START - health_check_url = url.rstrip('/') + "/sse" - # --- Append /sse to the health check URL --- END - - # cmd = ['curl', '--head', '-s', '-f', '--max-time', str(HEALTH_CHECK_TIMEOUT_SECONDS), url] - cmd = ['curl', '--head', '-s', '-f', '--max-time', str(HEALTH_CHECK_TIMEOUT_SECONDS), health_check_url] # Use modified URL - current_status = "checking" # Status will be updated below +# --- Authentication Routes --- +@app.post("/login") +async def login( + request: Request, + username: str = Form(...), + password: str = Form(...), +): + """Handle login form submission.""" + USERNAME_ENV = os.environ.get("ADMIN_USER", "admin") + PASSWORD_ENV = os.environ.get("ADMIN_PASSWORD", "password") + SECRET_KEY = os.environ.get("SECRET_KEY", "insecure-default-key-for-testing-only") + + # Validate SECRET_KEY strength before using it for session creation try: - proc = await asyncio.create_subprocess_exec( - *cmd, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE + validate_secret_key(SECRET_KEY) + except ValueError as e: + logger.error(f"SECRET_KEY validation failed during login: {e}") + return templates.TemplateResponse( + "login.html", + { + "request": request, + "error": "Server configuration error. Please contact administrator.", + "user_has_toggle_scope": lambda server_path: user_has_toggle_scope(request, server_path), + "user_has_edit_scope": lambda server_path: user_has_edit_scope(request, server_path), + "user_has_admin_scope": lambda: user_has_admin_scope(request) + } + ) + + session_cookie_name = "mcp_gateway_session" + session_max_age = 60 * 60 * 8 # 8 hours + + # Check credentials + if username == USERNAME_ENV and password == PASSWORD_ENV: + # Create session data + s = URLSafeTimedSerializer(SECRET_KEY) + session_data = { + "username": username, + "authenticated": True, + "created_at": datetime.now(timezone.utc).isoformat(), + "session_id": secrets.token_hex(16), + } + session_cookie = s.dumps(session_data) + + # Redirect to home with session cookie + response = RedirectResponse(url="/", status_code=status.HTTP_303_SEE_OTHER) + response.set_cookie( + key=session_cookie_name, + value=session_cookie, + max_age=session_max_age, + httponly=True, + ) + return response + else: + # Show login form with error + return templates.TemplateResponse( + "login.html", + { + "request": request, + "error": "Invalid username or password", + "user_has_toggle_scope": lambda server_path: user_has_toggle_scope(request, server_path), + "user_has_edit_scope": lambda server_path: user_has_edit_scope(request, server_path), + "user_has_admin_scope": lambda: user_has_admin_scope(request) + }, ) - # Use a slightly longer timeout for wait_for to catch process hangs - stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=HEALTH_CHECK_TIMEOUT_SECONDS + 2) - stderr_str = stderr.decode().strip() if stderr else '' - - if proc.returncode == 0: - current_status = "healthy" - logger.info(f"Health check successful for {path} ({url}).") - - # --- Check for transition to healthy state --- START - # Note: Tool list fetching moved inside the status transition check - if previous_status != "healthy": - logger.info(f"Service {path} transitioned to healthy. Regenerating Nginx config and fetching tool list...") - # --- Regenerate Nginx on transition TO healthy --- START - regenerate_nginx_config() - # --- Regenerate Nginx on transition TO healthy --- END - - # Ensure url is not None before attempting connection (redundant check as url is checked above, but safe) - if url: - tool_list = await get_tools_from_server(url) # Get the list of dicts - - if tool_list is not None: # Check if list retrieval was successful - new_tool_count = len(tool_list) - # Get current list (now list of dicts) - current_tool_list = REGISTERED_SERVERS[path].get("tool_list", []) - current_tool_count = REGISTERED_SERVERS[path].get("num_tools", 0) - - # Compare lists more carefully (simple set comparison won't work on dicts) - # Convert to comparable format (e.g., sorted list of JSON strings) - current_tool_list_str = sorted([json.dumps(t, sort_keys=True) for t in current_tool_list]) - new_tool_list_str = sorted([json.dumps(t, sort_keys=True) for t in tool_list]) - - # if set(current_tool_list) != set(tool_list) or current_tool_count != new_tool_count: - if current_tool_list_str != new_tool_list_str or current_tool_count != new_tool_count: - logger.info(f"Updating tool list for {path}. New count: {new_tool_count}.") # Simplified log - REGISTERED_SERVERS[path]["tool_list"] = tool_list # Store the new list of dicts - REGISTERED_SERVERS[path]["num_tools"] = new_tool_count # Update the count - # Save the updated server info to its file - if not save_server_to_file(REGISTERED_SERVERS[path]): - logger.error(f"ERROR: Failed to save updated tool list/count for {path} to file.") - # --- Update FAISS after tool list/count change --- START - # No explicit call here, will be handled by the one at the end of perform_single_health_check - # logger.info(f"Updating FAISS metadata for '{path}' after tool list/count update.") - # await add_or_update_service_in_faiss(path, REGISTERED_SERVERS[path]) # Moved to end - # --- Update FAISS after tool list/count change --- END - else: - logger.info(f"Tool list for {path} remains unchanged. No update needed.") - else: - logger.info(f"Failed to retrieve tool list for healthy service {path}. List/Count remains unchanged.") - # Even if tool list fetch failed, server is healthy. - # FAISS update will occur at the end of this function with current REGISTERED_SERVERS[path]. - else: - # This case should technically not be reachable due to earlier url check - logger.info(f"Cannot fetch tool list for {path}: proxy_pass_url is missing.") - # --- Check for transition to healthy state --- END - # If it was already healthy, and tools changed, the above block (current_tool_list_str != new_tool_list_str) handles it. - # The FAISS update with the latest REGISTERED_SERVERS[path] will happen at the end of this function. - - elif proc.returncode == 28: - current_status = f"error: timeout ({HEALTH_CHECK_TIMEOUT_SECONDS}s)" - logger.info(f"Health check timeout for {path} ({url})") - elif proc.returncode == 22: # HTTP error >= 400 - current_status = "unhealthy (HTTP error)" - logger.info(f"Health check unhealthy (HTTP >= 400) for {path} ({url}). Stderr: {stderr_str}") - elif proc.returncode == 7: # Connection failed - current_status = "error: connection failed" - logger.info(f"Health check connection failed for {path} ({url}). Stderr: {stderr_str}") - else: # Other curl errors - error_msg = f"error: check failed (code {proc.returncode})" - if stderr_str: - error_msg += f" - {stderr_str}" - current_status = error_msg - logger.info(f"Health check failed for {path} ({url}): {error_msg}") - - except asyncio.TimeoutError: - # This catches timeout on asyncio.wait_for, slightly different from curl's --max-time - current_status = "error: check process timeout" - logger.info(f"Health check asyncio.wait_for timeout for {path} ({url})") - except FileNotFoundError: - current_status = "error: command not found" - logger.error(f"ERROR: 'curl' command not found during health check for {path}. Cannot perform check.") - # No need to stop all checks, just this one fails - except Exception as e: - current_status = f"error: {type(e).__name__}" - logger.error(f"ERROR: Unexpected error during health check for {path} ({url}): {e}") - - # Update the global status *after* the check completes - SERVER_HEALTH_STATUS[path] = current_status - logger.info(f"Final health status for {path}: {current_status}") - # --- Update FAISS with final server_info state after health check attempt --- - if path in REGISTERED_SERVERS and embedding_model and faiss_index is not None: - logger.info(f"Updating FAISS metadata for '{path}' post health check (status: {current_status}).") - await add_or_update_service_in_faiss(path, REGISTERED_SERVERS[path]) - # --- Regenerate Nginx if status affecting it changed --- START - # Check if the service is enabled AND its Nginx-relevant status changed - if is_enabled: - if previous_status == "healthy" and current_status != "healthy": - logger.info(f"Status changed FROM healthy for enabled service {path}, regenerating Nginx config...") - regenerate_nginx_config() - # Regeneration on transition TO healthy is handled within the proc.returncode == 0 block above - # elif previous_status != "healthy" and current_status == "healthy": - # print(f"Status changed TO healthy for {path}, regenerating Nginx config...") - # regenerate_nginx_config() # Already handled above - # --- Regenerate Nginx if status affecting it changed --- END +def get_idp_logout_url_fast(provider_type: str, request: Request) -> str: + """Generate IdP logout URL in a provider-agnostic way (optimized for speed).""" + try: + # Quick check for Cognito without heavy AuthSettings initialization + if provider_type == "cognito": + cognito_domain = os.environ.get("MCP_AUTH_COGNITO_DOMAIN") + client_id = os.environ.get("MCP_AUTH_CLIENT_ID") + + if cognito_domain and client_id: + timestamp = int(datetime.now(timezone.utc).timestamp()) + scheme = request.url.scheme or "http" + host = request.headers.get('host', 'localhost:7860') + return_uri = f"{scheme}://{host}/login?t={timestamp}&signed_out=true&complete=true" + logout_url = f"https://{cognito_domain}/logout?client_id={client_id}&logout_uri={urllib.parse.quote(return_uri)}" + return logout_url + + # For other providers, fall back to the original function if needed + # but for now, just return None to avoid blocking + return None + + except Exception: + # Fail silently to avoid blocking logout + return None +def get_idp_logout_url_fast(provider_type: str, request: Request) -> str: + """Generate IdP logout URL in a provider-agnostic way (optimized for speed).""" + try: + # Quick check for Cognito without heavy AuthSettings initialization + if provider_type == "cognito": + cognito_domain = os.environ.get("MCP_AUTH_COGNITO_DOMAIN") + client_id = os.environ.get("MCP_AUTH_CLIENT_ID") + + if cognito_domain and client_id: + timestamp = int(datetime.now(timezone.utc).timestamp()) + scheme = request.url.scheme or "http" + host = request.headers.get('host', 'localhost:7860') + return_uri = f"{scheme}://{host}/login?t={timestamp}&signed_out=true&complete=true" + logout_url = f"https://{cognito_domain}/logout?client_id={client_id}&logout_uri={urllib.parse.quote(return_uri)}" + return logout_url + + # For other providers, fall back to the original function if needed + # but for now, just return None to avoid blocking + return None + + except Exception: + # Fail silently to avoid blocking logout + return None - return current_status, last_checked_time +def get_idp_logout_url(provider_type: str, request: Request) -> str: + """Generate IdP logout URL in a provider-agnostic way.""" + logger.info(f"Generating IdP logout URL for provider: {provider_type}") + try: + auth_settings = AuthSettings() + logger.info(f"Auth settings - enabled: {auth_settings.enabled}, has idp_settings: {auth_settings.idp_settings is not None}") + + if not (auth_settings.enabled and auth_settings.idp_settings): + logger.warning("Auth not enabled or no IdP settings available") + return None + + timestamp = int(datetime.now(timezone.utc).timestamp()) + scheme = request.url.scheme or "http" + host = request.headers.get('host', 'localhost:7860') + return_uri = f"{scheme}://{host}/login?t={timestamp}&signed_out=true&complete=true" + logger.info(f"Generated return URI: {return_uri}") + + if provider_type == "cognito": + client_id = auth_settings.idp_settings.client_id + cognito_domain = os.environ.get("MCP_AUTH_COGNITO_DOMAIN") + logger.info(f"Cognito config - client_id: {client_id is not None}, domain: {cognito_domain}") + if cognito_domain and client_id: + logout_url = f"https://{cognito_domain}/logout?client_id={client_id}&logout_uri={urllib.parse.quote(return_uri)}" + logger.info(f"Generated Cognito logout URL: {logout_url}") + return logout_url + else: + logger.warning("Missing Cognito client_id or domain") + # Add other IdP logout URL generation here + elif provider_type == "azure": + # return azure_logout_url(auth_settings, return_uri) + pass + elif provider_type == "auth0": + # return auth0_logout_url(auth_settings, return_uri) + pass + + except Exception as e: + logger.error(f"Error generating IdP logout URL for {provider_type}: {e}") + + logger.warning(f"No logout URL generated for provider: {provider_type}") + return None +@app.get("/logout") +async def logout_get(request: Request): + """ + Log out by clearing the session cookie and invalidating the session server-side. + Provides IdP logout URLs when available. + """ + session_cookie_name = "mcp_gateway_session" + SECRET_KEY = os.environ.get("SECRET_KEY", "insecure-default-key-for-testing-only") + + # Extract session cookie manually (same approach as get_current_user) + session = request.cookies.get(session_cookie_name) + + # Decode session and invalidate server-side + provider_type = None + username = None + session_invalidated = False + + if session: + try: + s = URLSafeTimedSerializer(SECRET_KEY) + session_data = s.loads(session, max_age=None) # Don't check expiry for logout + provider_type = session_data.get("provider_type") + username = session_data.get("username") + + # Invalidate session server-side using lightweight approach + fingerprint = get_session_fingerprint(session_data) + SESSION_LOGOUT_TIMES[fingerprint] = time.time() + session_invalidated = True + + # Only cleanup if we have too many entries (avoid unnecessary work) + if len(SESSION_LOGOUT_TIMES) > MAX_LOGOUT_ENTRIES: + cleanup_logout_times() + + except Exception as e: + logger.warning(f"Error decoding session during logout: {e}") + + # Create base logout response + timestamp = int(datetime.now(timezone.utc).timestamp()) + logout_url = f"/login?t={timestamp}&signed_out=true" + + # Add IdP logout URL if available (but don't block on it) + if provider_type: + try: + idp_logout_url = get_idp_logout_url_fast(provider_type, request) + if idp_logout_url: + logout_url += f"&idp_logout={urllib.parse.quote(idp_logout_url)}" + logout_url += f"&provider_type={provider_type}" + except Exception: + # Don't let IdP logout URL generation block the logout - fail silently + pass + + response = RedirectResponse(url=logout_url, status_code=status.HTTP_303_SEE_OTHER) + + # Clear session cookie and add cache control headers + response.delete_cookie(key=session_cookie_name, path="/", httponly=True) + response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate' + response.headers['Pragma'] = 'no-cache' + response.headers['Expires'] = '0' + # Removed Clear-Site-Data header to improve logout performance + + # Log completion with minimal info + if username: + logger.debug(f"Logout completed for user: {username}") + + return response -# --- Background Health Check Task --- -async def run_health_checks(): - """Periodically checks the health of registered *enabled* services.""" - while True: - logger.info(f"Running periodic health checks (Interval: {HEALTH_CHECK_INTERVAL_SECONDS}s)...") - paths_to_check = list(REGISTERED_SERVERS.keys()) - needs_broadcast = False # Flag to check if any status actually changed +@app.post("/logout") +async def logout_post(request: Request): + """Handle POST logout requests.""" + return await logout_get(request) - # --- Use a copy of MOCK_SERVICE_STATE for stable iteration --- START - current_enabled_state = MOCK_SERVICE_STATE.copy() - # --- Use a copy of MOCK_SERVICE_STATE for stable iteration --- END - for path in paths_to_check: - if path not in REGISTERED_SERVERS: # Check if server was removed during the loop - continue +# --- Main Routes --- - # --- Use copied state for check --- START - # is_enabled = MOCK_SERVICE_STATE.get(path, False) - is_enabled = current_enabled_state.get(path, False) - # --- Use copied state for check --- END - previous_status = SERVER_HEALTH_STATUS.get(path) - - if not is_enabled: - new_status = "disabled" - if previous_status != new_status: - SERVER_HEALTH_STATUS[path] = new_status - # Also clear last check time when disabling? Or keep it? Keep for now. - # SERVER_LAST_CHECK_TIME[path] = None - needs_broadcast = True - logger.info(f"Service {path} is disabled. Setting status.") - continue # Skip health check for disabled services - - # --- Service is enabled, perform check using the new function --- - logger.info(f"Performing periodic check for enabled service: {path}") +@app.get("/", response_class=HTMLResponse) +async def index(request: Request, username: Annotated[str, Depends(get_current_user)]): + """Render the main index page. Requires authentication.""" + # Directly pass services to the template as in the original implementation + service_data = [] + sorted_server_paths = sorted( + REGISTERED_SERVERS.keys(), key=lambda p: REGISTERED_SERVERS[p]["server_name"] + ) + + # Get user's scopes if authenticated with OAuth 2.1 + user_scopes = set() + has_admin_scope = False + + # Check if the user has OAuth-based authentication with scopes + user_scopes = set() + + if hasattr(request.state, "user") and hasattr(request.state.user, "scopes"): + user_scopes = set(request.state.user.scopes) + # Check for admin scope which grants access to all servers + auth_settings = AuthSettings() + has_admin_scope = auth_settings.registry_admin_scope in user_scopes + logger.info(f"User {request.state.user.display_name} has scopes: {user_scopes}, Admin: {has_admin_scope}") + else: + # Check if we have a session cookie with groups + session_cookie = request.cookies.get("mcp_gateway_session") + if session_cookie: try: - # Call the refactored check function - # We only care if the status *changed* from the beginning of the cycle for broadcast purposes - current_status, _ = await perform_single_health_check(path) - if previous_status != current_status: - needs_broadcast = True + # Use the environment variable for SECRET_KEY + secret_key = os.environ.get("SECRET_KEY", "insecure-default-key-for-testing-only") + s = URLSafeTimedSerializer(secret_key) + data = s.loads(session_cookie) + + # Extract groups from session if available + groups = data.get("groups", []) + if groups: + # Create a SessionUser to extract scopes from groups + session_user = SessionUser(data.get("username", "unknown"), groups) + user_scopes = session_user.scopes + + # Check for admin scope + auth_settings = AuthSettings() + has_admin_scope = auth_settings.registry_admin_scope in user_scopes + + logger.info(f"User {data.get('username')} from session has groups: {groups}") + logger.info(f"Extracted scopes: {user_scopes}, Admin: {has_admin_scope}") + else: + logger.info("No groups found in session cookie") + has_admin_scope = False except Exception as e: - # Log error if the check function itself fails unexpectedly - logger.error(f"ERROR: Unexpected exception calling perform_single_health_check for {path}: {e}") - # Update status to reflect this error? - error_status = f"error: check execution failed ({type(e).__name__})" - if previous_status != error_status: - SERVER_HEALTH_STATUS[path] = error_status - SERVER_LAST_CHECK_TIME[path] = datetime.now(timezone.utc) # Record time of failure - needs_broadcast = True - - - logger.info(f"Finished periodic health checks. Current status map: {SERVER_HEALTH_STATUS}") - # Broadcast status update only if something changed during this cycle - if needs_broadcast: - logger.info("Broadcasting updated health status after periodic check...") - await broadcast_health_status() + logger.error(f"Error processing session cookie: {e}") + has_admin_scope = False else: - logger.info("No status changes detected in periodic check, skipping broadcast.") - - # Wait for the next interval - await asyncio.sleep(HEALTH_CHECK_INTERVAL_SECONDS) - + # No scopes available - rely entirely on IdP for authorization + has_admin_scope = False + logger.info(f"No scopes available for user - access will be restricted") + + for path in sorted_server_paths: + server_info = REGISTERED_SERVERS[path] + server_name = server_info["server_name"] + + # Filter servers based on user's OAuth scopes + if not has_admin_scope: + # Get the required scope for this server + auth_settings = AuthSettings() + required_scope = auth_settings.get_server_execute_scope(path) + + # Skip this server if user doesn't have the required scope + if required_scope not in user_scopes: + logger.info(f"Skipping server {path} - user lacks required scope: {required_scope}") + continue + + # Pass all required fields to the template + service_data.append( + { + "display_name": server_name, + "path": path, + "description": server_info.get("description", ""), + "is_enabled": MOCK_SERVICE_STATE.get(path, False), + "tags": server_info.get("tags", []), + "num_tools": server_info.get("num_tools", 0), + "num_stars": server_info.get("num_stars", 0), + "is_python": server_info.get("is_python", False), + "license": server_info.get("license", "N/A"), + "health_status": SERVER_HEALTH_STATUS.get(path, "unknown"), + "last_checked_iso": SERVER_LAST_CHECK_TIME.get(path).isoformat() if SERVER_LAST_CHECK_TIME.get(path) else None + } + ) + + return templates.TemplateResponse( + "index.html", { + "request": request, + "services": service_data, + "username": username, + "user_has_toggle_scope": lambda server_path: user_has_toggle_scope(request, server_path), + "user_has_edit_scope": lambda server_path: user_has_edit_scope(request, server_path), + "user_has_admin_scope": lambda: user_has_admin_scope(request) + } + ) -# --- Lifespan for Startup Task --- -@asynccontextmanager -async def lifespan(app: FastAPI): - # --- Configure Logging INSIDE lifespan --- START - # Ensure log directory exists - CONTAINER_LOG_DIR.mkdir(parents=True, exist_ok=True) # Should be defined now - - # Configure logging - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - handlers=[ - logging.FileHandler(LOG_FILE_PATH), # Use correct variable - logging.StreamHandler() # Log to console (stdout/stderr) - ] +@app.get("/debug", response_class=HTMLResponse) +async def debug_index(request: Request, username: Annotated[str, Depends(get_current_user)]): + """Render the debug page for diagnostic purposes.""" + return templates.TemplateResponse( + "debug_index.html", { + "request": request, + "username": username, + "user_has_toggle_scope": lambda server_path: user_has_toggle_scope(request, server_path), + "user_has_edit_scope": lambda server_path: user_has_edit_scope(request, server_path), + "user_has_admin_scope": lambda: user_has_admin_scope(request) + } ) - logger.info("Logging configured. Running startup tasks...") # Now logger is configured - # --- Configure Logging INSIDE lifespan --- END - # 0. Load FAISS data and embedding model - load_faiss_data() # Loads model, empty index or existing index. Synchronous. - # 1. Load server definitions and persisted enabled/disabled state - load_registered_servers_and_state() # This populates REGISTERED_SERVERS. Synchronous. +@app.get("/login", response_class=HTMLResponse) +async def login_form(request: Request, error: str = None): + """Render the login form.""" + return templates.TemplateResponse("login.html", { + "request": request, + "error": error, + "user_has_toggle_scope": lambda server_path: user_has_toggle_scope(request, server_path), + "user_has_edit_scope": lambda server_path: user_has_edit_scope(request, server_path), + "user_has_admin_scope": lambda: user_has_admin_scope(request) + }) + + +@app.get("/api/servers") +async def list_servers( + request: Request, + username: Annotated[str, Depends(api_auth)] +): + """Get all registered servers with their state.""" + servers_list = [] + for path, server_info in REGISTERED_SERVERS.items(): + # Create a copy of the server info and add its enabled status + server_data = server_info.copy() + server_data["is_enabled"] = MOCK_SERVICE_STATE.get(path, False) + server_data["health_status"] = SERVER_HEALTH_STATUS.get(path, "unknown") + last_checked = SERVER_LAST_CHECK_TIME.get(path) + server_data["last_checked"] = last_checked.isoformat() if last_checked else None + servers_list.append(server_data) + return servers_list + + +# --- Function to regenerate Nginx configuration --- +def regenerate_nginx_config(): + """ + Regenerate the Nginx configuration file to include all enabled services. + Updates both HTTP (port 80) and HTTPS (port 443) server blocks. + Returns True if successful, False otherwise. + """ + # Load the existing config file to preserve non-dynamic parts + existing_config = "" + if NGINX_CONFIG_PATH.exists(): + try: + with open(NGINX_CONFIG_PATH, "r") as f: + existing_config = f.read() + except Exception as e: + logger.error(f"Failed to read existing Nginx config: {e}") + return False - # 1.5 Sync FAISS with loaded servers (initial build or update) - if embedding_model and faiss_index is not None: # Check faiss_index is not None - logger.info("Performing initial FAISS synchronization with loaded server definitions...") - sync_tasks = [] - for path, server_info in REGISTERED_SERVERS.items(): - # add_or_update_service_in_faiss is async, can be gathered - sync_tasks.append(add_or_update_service_in_faiss(path, server_info)) + # Extract parts before and after the dynamic locations + start_marker = "# DYNAMIC_LOCATIONS_START" + end_marker = "# DYNAMIC_LOCATIONS_END" + + # Add error handling blocks and auth endpoints for MCP tool execution + error_locations = """ + # Error handling locations for MCP tool execution + location @error401 { + return 401 '{"error":"Unauthorized","detail":"Authentication failed or insufficient permissions"}'; + } + + location @error404 { + return 404 '{"error":"Not Found","detail":"The requested resource was not found"}'; + } + + location @error5xx { + return 500 '{"error":"Server Error","detail":"An unexpected server error occurred"}'; + } + + # Common auth endpoint for tool execution - internal use only + location ~ ^/api/tool_auth/(.+)$ { + internal; # Only for internal use by auth_request + proxy_pass http://localhost:7860/api/tool_auth/$1; + proxy_pass_request_body off; + proxy_set_header Content-Length ""; + proxy_set_header X-Original-URI $request_uri; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + # Pass through authentication cookies + proxy_set_header Cookie $http_cookie; + } +""" + + # Build configuration for each enabled service + dynamic_locations = [] + dynamic_locations.append(start_marker) + + # Add error locations at the top + dynamic_locations.append(error_locations) + + for path, enabled in MOCK_SERVICE_STATE.items(): + if not enabled: + logger.info(f"Skipping disabled service: {path}") + continue + + service_info = REGISTERED_SERVERS.get(path) + if not service_info: + logger.warning(f"Service {path} is enabled but not found in REGISTERED_SERVERS") + continue + + proxy_url = service_info.get("proxy_pass_url") + if not proxy_url: + logger.warning(f"Service {path} has no proxy_pass_url defined") + continue + + # Add health check status info + service_health = SERVER_HEALTH_STATUS.get(path, "unknown") - if sync_tasks: - await asyncio.gather(*sync_tasks) - logger.info("Initial FAISS synchronization complete.") - else: - logger.warning("Skipping initial FAISS synchronization: embedding model or FAISS index not ready.") - - # 2. Perform initial health checks concurrently for *enabled* services - logger.info("Performing initial health checks for enabled services...") - initial_check_tasks = [] - enabled_paths = [path for path, is_enabled in MOCK_SERVICE_STATE.items() if is_enabled] - - global SERVER_HEALTH_STATUS, SERVER_LAST_CHECK_TIME - # Initialize status for all servers (defaults for disabled) - for path in REGISTERED_SERVERS.keys(): - SERVER_LAST_CHECK_TIME[path] = None # Initialize last check time - if path not in enabled_paths: - SERVER_HEALTH_STATUS[path] = "disabled" - else: - # Will be set by the check task below (or remain unset if check fails badly) - SERVER_HEALTH_STATUS[path] = "checking" # Tentative status before check runs - - logger.info(f"Initially enabled services to check: {enabled_paths}") - if enabled_paths: - for path in enabled_paths: - # Create a task for each enabled service check - task = asyncio.create_task(perform_single_health_check(path)) - initial_check_tasks.append(task) - - # Wait for all initial checks to complete - results = await asyncio.gather(*initial_check_tasks, return_exceptions=True) + # Check for dev mode flag to include all enabled servers regardless of health + dev_mode = os.environ.get("MCP_GATEWAY_DEV_MODE", "").lower() in ("1", "true", "yes") + + # In dev mode, include all enabled services; otherwise only include healthy ones + if not dev_mode and service_health != "healthy": + logger.warning(f"Skipping unhealthy service: {path} (status: {service_health})") + continue + + logger.info(f"Adding service to nginx config: {path} (status: {service_health})") + + # Create location entry for this service + # Remove leading slash from path for safer path building + safe_path = path.lstrip("/") + + # 1. Regular location block for the service path + nginx_location = f""" + location /{safe_path}/ {{ + proxy_pass {proxy_url.rstrip('/')}/; + proxy_http_version 1.1; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + + # Pass through authentication headers for service-specific tokens + proxy_pass_request_headers on; + # Preserve headers prefixed with X-Service-Auth- + proxy_set_header X-Service-Auth-Github $http_x_service_auth_github; + proxy_set_header X-Service-Auth-AWS $http_x_service_auth_aws; + proxy_set_header X-Service-Auth-Token $http_x_service_auth_token; + }}""" + dynamic_locations.append(nginx_location) + + # No nginx proxy configuration needed for tool execution - # Log results/errors from initial checks - for i, result in enumerate(results): - path = enabled_paths[i] - if isinstance(result, Exception): - logger.error(f"ERROR during initial health check for {path}: {result}") - # Status might have already been set to an error state within the check function + # Add the end marker + dynamic_locations.append(end_marker) + + # Generate the dynamic content as a string + dynamic_content = "\n".join(dynamic_locations) + + # Generate the dynamic content as a string + dynamic_content = "\n".join(dynamic_locations) + + # Find all occurrences of start and end markers + start_positions = [] + end_positions = [] + start_pos = 0 + + # Find all start markers + while True: + start_pos = existing_config.find(start_marker, start_pos) + if start_pos == -1: + break + start_positions.append(start_pos) + start_pos += len(start_marker) + + # Find all end markers + end_pos = 0 + while True: + end_pos = existing_config.find(end_marker, end_pos) + if end_pos == -1: + break + end_positions.append(end_pos + len(end_marker)) + end_pos += len(end_marker) + + # Verify we have matching pairs of markers + if len(start_positions) != len(end_positions) or len(start_positions) == 0: + logger.error(f"Mismatched or missing markers: {len(start_positions)} starts, {len(end_positions)} ends") + return False + + # Sort positions to ensure correct order + start_positions.sort() + end_positions.sort() + + # Build new config by replacing each section + new_config = existing_config + + # Replace sections in reverse order to avoid position shifts + for i in range(len(start_positions) - 1, -1, -1): + start_pos = start_positions[i] + end_pos = end_positions[i] + + # Replace this section with dynamic content + new_config = new_config[:start_pos] + dynamic_content + new_config[end_pos:] + + logger.info(f"Updated {len(start_positions)} dynamic sections in Nginx config") + + # Write the new configuration + try: + with open(NGINX_CONFIG_PATH, "w") as f: + f.write(new_config) + logger.info(f"Nginx configuration updated at {NGINX_CONFIG_PATH}") + + # Reload Nginx if possible + try: + logger.info("Attempting to reload Nginx configuration...") + result = subprocess.run(['/usr/sbin/nginx', '-s', 'reload'], capture_output=True, text=True) + if result.returncode == 0: + logger.info(f"Nginx reload successful. stdout: {result.stdout.strip()}") + return True else: - status, _ = result # Unpack the result tuple - logger.info(f"Initial health check completed for {path}: Status = {status}") - # Update FAISS with potentially changed server_info (e.g., num_tools from health check) - if path in REGISTERED_SERVERS and embedding_model and faiss_index is not None: - # This runs after each health check result, can be awaited individually - await add_or_update_service_in_faiss(path, REGISTERED_SERVERS[path]) - else: - logger.info("No services are initially enabled.") - - logger.info(f"Initial health status after checks: {SERVER_HEALTH_STATUS}") - - # 3. Generate Nginx config *after* initial checks are done - logger.info("Generating initial Nginx configuration...") - regenerate_nginx_config() # Generate config based on initial health status + logger.error(f"Failed to reload Nginx configuration. Return code: {result.returncode}") + logger.error(f"Nginx reload stderr: {result.stderr.strip()}") + logger.error(f"Nginx reload stdout: {result.stdout.strip()}") + return False + except FileNotFoundError: + logger.error("'nginx' command not found. Cannot reload configuration.") + return False + except subprocess.CalledProcessError as e: + logger.error(f"Failed to reload Nginx configuration. Return code: {e.returncode}") + logger.error(f"Nginx reload stderr: {e.stderr.strip()}") + logger.error(f"Nginx reload stdout: {e.stdout.strip()}") + return False + except Exception as e: + logger.error(f"An unexpected error occurred during Nginx reload: {e}", exc_info=True) + return False - # 4. Start the background periodic health check task - logger.info("Starting background health check task...") - health_check_task = asyncio.create_task(run_health_checks()) + except FileNotFoundError: + logger.error(f"Target Nginx config file not found at {NGINX_CONFIG_PATH}. Cannot regenerate.") + return False + except Exception as e: + logger.error(f"Failed to modify Nginx config at {NGINX_CONFIG_PATH}: {e}", exc_info=True) + return False - # --- Yield to let the application run --- START - yield - # --- Yield to let the application run --- END +COMMENTED_LOCATION_BLOCK_TEMPLATE = """ +# location {path}/ {{ +# proxy_pass {proxy_pass_url}; +# proxy_http_version 1.1; +# proxy_set_header Host $host; +# proxy_set_header X-Real-IP $remote_addr; +# proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; +# proxy_set_header X-Forwarded-Proto $scheme; +# }} +""" - # --- Shutdown tasks --- START - logger.info("Running shutdown tasks...") - logger.info("Cancelling background health check task...") - health_check_task.cancel() - try: - await health_check_task - except asyncio.CancelledError: - logger.info("Health check task cancelled successfully.") - # --- Shutdown tasks --- END +# --- Helper function to normalize a path to a filename --- +def path_to_filename(path): + # Remove leading slash and replace remaining slashes with underscores + normalized = path.lstrip("/").replace("/", "_") + # Append .json extension if not present + if not normalized.endswith(".json"): + normalized += ".json" + return normalized -app = FastAPI(lifespan=lifespan) +# --- Data Loading --- +def load_registered_servers_and_state(): + global REGISTERED_SERVERS, MOCK_SERVICE_STATE + logger.info(f"Loading server definitions from {SERVERS_DIR}...") -# --- Authentication / Session Dependency --- -def get_current_user( - session: Annotated[str | None, Cookie(alias=SESSION_COOKIE_NAME)] = None, -) -> str: - if session is None: - raise HTTPException( - status_code=307, detail="Not authenticated", headers={"Location": "/login"} - ) - try: - data = signer.loads(session, max_age=SESSION_MAX_AGE_SECONDS) - username = data.get("username") - if not username: - raise HTTPException( - status_code=307, - detail="Invalid session data", - headers={"Location": "/login"}, - ) - return username - except (BadSignature, SignatureExpired): - response = RedirectResponse( - url="/login?error=Session+expired+or+invalid", status_code=307 - ) - response.delete_cookie(SESSION_COOKIE_NAME) - raise HTTPException( - status_code=307, - detail="Session expired or invalid", - headers={"Location": "/login"}, - ) - except Exception: - raise HTTPException( - status_code=307, - detail="Authentication error", - headers={"Location": "/login"}, - ) + # Create servers directory if it doesn't exist + SERVERS_DIR.mkdir(parents=True, exist_ok=True) # Added parents=True + temp_servers = {} + server_files = list(SERVERS_DIR.glob("**/*.json")) + logger.info(f"Found {len(server_files)} JSON files in {SERVERS_DIR} and its subdirectories") + for file in server_files: + logger.info(f"[DEBUG] - {file.relative_to(SERVERS_DIR)}") -# --- API Authentication Dependency (returns 401 instead of redirecting) --- -def api_auth( - session: Annotated[str | None, Cookie(alias=SESSION_COOKIE_NAME)] = None, -) -> str: - if session is None: - raise HTTPException(status_code=401, detail="Not authenticated") - try: - data = signer.loads(session, max_age=SESSION_MAX_AGE_SECONDS) - username = data.get("username") - if not username: - raise HTTPException(status_code=401, detail="Invalid session data") - return username - except (BadSignature, SignatureExpired): - raise HTTPException(status_code=401, detail="Session expired or invalid") - except Exception: - raise HTTPException(status_code=401, detail="Authentication error") + if not server_files: + logger.warning(f"No server definition files found in {SERVERS_DIR}. Initializing empty registry.") + REGISTERED_SERVERS = {} + # Don't return yet, need to load state file + # return + for server_file in server_files: + if server_file.name == STATE_FILE_PATH.name: # Skip the state file itself + continue + try: + with open(server_file, "r") as f: + server_info = json.load(f) -# --- Static Files and Templates (Paths relative to this script) --- -app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") -templates = Jinja2Templates(directory=TEMPLATES_DIR) + if ( + isinstance(server_info, dict) + and "path" in server_info + and "server_name" in server_info + ): + server_path = server_info["path"] + if server_path in temp_servers: + logger.warning(f"Duplicate server path found in {server_file}: {server_path}. Overwriting previous definition.") -# --- Routes --- + # Add new fields with defaults + server_info["description"] = server_info.get("description", "") + server_info["tags"] = server_info.get("tags", []) + server_info["num_tools"] = server_info.get("num_tools", 0) + server_info["num_stars"] = server_info.get("num_stars", 0) + server_info["is_python"] = server_info.get("is_python", False) + server_info["license"] = server_info.get("license", "N/A") + server_info["proxy_pass_url"] = server_info.get("proxy_pass_url", None) + server_info["tool_list"] = server_info.get("tool_list", []) # Initialize tool_list if missing + temp_servers[server_path] = server_info + else: + logger.warning(f"Invalid server entry format found in {server_file}. Skipping.") + except FileNotFoundError: + logger.error(f"Server definition file {server_file} reported by glob not found.") + except json.JSONDecodeError as e: + logger.error(f"Could not parse JSON from {server_file}: {e}.") + except Exception as e: + logger.error(f"An unexpected error occurred loading {server_file}: {e}", exc_info=True) -@app.get("/login", response_class=HTMLResponse) -async def login_form(request: Request, error: str | None = None): - return templates.TemplateResponse( - "login.html", {"request": request, "error": error} - ) + REGISTERED_SERVERS = temp_servers + logger.info(f"Successfully loaded {len(REGISTERED_SERVERS)} server definitions.") + # --- Load persisted mock service state --- START + logger.info(f"Attempting to load persisted state from {STATE_FILE_PATH}...") + loaded_state = {} + try: + if STATE_FILE_PATH.exists(): + with open(STATE_FILE_PATH, "r") as f: + loaded_state = json.load(f) + if not isinstance(loaded_state, dict): + logger.warning(f"Invalid state format in {STATE_FILE_PATH}. Expected a dictionary. Resetting state.") + loaded_state = {} # Reset if format is wrong + else: + logger.info(f"Loaded state for {len(loaded_state)} services.") + else: + logger.info(f"No state file found at {STATE_FILE_PATH}. Starting with empty state.") + except json.JSONDecodeError: + logger.warning(f"Could not parse JSON from {STATE_FILE_PATH}. Resetting state.") + except Exception as e: + logger.error(f"Error loading state file: {e}", exc_info=True) + + # Initialize state for all registered servers + for path in REGISTERED_SERVERS: + if path not in MOCK_SERVICE_STATE: + # Default to enabled for new services + MOCK_SERVICE_STATE[path] = loaded_state.get(path, True) + + logger.info(f"Service state initialized with {len(MOCK_SERVICE_STATE)} entries.") + # --- Load persisted mock service state --- END -@app.post("/login") -async def login_submit( - username: Annotated[str, Form()], password: Annotated[str, Form()] -): - # cu = os.environ.get("ADMIN_USER", "admin") - # cp = os.environ.get("ADMIN_PASSWORD", "password") - # logger.info(f"Login attempt with username: {username}, {cu}") - # logger.info(f"Login attempt with password: {password}, {cp}") - correct_username = secrets.compare_digest( - username, os.environ.get("ADMIN_USER", "admin") - ) - correct_password = secrets.compare_digest( - password, os.environ.get("ADMIN_PASSWORD", "password") - ) - if correct_username and correct_password: - session_data = signer.dumps({"username": username}) - response = RedirectResponse(url="/", status_code=status.HTTP_303_SEE_OTHER) - response.set_cookie( - key=SESSION_COOKIE_NAME, - value=session_data, - max_age=SESSION_MAX_AGE_SECONDS, - httponly=True, - samesite="lax", - ) - logger.info(f"User '{username}' logged in successfully.") - return response - else: - logger.info(f"Login failed for user '{username}'.") - return RedirectResponse( - url="/login?error=Invalid+username+or+password", - status_code=status.HTTP_303_SEE_OTHER, +# --- Check function to test if a service is healthy --- +async def perform_single_health_check(path: str): + """ + Perform a health check for a single service. + Updates SERVER_HEALTH_STATUS and SERVER_LAST_CHECK_TIME. + + In development/test mode, we're more tolerant of health check failures to ensure + services remain visible in the UI even if the backend services are not running. + """ + # Update status to checking + SERVER_HEALTH_STATUS[path] = "checking" + SERVER_LAST_CHECK_TIME[path] = datetime.now(timezone.utc) + + # Get info about the service + service_info = REGISTERED_SERVERS.get(path) + if not service_info: + error_msg = f"Service not found in registry: {path}" + SERVER_HEALTH_STATUS[path] = f"error: {error_msg}" + await broadcast_health_status() + return + + # Service must be enabled for health check + if not MOCK_SERVICE_STATE.get(path, False): + error_msg = "Service is disabled" + SERVER_HEALTH_STATUS[path] = f"error: {error_msg}" + await broadcast_health_status() + return + + # Get the proxy pass URL + proxy_url = service_info.get("proxy_pass_url") + if not proxy_url: + error_msg = "No proxy_pass_url defined" + SERVER_HEALTH_STATUS[path] = f"error: {error_msg}" + await broadcast_health_status() + return + + # Check for dev mode flag to bypass actual health checks + dev_mode = os.environ.get("MCP_GATEWAY_DEV_MODE", "").lower() in ("1", "true", "yes") + if dev_mode: + logger.info(f"Dev mode enabled - marking {path} as healthy without checking") + SERVER_HEALTH_STATUS[path] = "healthy" + + # Set real tools for known servers in dev mode + server_name = path.lstrip("/") + server_path = os.path.join(os.environ.get("SERVER_DIR", "/Users/aaronbw/Documents/DEV/v1/mcp-gateway/servers"), server_name) + server_py_path = os.path.join(server_path, "server.py") + + # Check if this server has a server.py file and we haven't set tools yet + if os.path.exists(server_py_path) and not service_info.get("real_tools_set"): + # Try to automatically extract tools from server.py + extracted_tools = try_extract_tools_from_server_py(server_py_path) + + if extracted_tools: + # Only set tools if we successfully extracted them + service_info["num_tools"] = len(extracted_tools) + service_info["tool_list"] = extracted_tools + logger.info(f"Using {len(extracted_tools)} automatically extracted tools for {path}") + else: + # If extraction failed, leave tools as-is - don't show anything until the real tools are available + logger.warning(f"Failed to extract tools from {path}, no tools will be shown in dev mode") + # Make sure we don't have any old placeholder values by explicitly setting to empty + service_info["num_tools"] = 0 + service_info["tool_list"] = [] + + # Mark that we've set real tools for this server + service_info["real_tools_set"] = True + REGISTERED_SERVERS[path] = service_info + logger.info(f"Set real tools for {path} in dev mode") + # Never use placeholder tools - if we don't have real tools, set an empty list + elif not service_info.get("tool_list"): + service_info["num_tools"] = 0 + service_info["tool_list"] = [] + REGISTERED_SERVERS[path] = service_info + logger.info(f"No tools set for {path} - waiting for real tools") + + SERVER_LAST_CHECK_TIME[path] = datetime.now(timezone.utc) + await broadcast_health_status() + regenerate_nginx_config() + return + + # Form the health check URL - this will be the /tools endpoint + health_url = proxy_url.rstrip("/") + "/tools" + + # Use curl for health check since it's installed in the container + try: + # Perform the health check using curl command + cmd = ["curl", "-s", "-m", str(HEALTH_CHECK_TIMEOUT_SECONDS), "-w", "%{http_code}", health_url] + logger.info(f"Running health check: {' '.join(cmd)}") + + # Run curl in a subprocess + process = await asyncio.create_subprocess_exec( + *cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE ) + + stdout, stderr = await process.communicate() + + # Process return code and output + return_code = process.returncode + response = stdout.decode().strip() + error_output = stderr.decode().strip() + + if return_code != 0: + # curl command failed + error_msg = f"Health check failed (curl error): {error_output or 'unknown error'}" + SERVER_HEALTH_STATUS[path] = f"unhealthy" + logger.warning(f"Health check for {path} failed: {error_msg}") + else: + # Check response code + try: + # Try to parse JSON from the first part of the response + # The response format is + status_code = response[-3:] # Last 3 chars should be status code + json_content = response[:-3] # Everything before status code + + if status_code.isdigit() and 200 <= int(status_code) < 300: + try: + tools_data = json.loads(json_content) + # Update tool count in server info + tool_list = tools_data.get("tools", []) + service_info["num_tools"] = len(tool_list) + # Store the tools list + service_info["tool_list"] = tool_list + REGISTERED_SERVERS[path] = service_info + # Update status + SERVER_HEALTH_STATUS[path] = "healthy" + logger.info(f"Health check for {path} succeeded: {len(tool_list)} tools found") + except json.JSONDecodeError: + # Couldn't parse JSON, consider failed + SERVER_HEALTH_STATUS[path] = "unhealthy" + logger.warning(f"Health check for {path} failed: Invalid JSON response") + else: + # Status code not 2xx + SERVER_HEALTH_STATUS[path] = "unhealthy" + logger.warning(f"Health check for {path} failed: HTTP {status_code}") + except Exception as e: + SERVER_HEALTH_STATUS[path] = "unhealthy" + logger.warning(f"Health check for {path} failed to parse response: {e}") + + except Exception as e: + SERVER_HEALTH_STATUS[path] = "unhealthy" + logger.error(f"Error performing health check for {path}: {e}") + + # Update last check time + SERVER_LAST_CHECK_TIME[path] = datetime.now(timezone.utc) + + # Broadcast status update + await broadcast_health_status() + + # Trigger Nginx config regeneration if status changed + regenerate_nginx_config() - -@app.post("/logout") -async def logout(): - logger.info("User logged out.") - response = RedirectResponse(url="/login", status_code=status.HTTP_303_SEE_OTHER) - response.delete_cookie(SESSION_COOKIE_NAME) - return response - - -@app.get("/", response_class=HTMLResponse) -async def read_root( - request: Request, - username: Annotated[str, Depends(get_current_user)], - query: str | None = None, -): - service_data = [] - search_query = query.lower() if query else "" - sorted_server_paths = sorted( - REGISTERED_SERVERS.keys(), key=lambda p: REGISTERED_SERVERS[p]["server_name"] - ) - for path in sorted_server_paths: - server_info = REGISTERED_SERVERS[path] - server_name = server_info["server_name"] - # Include description and tags in search - searchable_text = f"{server_name.lower()} {server_info.get('description', '').lower()} {' '.join(server_info.get('tags', []))}" - if not search_query or search_query in searchable_text: - # Pass all required fields to the template - service_data.append( - { - "display_name": server_name, - "path": path, - "description": server_info.get("description", ""), - "is_enabled": MOCK_SERVICE_STATE.get(path, False), - "tags": server_info.get("tags", []), - "num_tools": server_info.get("num_tools", 0), - "num_stars": server_info.get("num_stars", 0), - "is_python": server_info.get("is_python", False), - "license": server_info.get("license", "N/A"), - "health_status": SERVER_HEALTH_STATUS.get(path, "unknown"), # Get current health status - "last_checked_iso": SERVER_LAST_CHECK_TIME.get(path).isoformat() if SERVER_LAST_CHECK_TIME.get(path) else None - } - ) - # --- End Debug --- - return templates.TemplateResponse( - "index.html", - {"request": request, "services": service_data, "username": username}, - ) +# --- Background task to run health checks periodically --- +async def run_health_checks(): + """ + Background task that periodically runs health checks for all enabled services. + """ + logger.info("Health check background task started.") + + try: + while True: + # Check all enabled services + enabled_services = [path for path, enabled in MOCK_SERVICE_STATE.items() if enabled] + logger.info(f"Running health checks for {len(enabled_services)} enabled services...") + + for path in enabled_services: + try: + # Check if we need to do a health check + last_check = SERVER_LAST_CHECK_TIME.get(path) + now = datetime.now(timezone.utc) + + # Never checked, or it's been longer than the interval + if last_check is None or (now - last_check).total_seconds() >= HEALTH_CHECK_INTERVAL_SECONDS: + logger.info(f"Running health check for {path}") + await perform_single_health_check(path) + else: + time_since = (now - last_check).total_seconds() + logger.debug(f"Skipping health check for {path} (checked {time_since:.1f}s ago)") + except Exception as e: + logger.error(f"Error in health check for {path}: {e}") + + # Sleep a short interval before checking again + await asyncio.sleep(30) # Check every 30 seconds if any service needs checking + + except asyncio.CancelledError: + logger.info("Health check background task cancelled.") + except Exception as e: + logger.error(f"Error in health check task: {e}") + + logger.info("Health check background task ended.") -@app.post("/toggle/{service_path:path}") -async def toggle_service_route( +# --- Handle disabled services --- START +@app.post("/api/services/{service_path:path}/toggle", response_model=None) +async def toggle_service_api( request: Request, service_path: str, - enabled: Annotated[str | None, Form()] = None, - username: Annotated[str, Depends(get_current_user)] = None, + username: Annotated[str, Depends(api_auth)], + enabled: bool = Form(False), ): - if not service_path.startswith("/"): - service_path = "/" + service_path + """ + Toggle a service on or off through the API. + Requires the mcp:server:{service_path}:toggle scope or mcp:registry:admin scope. + """ + # Check authorization + auth_dependency = require_toggle_for_path(service_path) + await run_async_dependency(auth_dependency, {"request": request}) + # Normalize the path + if not service_path.startswith('/'): + service_path = '/' + service_path + + # Check if the service exists if service_path not in REGISTERED_SERVERS: - raise HTTPException(status_code=404, detail="Service path not registered") - - new_state = enabled == "on" - MOCK_SERVICE_STATE[service_path] = new_state - server_name = REGISTERED_SERVERS[service_path]["server_name"] - logger.info( - f"Simulated toggle for '{server_name}' ({service_path}) to {new_state} by user '{username}'" - ) - - # --- Update health status immediately on toggle --- START - new_status = "" - last_checked_iso = None - last_checked_dt = None # Initialize datetime object - - if new_state: - # Perform immediate check when enabling - logger.info(f"Performing immediate health check for {service_path} upon toggle ON...") + raise HTTPException(status_code=404, detail="Service not found") + + # Update the service status + logger.info(f"User '{username}' toggling service {service_path} to {'enabled' if enabled else 'disabled'}") + MOCK_SERVICE_STATE[service_path] = enabled + + # Save the updated state + try: + with open(STATE_FILE_PATH, "w") as f: + json.dump(MOCK_SERVICE_STATE, f, indent=2) + logger.info(f"Updated service state saved to {STATE_FILE_PATH}") + except Exception as e: + logger.error(f"Failed to save service state: {e}") + + # Handle enabled/disabled services differently + if enabled: + # If enabling, update its health status + logger.info(f"Service {service_path} enabled. Running health check...") + # Run the health check and broadcast in that function try: - new_status, last_checked_dt = await perform_single_health_check(service_path) - last_checked_iso = last_checked_dt.isoformat() if last_checked_dt else None - logger.info(f"Immediate check for {service_path} completed. Status: {new_status}") + # Update status and time as the check starts + SERVER_HEALTH_STATUS[service_path] = "checking" + SERVER_LAST_CHECK_TIME[service_path] = datetime.now(timezone.utc) + + # Broadcast the "checking" state first + await broadcast_health_status() + + # Then run the actual check (which will broadcast again) + await perform_single_health_check(service_path) except Exception as e: - # Handle potential errors during the immediate check itself - logger.error(f"ERROR during immediate health check for {service_path}: {e}") - new_status = f"error: immediate check failed ({type(e).__name__})" - # Update global state to reflect this error - SERVER_HEALTH_STATUS[service_path] = new_status - last_checked_dt = SERVER_LAST_CHECK_TIME.get(service_path) # Use time if check started - last_checked_iso = last_checked_dt.isoformat() if last_checked_dt else None + logger.error(f"Error during health check of newly enabled service: {e}") + # Mark as unhealthy if the check fails with an exception + SERVER_HEALTH_STATUS[service_path] = "unhealthy" + SERVER_LAST_CHECK_TIME[service_path] = datetime.now(timezone.utc) + # And make sure to broadcast + await broadcast_health_status() else: - # When disabling, set status to disabled and keep last check time - new_status = "disabled" - # Keep the last check time from when it was enabled - last_checked_dt = SERVER_LAST_CHECK_TIME.get(service_path) - last_checked_iso = last_checked_dt.isoformat() if last_checked_dt else None - # Update global state directly when disabling - SERVER_HEALTH_STATUS[service_path] = new_status - logger.info(f"Service {service_path} toggled OFF. Status set to disabled.") - # --- Update FAISS metadata for disabled service --- START - if embedding_model and faiss_index is not None: - logger.info(f"Updating FAISS metadata for disabled service {service_path}.") - # REGISTERED_SERVERS[service_path] contains the static definition - await add_or_update_service_in_faiss(service_path, REGISTERED_SERVERS[service_path]) - else: - logger.warning(f"Skipped FAISS metadata update for disabled service {service_path}: model or index not ready.") - # --- Update FAISS metadata for disabled service --- END - - # --- Send *targeted* update via WebSocket --- START - # Send immediate feedback for the toggled service only - # Always get the latest num_tools from the registry - current_num_tools = REGISTERED_SERVERS.get(service_path, {}).get("num_tools", 0) - - update_data = { - service_path: { - "status": new_status, - "last_checked_iso": last_checked_iso, - "num_tools": current_num_tools # Include num_tools - } + # If disabling, just mark it as disabled + logger.info(f"Service {service_path} disabled. Removing from configuration...") + SERVER_HEALTH_STATUS[service_path] = "disabled" + SERVER_LAST_CHECK_TIME[service_path] = datetime.now(timezone.utc) + + # Broadcast the disabled state + await broadcast_health_status() + + # Regenerate the Nginx config + regenerate_nginx_config() + + # Return success with the new state + return { + "service_path": service_path, + "enabled": MOCK_SERVICE_STATE[service_path], + "health_status": SERVER_HEALTH_STATUS.get(service_path, "unknown") } - message = json.dumps(update_data) - logger.info(f"--- TOGGLE: Sending targeted update: {message}") - - # Create task to send without blocking the request - async def send_specific_update(): - disconnected_clients = set() - current_connections = list(active_connections) - send_tasks = [] - for conn in current_connections: - send_tasks.append((conn, conn.send_text(message))) +# --- Handle disabled services --- END - results = await asyncio.gather(*(task for _, task in send_tasks), return_exceptions=True) - - for i, result in enumerate(results): - conn, _ = send_tasks[i] - if isinstance(result, Exception): - logger.warning(f"Error sending toggle update to WebSocket client {conn.client}: {result}. Marking for removal.") - disconnected_clients.add(conn) - if disconnected_clients: - logger.info(f"Removing {len(disconnected_clients)} disconnected clients after toggle update.") - for conn in disconnected_clients: - if conn in active_connections: - active_connections.remove(conn) - - asyncio.create_task(send_specific_update()) - # --- Send *targeted* update via WebSocket --- END - - # --- Persist the updated state --- START +# --- Frontend Toggle Handler --- START +@app.post("/toggle/{service_path:path}") +async def toggle_service_frontend( + request: Request, + service_path: str, + username: Annotated[str, Depends(get_current_user)], + enabled: bool = Form(False), +): + """ + Toggle a service on or off via the frontend form. + Requires the mcp:server:{service_path}:toggle scope or mcp:registry:admin scope. + """ + # Check authorization + auth_dependency = require_toggle_for_path(service_path) + await run_async_dependency(auth_dependency, {"request": request}) + + # Normalize the path + if not service_path.startswith('/'): + service_path = '/' + service_path + + # Check if the service exists + if service_path not in REGISTERED_SERVERS: + raise HTTPException(status_code=404, detail="Service not found") + + # Update the service status + logger.info(f"User '{username}' toggling service {service_path} to {'enabled' if enabled else 'disabled'}") + MOCK_SERVICE_STATE[service_path] = enabled + + # Save the updated state + try: + with open(STATE_FILE_PATH, "w") as f: + json.dump(MOCK_SERVICE_STATE, f, indent=2) + logger.info(f"Updated service state saved to {STATE_FILE_PATH}") + except Exception as e: + logger.error(f"Failed to save service state: {e}") + + # Handle enabled/disabled services differently + if enabled: + # If enabling, update its health status + logger.info(f"Service {service_path} enabled. Running health check...") + # Run the health check and broadcast in that function + try: + # Update status and time as the check starts + SERVER_HEALTH_STATUS[service_path] = "checking" + SERVER_LAST_CHECK_TIME[service_path] = datetime.now(timezone.utc) + + # Broadcast the "checking" state first + await broadcast_health_status() + + # Then run the actual check (which will broadcast again) + await perform_single_health_check(service_path) + except Exception as e: + logger.error(f"Error during health check of newly enabled service: {e}") + # Mark as unhealthy if the check fails with an exception + SERVER_HEALTH_STATUS[service_path] = "unhealthy" + SERVER_LAST_CHECK_TIME[service_path] = datetime.now(timezone.utc) + # And make sure to broadcast + await broadcast_health_status() + else: + # If disabling, just mark it as disabled + logger.info(f"Service {service_path} disabled. Removing from configuration...") + SERVER_HEALTH_STATUS[service_path] = "disabled" + SERVER_LAST_CHECK_TIME[service_path] = datetime.now(timezone.utc) + + # Broadcast the disabled state + await broadcast_health_status() + + # Regenerate the Nginx config + regenerate_nginx_config() + + # Return a JSON response instead of the default dictionary + # The frontend expects a JSON response format + return JSONResponse(content={ + "service_path": service_path, + "enabled": MOCK_SERVICE_STATE[service_path], + "health_status": SERVER_HEALTH_STATUS.get(service_path, "unknown") + }) +# --- Frontend Toggle Handler --- END + +# --- Service Search --- START +@app.get("/api/search") +async def search_services( + query: str, + username: Annotated[str, Depends(api_auth)], + filter_enabled: bool = False +): + """ + Search for services based on text query using FAISS. + + Args: + query: Search query text + filter_enabled: Only return enabled services (default: False) + + Returns: + List of matching services with scores + """ + global embedding_model, faiss_index, faiss_metadata_store + + # Handle empty query + if not query or query.strip() == "": + logger.info("Empty search query. Returning all services.") + + # Just return all services instead + all_services = [] + for path, server_info in REGISTERED_SERVERS.items(): + is_enabled = MOCK_SERVICE_STATE.get(path, False) + + # Apply enabled filter if requested + if filter_enabled and not is_enabled: + continue + + service_copy = server_info.copy() + service_copy["is_enabled"] = is_enabled + service_copy["relevance_score"] = 0.0 # No relevance score for unranked results + service_copy["health_status"] = SERVER_HEALTH_STATUS.get(path, "unknown") + all_services.append(service_copy) + + # Sort alphabetically by name as fallback + all_services.sort(key=lambda x: x.get("server_name", "").lower()) + return all_services + + # Check if search is ready + if embedding_model is None or faiss_index is None: + logger.warning("FAISS search not ready (model or index not loaded)") + raise HTTPException( + status_code=503, + detail="Search functionality not available yet. Please try again later." + ) + try: - with open(STATE_FILE_PATH, "w") as f: - json.dump(MOCK_SERVICE_STATE, f, indent=2) - logger.info(f"Persisted state to {STATE_FILE_PATH}") + # Encode the query + query_embedding = await asyncio.to_thread(embedding_model.encode, [query.strip()]) + query_embedding_np = np.array([query_embedding[0]], dtype=np.float32) + + # Configure search (number of results, etc) + k = min(50, faiss_index.ntotal) # Return up to 50 results, or fewer if index is smaller + if k == 0: + logger.info("No services in search index.") + return [] + + # Search the index + distances, indices = faiss_index.search(query_embedding_np, k) + + # Process search results + results = [] + seen_paths = set() + + for i, (distance, idx) in enumerate(zip(distances[0], indices[0])): + if idx == -1: # -1 means no more matches + break + + # Find the service path for this index + service_path = None + server_info = None + + for path, metadata in faiss_metadata_store.items(): + if metadata.get("id") == idx: + service_path = path + server_info = metadata.get("full_server_info", {}) + break + + if not service_path or not server_info: + logger.warning(f"Found result with idx {idx} but no matching service in metadata store") + continue + + # Skip if we've already seen this service + if service_path in seen_paths: + continue + seen_paths.add(service_path) + + # Get enabled status + is_enabled = MOCK_SERVICE_STATE.get(service_path, False) + + # Apply enabled filter if requested + if filter_enabled and not is_enabled: + continue + + # Compute a relevance score (invert the distance) + # L2 distance might be arbitrarily large, so we apply a transformation + # to get a score between 0 and 1 (closer to 1 is better) + # This formula gives reasonable distribution for sentence transformer embeddings + relevance_score = 1.0 / (1.0 + distance/10) + + # Add to results + health_status = SERVER_HEALTH_STATUS.get(service_path, "unknown") + result = server_info.copy() # Start with the server info + result["is_enabled"] = is_enabled + result["relevance_score"] = relevance_score + result["health_status"] = health_status + results.append(result) + + # Sort by relevance score + results.sort(key=lambda x: x["relevance_score"], reverse=True) + + logger.info(f"Search for '{query}' returned {len(results)} results") + return results + except Exception as e: - logger.error(f"ERROR: Failed to persist state to {STATE_FILE_PATH}: {e}") - # Decide if we should raise an error or just log - # --- Persist the updated state --- END - - # Regenerate Nginx config after toggling state - if not regenerate_nginx_config(): - logger.error("ERROR: Failed to update Nginx configuration after toggle.") - - # --- Return JSON instead of Redirect --- START - final_status = SERVER_HEALTH_STATUS.get(service_path, "unknown") - final_last_checked_dt = SERVER_LAST_CHECK_TIME.get(service_path) - final_last_checked_iso = final_last_checked_dt.isoformat() if final_last_checked_dt else None - final_num_tools = REGISTERED_SERVERS.get(service_path, {}).get("num_tools", 0) - - return JSONResponse( - status_code=200, - content={ - "message": f"Toggle request for {service_path} processed.", - "service_path": service_path, - "new_enabled_state": new_state, # The state it was set to - "status": final_status, # The status after potential immediate check - "last_checked_iso": final_last_checked_iso, - "num_tools": final_num_tools - } - ) - # --- Return JSON instead of Redirect --- END - - # query_param = request.query_params.get("query", "") - # redirect_url = f"/?query={query_param}" if query_param else "/" - # return RedirectResponse(url=redirect_url, status_code=status.HTTP_303_SEE_OTHER) + logger.error(f"Error performing search: {e}", exc_info=True) + raise HTTPException( + status_code=500, + detail=f"Search failed: {str(e)}" + ) +# --- Service Search --- END +# --- Save Service Helper Function --- +def save_server_to_file(server_entry) -> bool: + """ + Save a server entry to disk as JSON. + + Args: + server_entry: Dictionary with server information + + Returns: + bool: True if successful, False otherwise + """ + if not server_entry or "path" not in server_entry: + logger.error("Invalid server entry to save, missing path") + return False + + # Create safe filename from path (replace slashes and some other chars) + path = server_entry["path"] + safe_name = path.lstrip('/').replace('/', '_').replace(':', '_') + + if not safe_name: + safe_name = "root" # In case path is just "/" + + filename = SERVERS_DIR / f"{safe_name}.json" + + try: + SERVERS_DIR.mkdir(parents=True, exist_ok=True) + with open(filename, "w") as f: + json.dump(server_entry, f, indent=2) + logger.info(f"Saved server info to {filename}") + return True + except Exception as e: + logger.error(f"Error saving server info to {filename}: {e}") + return False -@app.post("/register") +# --- Register a new service --- START +@app.post("/api/register", response_model=None) async def register_service( + request: Request, + username: Annotated[str, Depends(api_auth)], name: Annotated[str, Form()], - description: Annotated[str, Form()], path: Annotated[str, Form()], proxy_pass_url: Annotated[str, Form()], + description: Annotated[str, Form()] = "", tags: Annotated[str, Form()] = "", num_tools: Annotated[int, Form()] = 0, num_stars: Annotated[int, Form()] = 0, - is_python: Annotated[bool, Form()] = False, + is_python: Annotated[bool | None, Form()] = False, license_str: Annotated[str, Form(alias="license")] = "N/A", - username: Annotated[str, Depends(api_auth)] = None, ): - logger.info("[DEBUG] register_service() called with parameters:") - logger.info(f"[DEBUG] - name: {name}") - logger.info(f"[DEBUG] - description: {description}") - logger.info(f"[DEBUG] - path: {path}") - logger.info(f"[DEBUG] - proxy_pass_url: {proxy_pass_url}") - logger.info(f"[DEBUG] - tags: {tags}") - logger.info(f"[DEBUG] - num_tools: {num_tools}") - logger.info(f"[DEBUG] - num_stars: {num_stars}") - logger.info(f"[DEBUG] - is_python: {is_python}") - logger.info(f"[DEBUG] - license_str: {license_str}") - logger.info(f"[DEBUG] - username: {username}") - + """Register a new service with the gateway.""" + # Check authorization + auth_dependency = require_registry_admin() + await run_async_dependency(auth_dependency, {"request": request}) # Ensure path starts with a slash - if not path.startswith("/"): - path = "/" + path - logger.info(f"[DEBUG] Path adjusted to start with slash: {path}") - + if not path.startswith('/'): + path = '/' + path + # Check if path already exists if path in REGISTERED_SERVERS: - logger.error(f"[ERROR] Service registration failed: path '{path}' already exists") - return JSONResponse( - status_code=400, - content={"error": f"Service with path '{path}' already exists"}, - ) - - # Process tags: split string, strip whitespace, filter empty - tag_list = [tag.strip() for tag in tags.split(",") if tag.strip()] - logger.info(f"[DEBUG] Processed tags: {tag_list}") + raise HTTPException(status_code=400, detail="Service path already registered") + + # Process tags + tag_list = [tag.strip() for tag in tags.split(',') if tag.strip()] - # Create new server entry with all fields + # Create server entry server_entry = { "server_name": name, "description": description, @@ -1431,46 +2128,28 @@ async def register_service( "tags": tag_list, "num_tools": num_tools, "num_stars": num_stars, - "is_python": is_python, + "is_python": bool(is_python), # Convert checkbox value "license": license_str, - "tool_list": [] # Initialize tool list } - logger.info(f"[DEBUG] Created server entry: {json.dumps(server_entry, indent=2)}") - - # Save to individual file - logger.info("[DEBUG] Attempting to save server data to file...") + + # Save to disk storage success = save_server_to_file(server_entry) if not success: - logger.error("[ERROR] Failed to save server data to file") - return JSONResponse( - status_code=500, content={"error": "Failed to save server data"} - ) - logger.info("[DEBUG] Successfully saved server data to file") - - # Add to in-memory registry and default to disabled - logger.info("[DEBUG] Adding server to in-memory registry...") + raise HTTPException(status_code=500, detail="Failed to save server information") + + # Add to in-memory registry REGISTERED_SERVERS[path] = server_entry - logger.info("[DEBUG] Setting initial service state to disabled") + + # Set up default state (disabled by default) MOCK_SERVICE_STATE[path] = False - # Set initial health status for the new service (always start disabled) - logger.info("[DEBUG] Setting initial health status to 'disabled'") - SERVER_HEALTH_STATUS[path] = "disabled" # Start disabled - SERVER_LAST_CHECK_TIME[path] = None # No check time yet - # Ensure num_tools is present in the in-memory dict immediately - if "num_tools" not in REGISTERED_SERVERS[path]: - logger.info("[DEBUG] Adding missing num_tools field to in-memory registry") - REGISTERED_SERVERS[path]["num_tools"] = 0 - - # Regenerate Nginx config after successful registration - logger.info("[DEBUG] Attempting to regenerate Nginx configuration...") - if not regenerate_nginx_config(): - logger.error("[ERROR] Failed to update Nginx configuration after registration") - else: - logger.info("[DEBUG] Successfully regenerated Nginx configuration") - + + # Initialize health status + SERVER_HEALTH_STATUS[path] = "unknown" + SERVER_LAST_CHECK_TIME[path] = None + # --- Add to FAISS Index --- START - logger.info(f"[DEBUG] Adding/updating service '{path}' in FAISS index after registration...") - if embedding_model and faiss_index is not None: + logger.info(f"[DEBUG] Adding service '{path}' to FAISS index...") + if embedding_model is not None and faiss_index is not None: await add_or_update_service_in_faiss(path, server_entry) # server_entry is the new service info logger.info(f"[DEBUG] Service '{path}' processed for FAISS index.") else: @@ -1502,11 +2181,24 @@ async def register_service( }, ) -@app.get("/api/server_details/{service_path:path}") +@app.get("/api/server_details/{service_path:path}", response_model=None) async def get_server_details( + request: Request, service_path: str, - username: Annotated[str, Depends(api_auth)] + username: Annotated[str, Depends(api_auth)], ): + """ + Get detailed information about a server. + Requires the mcp:server:{service_path}:edit scope or mcp:registry:admin scope. + """ + # Check authorization + if service_path != 'all': + auth_dependency = require_edit_for_path(service_path) + await run_async_dependency(auth_dependency, {"request": request}) + else: + auth_dependency = check_admin_scope() + await run_async_dependency(auth_dependency, {"request": request}) + # Normalize the path to ensure it starts with '/' if not service_path.startswith('/'): service_path = '/' + service_path @@ -1525,12 +2217,551 @@ async def get_server_details( return server_info +# --- API endpoint for Tool Execution via SSE Transport --- START +@app.post("/api/execute/{service_path:path}", response_model=None) +async def execute_tool( + request: Request, + service_path: str, + username: Annotated[str, Depends(api_auth)], +): + """ + Execute a tool on a specific service using MCP client with SSE transport. + + endpoint acts as an MCP client to backend servers, providing OAuth-protected + access to tools while maintaining proper MCP protocol compliance. + + Transport: Server-Sent Events (SSE) + Auth required: mcp:server:{service_path}:execute scope or mcp:registry:admin scope + + Flow: + 1. Authenticate and authorize the request + 2. Validate service exists and is enabled + 3. Establish MCP client session with backend server + 4. Execute tool via proper MCP protocol + 5. Return JSON-RPC compliant response + """ + try: + # Normalize the service path + if not service_path.startswith('/'): + service_path = '/' + service_path + + # Check for authenticated user in request.state + if not hasattr(request.state, "user") or not request.state.user or not getattr(request.state.user, "is_authenticated", False): + logger.warning(f"Unauthorized attempt to execute tool on '{service_path}': No valid authentication") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Check if the service exists and is enabled + if service_path not in REGISTERED_SERVERS: + logger.warning(f"Service not found for tool execution: '{service_path}'") + raise HTTPException(status_code=404, detail=f"Service '{service_path}' not found") + + if not MOCK_SERVICE_STATE.get(service_path, False): + logger.warning(f"Service disabled for tool execution: '{service_path}'") + raise HTTPException(status_code=403, detail=f"Service '{service_path}' is disabled") + + # Get service info and determine port + service_info = REGISTERED_SERVERS.get(service_path) + proxy_url = service_info.get("proxy_pass_url") + + if not proxy_url: + logger.error(f"No proxy URL configured for service '{service_path}'") + raise HTTPException(status_code=500, detail=f"No proxy URL configured for service '{service_path}'") + + # Check required scopes + auth_settings = AuthSettings() + execute_scope = auth_settings.get_server_execute_scope(service_path) + + if not (request.state.user.has_scope(auth_settings.registry_admin_scope) or + request.state.user.has_scope(execute_scope)): + logger.warning(f"User '{username}' denied access to execute tool on '{service_path}' - missing scope: {execute_scope}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Missing required scope for service access: {execute_scope}", + ) + + # Parse JSON-RPC request + try: + body = await request.json() + if not all(key in body for key in ["jsonrpc", "method", "params", "id"]): + raise ValueError("Invalid JSON-RPC format") + + tool_name = body["params"]["name"] + tool_arguments = body["params"]["arguments"] + request_id = body["id"] + + logger.info(f"Tool execution: '{tool_name}' on '{service_path}' by '{username}' with args: {tool_arguments}") + except Exception as e: + logger.error(f"Failed to parse JSON-RPC request: {e}") + raise HTTPException(status_code=400, detail="Invalid JSON-RPC request format") + + # Establish MCP client session and execute tool + # Route through nginx to handle mount_path properly + # Nginx strips the service path prefix when proxying to backend + nginx_base = f"http://localhost{service_path}" # e.g., http://localhost/currenttime + + # Try SSE endpoint through nginx + sse_endpoints = [ + f"{nginx_base}/sse", # e.g., http://localhost/currenttime/sse -> proxied to backend + ] + + last_error = None + for sse_url in sse_endpoints: + try: + logger.info(f"Attempting SSE connection to: {sse_url}") + # Connect to MCP server with timeout + async with sse_client(sse_url, timeout=10.0) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + await asyncio.wait_for(session.initialize(), timeout=5.0) + + # Execute the tool + result = await asyncio.wait_for( + session.call_tool(tool_name, tool_arguments), + timeout=30.0 + ) + + # Extract content from MCP result and ensure it's serializable + if hasattr(result, 'content'): + if hasattr(result.content, 'text'): + # Handle TextContent objects + result_content = result.content.text + elif isinstance(result.content, list): + # Handle list of content objects + result_content = [] + for item in result.content: + if hasattr(item, 'text'): + result_content.append(item.text) + else: + result_content.append(str(item)) + else: + result_content = str(result.content) + else: + result_content = str(result) + + # Return JSON-RPC response + # Ensure result is always an array for consistency + if isinstance(result_content, list): + result_array = result_content + else: + result_array = [result_content] if result_content else [] + + return JSONResponse( + content={ + "jsonrpc": "2.0", + "result": result_array, + "id": request_id + }, + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0" + } + ) + + except asyncio.TimeoutError as e: + logger.warning(f"Timeout connecting to MCP server at {sse_url}") + last_error = e + continue # Try next endpoint + except Exception as e: + logger.warning(f"Error connecting to {sse_url}: {e}") + last_error = e + continue # Try next endpoint + + # If we've tried all endpoints and none worked, raise the last error + if last_error: + if isinstance(last_error, asyncio.TimeoutError): + raise HTTPException(status_code=504, detail=f"Timeout connecting to service '{service_path}'") + else: + raise HTTPException(status_code=502, detail=f"Failed to execute tool on service '{service_path}': {str(last_error)}") + + except Exception as e: + # Log the error with appropriate level and context + if isinstance(e, HTTPException): + # For expected HTTP exceptions, we log at warning level + if e.status_code >= 500: + logger.error(f"Server error during tool execution on '{service_path}': {e}") + else: + logger.warning(f"Client error during tool execution on '{service_path}': {e}") + raise + else: + # For unexpected exceptions, log as error with traceback + logger.error(f"Unexpected error during tool execution on '{service_path}': {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Internal server error during tool execution") + +@app.post("/api/streamable/{service_path:path}", response_model=None) +async def execute_tool_streamable( + request: Request, + service_path: str, + username: Annotated[str, Depends(api_auth)], +): + """ + Execute a tool on a specific service using MCP client with StreamableHTTP transport. + + This endpoint acts as an MCP client to backend servers, providing OAuth-protected + access to tools while maintaining proper MCP protocol compliance. + + Transport: StreamableHTTP + Auth required: mcp:server:{service_path}:execute scope or mcp:registry:admin scope + + Flow identical to execute_tool but for StreamableHTTP transport. + """ + try: + # Normalize the service path + if not service_path.startswith('/'): + service_path = '/' + service_path + + # Check for authenticated user in request.state + if not hasattr(request.state, "user") or not request.state.user or not getattr(request.state.user, "is_authenticated", False): + logger.warning(f"Unauthorized attempt to execute tool (streamable) on '{service_path}': No valid authentication") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Check if the service exists and is enabled + if service_path not in REGISTERED_SERVERS: + logger.warning(f"Service not found for tool execution (streamable): '{service_path}'") + raise HTTPException(status_code=404, detail=f"Service '{service_path}' not found") + + if not MOCK_SERVICE_STATE.get(service_path, False): + logger.warning(f"Service disabled for tool execution (streamable): '{service_path}'") + raise HTTPException(status_code=403, detail=f"Service '{service_path}' is disabled") + + # Get service info and determine port + service_info = REGISTERED_SERVERS.get(service_path) + proxy_url = service_info.get("proxy_pass_url") + + if not proxy_url: + logger.error(f"No proxy URL configured for service '{service_path}'") + raise HTTPException(status_code=500, detail=f"No proxy URL configured for service '{service_path}'") + + # Check required scopes + auth_settings = AuthSettings() + execute_scope = auth_settings.get_server_execute_scope(service_path) + + if not (request.state.user.has_scope(auth_settings.registry_admin_scope) or + request.state.user.has_scope(execute_scope)): + logger.warning(f"User '{username}' denied access to execute tool (streamable) on '{service_path}' - missing scope: {execute_scope}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Missing required scope for service access: {execute_scope}", + ) + + # Parse JSON-RPC request + try: + body = await request.json() + if not all(key in body for key in ["jsonrpc", "method", "params", "id"]): + raise ValueError("Invalid JSON-RPC format") + + tool_name = body["params"]["name"] + tool_arguments = body["params"]["arguments"] + request_id = body["id"] + + logger.info(f"StreamableHTTP tool execution: '{tool_name}' on '{service_path}' by '{username}' with args: {tool_arguments}") + except Exception as e: + logger.error(f"Failed to parse JSON-RPC request: {e}") + raise HTTPException(status_code=400, detail="Invalid JSON-RPC request format") + + # Import MCP client with proper error handling + try: + from mcp import ClientSession + import httpx + except ImportError as e: + logger.error(f"MCP SDK not available: {e}") + raise HTTPException(status_code=500, detail="MCP SDK not properly installed") + + # Check if this server supports StreamableHTTP transport + # Route through nginx to handle mount_path properly + nginx_base = f"http://localhost{service_path}" # e.g., http://localhost/currenttime + + # Try the MCP endpoint through nginx + mcp_endpoints = [ + f"{nginx_base}/mcp", # e.g., http://localhost/currenttime/mcp -> proxied to backend + ] + + streamable_url = None + for test_url in mcp_endpoints: + try: + # Test if StreamableHTTP endpoint exists + async with httpx.AsyncClient(timeout=5.0) as test_client: + test_response = await test_client.get(test_url) + if test_response.status_code != 404: + streamable_url = test_url + logger.info(f"Found StreamableHTTP endpoint at: {streamable_url}") + break + except Exception as e: + logger.debug(f"Failed to test {test_url}: {e}") + continue + + try: + if streamable_url is None: + # Server doesn't support StreamableHTTP, fall back to SSE approach + logger.info(f"Server '{service_path}' doesn't support StreamableHTTP, using SSE approach") + + # Use SSE transport for this request + # Route through nginx to handle mount_path properly + nginx_base = f"http://localhost{service_path}" # e.g., http://localhost/currenttime + + # Try SSE endpoint through nginx + sse_endpoints = [ + f"{nginx_base}/sse", # e.g., http://localhost/currenttime/sse -> proxied to backend + ] + + last_error = None + for sse_url in sse_endpoints: + try: + logger.info(f"StreamableHTTP fallback - attempting SSE connection to: {sse_url}") + # Connect to MCP server with timeout + async with sse_client(sse_url, timeout=10.0) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + await asyncio.wait_for(session.initialize(), timeout=5.0) + + # Execute the tool + result = await asyncio.wait_for( + session.call_tool(tool_name, tool_arguments), + timeout=30.0 + ) + + # Extract content from MCP result and ensure it's serializable + if hasattr(result, 'content'): + if hasattr(result.content, 'text'): + # Handle TextContent objects + result_content = result.content.text + elif isinstance(result.content, list): + # Handle list of content objects + result_content = [] + for item in result.content: + if hasattr(item, 'text'): + result_content.append(item.text) + else: + result_content.append(str(item)) + else: + result_content = str(result.content) + else: + result_content = str(result) + + # Return JSON-RPC response + # Ensure result is always an array for consistency + if isinstance(result_content, list): + result_array = result_content + else: + result_array = [result_content] if result_content else [] + + return JSONResponse( + content={ + "jsonrpc": "2.0", + "result": result_array, + "id": request_id + }, + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0" + } + ) + + except asyncio.TimeoutError as e: + logger.warning(f"StreamableHTTP fallback - timeout connecting to MCP server at {sse_url}") + last_error = e + continue # Try next endpoint + except Exception as e: + logger.warning(f"StreamableHTTP fallback - error connecting to {sse_url}: {e}") + last_error = e + continue # Try next endpoint + + # If we've tried all endpoints and none worked, raise the last error + if last_error: + if isinstance(last_error, asyncio.TimeoutError): + raise HTTPException(status_code=504, detail=f"Timeout connecting to service '{service_path}'") + else: + raise HTTPException(status_code=502, detail=f"Failed to execute tool on service '{service_path}': {str(last_error)}") + + # Server supports StreamableHTTP, proceed with original implementation + else: + # Use the discovered streamable URL + # Make direct HTTP request to StreamableHTTP endpoint + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.post( + streamable_url, + json={ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "mcp-gateway", + "version": "1.0.0" + } + }, + "id": 1 + }, + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream" + } + ) + + if response.status_code != 200: + raise Exception(f"Failed to initialize MCP session: {response.status_code}") + + # Now execute the tool + tool_response = await client.post( + streamable_url, + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": tool_arguments + }, + "id": request_id + }, + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream" + } + ) + + if tool_response.status_code != 200: + raise Exception(f"Tool execution failed: {tool_response.status_code} - {tool_response.text}") + + # Return the tool response + result = tool_response.json() + return JSONResponse( + content=result, + headers={ + "Cache-Control": "no-cache, no-store, must-revalidate", + "Pragma": "no-cache", + "Expires": "0" + } + ) + + except asyncio.TimeoutError: + logger.error(f"Timeout connecting to MCP server at {streamable_url}") + raise HTTPException(status_code=504, detail=f"Timeout connecting to service '{service_path}'") + except Exception as e: + logger.error(f"Error executing tool via StreamableHTTP: {e}", exc_info=True) + raise HTTPException(status_code=502, detail=f"Failed to execute tool on service '{service_path}': {str(e)}") + + except Exception as e: + # Log the error with appropriate level and context + if isinstance(e, HTTPException): + # For expected HTTP exceptions, we log at warning level + if e.status_code >= 500: + logger.error(f"Server error during StreamableHTTP tool execution on '{service_path}': {e}") + else: + logger.warning(f"Client error during StreamableHTTP tool execution on '{service_path}': {e}") + raise + else: + # For unexpected exceptions, log as error with traceback + logger.error(f"Unexpected error during StreamableHTTP tool execution on '{service_path}': {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Internal server error during tool execution") + +@app.post("/api/tool_auth/{service_path:path}", response_model=None) +async def auth_tool_request( + request: Request, + service_path: str, + username: Annotated[str, Depends(api_auth)], +): + """ + Authentication endpoint for tool execution via Nginx auth_request. + + This endpoint is used by Nginx to check if a user has permission to execute a tool + on a specific service. It returns 200 if the user has permission, and an + appropriate error code otherwise. + + This implementation leverages the MCP SDK's authentication mechanisms to verify + user permissions against service-specific scopes following the MCP protocol standards. + + Auth flow: + 1. Verify user is authenticated using the MCP auth context + 2. Check service path exists and is enabled + 3. Verify user has appropriate scope for the service + 4. Return 200 OK if authorized, appropriate error code otherwise + """ + try: + # Normalize the service path for consistent handling + if not service_path.startswith('/'): + service_path = '/' + service_path + + # Check for authenticated user in request.state + if not hasattr(request.state, "user") or not request.state.user or not getattr(request.state.user, "is_authenticated", False): + logger.warning(f"Unauthorized attempt to access service '{service_path}': No valid authentication") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not authenticated", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Check if the service exists and is enabled + if service_path not in REGISTERED_SERVERS: + logger.warning(f"Service not found: '{service_path}'") + raise HTTPException(status_code=404, detail=f"Service '{service_path}' not found") + + if not MOCK_SERVICE_STATE.get(service_path, False): + logger.warning(f"Service disabled: '{service_path}'") + raise HTTPException(status_code=403, detail=f"Service '{service_path}' is disabled") + + # Get settings and determine required scope for this service + auth_settings = AuthSettings() + # First check for admin scope - grants access to all services + if request.state.user.has_scope(auth_settings.registry_admin_scope): + logger.info(f"User '{username}' granted execute access to '{service_path}' via admin scope") + return JSONResponse(status_code=200, content={"status": "authorized"}) + + # Check for service-specific execute scope + execute_scope = auth_settings.get_server_execute_scope(service_path) + if request.state.user.has_scope(execute_scope): + logger.info(f"User '{username}' granted execute access to '{service_path}' via execute scope") + return JSONResponse(status_code=200, content={"status": "authorized"}) + + # No valid scope found + logger.warning(f"User '{username}' denied access to '{service_path}' - missing scope: {execute_scope}") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Missing required scope for service access: {execute_scope}", + ) + except Exception as e: + # Log the error with appropriate level and context + if isinstance(e, HTTPException): + # For expected HTTP exceptions, we log at warning level + if e.status_code >= 500: + logger.error(f"Server error during auth check for '{service_path}': {e}") + else: + logger.warning(f"Client error during auth check for '{service_path}': {e}") + raise + else: + # For unexpected exceptions, log as error with traceback + logger.error(f"Unexpected error during auth check for '{service_path}': {e}", exc_info=True) + raise HTTPException(status_code=500, detail=f"Internal server error during authorization") +# --- API endpoint for Tool Execution --- END + # --- Endpoint to get tool list for a service --- START -@app.get("/api/tools/{service_path:path}") +@app.get("/api/tools/{service_path:path}", response_model=None) async def get_service_tools( + request: Request, service_path: str, - username: Annotated[str, Depends(api_auth)] # Requires authentication + username: Annotated[str, Depends(api_auth)], # Requires authentication ): + """ + Get the list of tools for a specific server. + Requires the mcp:server:{service_path}:read scope or mcp:registry:admin scope. + """ + # Check authorization + if service_path != 'all': + auth_dependency = require_access_for_path(service_path) + await run_async_dependency(auth_dependency, {"request": request}) + else: + auth_dependency = check_admin_scope() + await run_async_dependency(auth_dependency, {"request": request}) + if not service_path.startswith('/'): service_path = '/' + service_path @@ -1581,8 +2812,20 @@ async def get_service_tools( # --- Refresh Endpoint --- START -@app.post("/api/refresh/{service_path:path}") -async def refresh_service(service_path: str, username: Annotated[str, Depends(api_auth)]): +@app.post("/api/refresh/{service_path:path}", response_model=None) +async def refresh_service( + request: Request, + service_path: str, + username: Annotated[str, Depends(api_auth)], +): + """ + Refresh a service by running a health check. + Requires the mcp:server:{service_path}:toggle scope or mcp:registry:admin scope. + """ + # Check authorization + auth_dependency = require_toggle_for_path(service_path) + await run_async_dependency(auth_dependency, {"request": request}) + if not service_path.startswith('/'): service_path = '/' + service_path @@ -1656,8 +2899,15 @@ async def edit_server_form( raise HTTPException(status_code=404, detail="Service path not found") return templates.TemplateResponse( - "edit_server.html", - {"request": request, "server": server_info, "username": username} + "edit_server.html", + { + "request": request, + "server": server_info, + "username": username, + "user_has_toggle_scope": lambda server_path: user_has_toggle_scope(request, server_path), + "user_has_edit_scope": lambda server_path: user_has_edit_scope(request, server_path), + "user_has_admin_scope": lambda: user_has_admin_scope(request) + } ) @app.post("/edit/{service_path:path}") @@ -1813,11 +3063,49 @@ async def websocket_endpoint(websocket: WebSocket): # --- Run (for local testing) --- -# Use: uvicorn registry.main:app --reload --host 0.0.0.0 --port 7860 --root-path /home/ubuntu/mcp-gateway -# (Running from parent dir) - -# If running directly (python registry/main.py): -# if __name__ == "__main__": -# import uvicorn -# # Running this way makes relative paths tricky, better to use uvicorn command from parent -# uvicorn.run(app, host="0.0.0.0", port=7860) \ No newline at end of file +# Use: uvicorn registry.main:app --reload --host 0.0.0.0 --port 7860 + +if __name__ == "__main__": + # Get port from environment variable or use default + port = int(os.environ.get("REGISTRY_PORT", 7860)) + uvicorn.run(app, host="0.0.0.0", port=port) +# Helper functions for template context +def user_has_toggle_scope(request, server_path): + """Check if the current user has toggle permission for a server.""" + if not hasattr(request.state, "user") or not hasattr(request.state.user, "has_scope"): + return False + + auth_settings = AuthSettings() + + # Admin scope grants all permissions + if request.state.user.has_scope(auth_settings.registry_admin_scope): + return True + + # Check for server-specific toggle scope + base_scope = auth_settings.server_execute_scope_prefix + server_path.lstrip("/") + toggle_scope = f"{base_scope}:toggle" + return request.state.user.has_scope(toggle_scope) + +def user_has_edit_scope(request, server_path): + """Check if the current user has edit permission for a server.""" + if not hasattr(request.state, "user") or not hasattr(request.state.user, "has_scope"): + return False + + auth_settings = AuthSettings() + + # Admin scope grants all permissions + if request.state.user.has_scope(auth_settings.registry_admin_scope): + return True + + # Check for server-specific edit scope + base_scope = auth_settings.server_execute_scope_prefix + server_path.lstrip("/") + edit_scope = f"{base_scope}:edit" + return request.state.user.has_scope(edit_scope) + +def user_has_admin_scope(request): + """Check if the current user has admin scope.""" + if not hasattr(request.state, "user") or not hasattr(request.state.user, "has_scope"): + return False + + auth_settings = AuthSettings() + return request.state.user.has_scope(auth_settings.registry_admin_scope) diff --git a/registry/templates/edit_server.html b/registry/templates/edit_server.html index a95dee6..ec7bd59 100644 --- a/registry/templates/edit_server.html +++ b/registry/templates/edit_server.html @@ -115,7 +115,7 @@

Edit Server: {{ server.server_name }}

- Cancel + Cancel
diff --git a/registry/templates/index.html b/registry/templates/index.html index e0edc29..0975728 100644 --- a/registry/templates/index.html +++ b/registry/templates/index.html @@ -1646,35 +1646,44 @@ body: new URLSearchParams(formData) // Send as x-www-form-urlencoded }); - const responseData = await response.json(); // Always try to parse JSON - if (!response.ok) { - // Log error from backend response if possible - const errorMsg = responseData.detail || `HTTP error ${response.status}`; - console.error(`Error toggling service ${servicePath}: ${errorMsg}`); - alert(`Error toggling service: ${errorMsg}`); - // Revert checkbox state on error? Maybe not, let WS handle the authoritative state. + // Handle HTTP error responses + console.error(`HTTP error ${response.status} while toggling service ${servicePath}`); + + // Try to get error details if available in JSON format + try { + const errorData = await response.json(); + const errorMsg = errorData.detail || `HTTP error ${response.status}`; + console.error(`Error details: ${errorMsg}`); + alert(`Error toggling service: ${errorMsg}`); + } catch (jsonError) { + // If response is not JSON, just show HTTP status + console.error(`Error response was not valid JSON: ${jsonError}`); + alert(`Error toggling service: HTTP ${response.status}`); + } } else { - console.log(`Toggle request successful for ${servicePath}. Backend response:`, responseData); - // Backend will trigger WebSocket update, which calls updateServiceDisplay - // updateServiceDisplay will hide spinner, set correct checkbox state, label, etc. + // Handle successful response + try { + const responseData = await response.json(); + console.log(`Toggle request successful for ${servicePath}. Backend response:`, responseData); + // Backend will trigger WebSocket update, which calls updateServiceDisplay + // updateServiceDisplay will hide spinner, set correct checkbox state, label, etc. + } catch (jsonError) { + console.warn(`Success response was not valid JSON: ${jsonError}`); + console.log(`Toggle request completed for ${servicePath} but response was not JSON.`); + } } - } catch (error) { console.error(`Network or fetch error toggling service ${servicePath}:`, error); alert(`Failed to send toggle request: ${error}`); - // If fetch fails completely, maybe revert spinner/disable state here? - // But relying on WS might still be better for consistency. } finally { - // Re-enable the checkbox ONLY IF the WS hasn't already done so - // updateServiceDisplay handles the definitive enabling/disabling and check state - // We might remove this re-enable here and solely rely on updateServiceDisplay. - // Let's test with it first. If double-enabling causes issues, remove this. - // Checkbox might still be disabled if WS update came through quickly. - // if (checkboxElement.disabled) { - // checkboxElement.disabled = false; - // } - // Let's rely on updateServiceDisplay to manage the final enabled state. + // Always try to reset UI state regardless of success/failure + if (spinner) { + spinner.style.display = 'none'; + } + if (checkboxElement.disabled) { + checkboxElement.disabled = false; + } } } @@ -2156,7 +2165,9 @@ {% endif %} - User: {{ username }} + + Logged in as {{ username }} +
@@ -2493,20 +2504,40 @@

Register by Pasting JSON

body: new URLSearchParams(formData) // Send as x-www-form-urlencoded }); - const responseData = await response.json(); // Always try to parse JSON - if (!response.ok) { - // Log error from backend response if possible - const errorMsg = responseData.detail || `HTTP error ${response.status}`; - console.error(`Error toggling service ${servicePath}: ${errorMsg}`); - alert(`Error toggling service: ${errorMsg}`); + // Handle HTTP error responses + console.error(`HTTP error ${response.status} while toggling service ${servicePath}`); + + // Try to get error details if available in JSON format + try { + const errorData = await response.json(); + const errorMsg = errorData.detail || `HTTP error ${response.status}`; + console.error(`Error details: ${errorMsg}`); + alert(`Error toggling service: ${errorMsg}`); + } catch (jsonError) { + // If response is not JSON, just show HTTP status + console.error(`Error response was not valid JSON: ${jsonError}`); + alert(`Error toggling service: HTTP ${response.status}`); + } } else { - console.log(`Toggle request successful for ${servicePath}. Backend response:`, responseData); + // Handle successful response + try { + const responseData = await response.json(); + console.log(`Toggle request successful for ${servicePath}. Backend response:`, responseData); + } catch (jsonError) { + console.warn(`Success response was not valid JSON: ${jsonError}`); + console.log(`Toggle request completed for ${servicePath} but response was not JSON.`); + } } - } catch (error) { console.error(`Network or fetch error toggling service ${servicePath}:`, error); alert(`Failed to send toggle request: ${error}`); + } finally { + // Always re-enable the checkbox and hide spinner after request completes + if (spinner) { + spinner.style.display = 'none'; + } + checkboxElement.disabled = false; } } diff --git a/registry/templates/login.html b/registry/templates/login.html index 799bd12..a6bb515 100644 --- a/registry/templates/login.html +++ b/registry/templates/login.html @@ -3,14 +3,332 @@ + + + + + + Login - MCP Gateway - + + + + + - \ No newline at end of file + \ No newline at end of file