Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rate limiting configuration for LLM providers #276

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
import gradio as gr
from langchain_core.rate_limiters import InMemoryRateLimiter

from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama

Expand All @@ -31,6 +32,15 @@ def get_llm_model(provider: str, **kwargs):
:param kwargs:
:return:
"""
rate_limit_rps = kwargs.get("rate_limit_rps", 1.0)
rate_limit_bucket = kwargs.get("rate_limit_bucket", 10)
# Create rate limiter
rate_limiter = InMemoryRateLimiter(
requests_per_second=rate_limit_rps,
check_every_n_seconds=0.1,
max_bucket_size=rate_limit_bucket
)

if provider not in ["ollama"]:
env_var = f"{provider.upper()}_API_KEY"
api_key = kwargs.get("api_key", "") or os.getenv(env_var, "")
Expand All @@ -49,6 +59,7 @@ def get_llm_model(provider: str, **kwargs):
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
rate_limiter=rate_limiter,
)
elif provider == 'mistral':
if not kwargs.get("base_url", ""):
Expand All @@ -65,6 +76,7 @@ def get_llm_model(provider: str, **kwargs):
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
rate_limiter=rate_limiter,
)
elif provider == "openai":
if not kwargs.get("base_url", ""):
Expand All @@ -77,6 +89,7 @@ def get_llm_model(provider: str, **kwargs):
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
rate_limiter=rate_limiter,
)
elif provider == "deepseek":
if not kwargs.get("base_url", ""):
Expand All @@ -90,19 +103,24 @@ def get_llm_model(provider: str, **kwargs):
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
rate_limiter=rate_limiter,
)
else:

return ChatOpenAI(
model=kwargs.get("model_name", "deepseek-chat"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
rate_limiter=rate_limiter,
)
elif provider == "google":

return ChatGoogleGenerativeAI(
model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
temperature=kwargs.get("temperature", 0.0),
google_api_key=api_key,
rate_limiter=rate_limiter,
)
elif provider == "ollama":
if not kwargs.get("base_url", ""):
Expand All @@ -111,19 +129,23 @@ def get_llm_model(provider: str, **kwargs):
base_url = kwargs.get("base_url")

if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"):

return DeepSeekR1ChatOllama(
model=kwargs.get("model_name", "deepseek-r1:14b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
base_url=base_url,
rate_limiter=rate_limiter,
)
else:

return ChatOllama(
model=kwargs.get("model_name", "qwen2.5:7b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
num_predict=kwargs.get("num_predict", 1024),
base_url=base_url,
rate_limiter=rate_limiter,
)
elif provider == "azure_openai":
if not kwargs.get("base_url", ""):
Expand All @@ -137,6 +159,7 @@ def get_llm_model(provider: str, **kwargs):
api_version=api_version,
azure_endpoint=base_url,
api_key=api_key,
rate_limiter=rate_limiter,
)
elif provider == "alibaba":
if not kwargs.get("base_url", ""):
Expand Down
40 changes: 33 additions & 7 deletions webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ async def run_browser_agent(
max_steps,
use_vision,
max_actions_per_step,
tool_calling_method
tool_calling_method,
rate_limit_rps,
rate_limit_bucket
):
global _global_agent_state
_global_agent_state.clear_stop() # Clear any previous stop requests
Expand Down Expand Up @@ -149,6 +151,8 @@ async def run_browser_agent(
temperature=llm_temperature,
base_url=llm_base_url,
api_key=llm_api_key,
rate_limit_rps=rate_limit_rps,
rate_limit_bucket=rate_limit_bucket
)
if agent_type == "org":
final_result, errors, model_actions, model_thoughts, trace_file, history_file = await run_org_agent(
Expand Down Expand Up @@ -456,7 +460,9 @@ async def run_with_stream(
max_steps,
use_vision,
max_actions_per_step,
tool_calling_method
tool_calling_method,
rate_limit_rps,
rate_limit_bucket
):
global _global_agent_state
stream_vw = 80
Expand Down Expand Up @@ -485,7 +491,9 @@ async def run_with_stream(
max_steps=max_steps,
use_vision=use_vision,
max_actions_per_step=max_actions_per_step,
tool_calling_method=tool_calling_method
tool_calling_method=tool_calling_method,
rate_limit_rps=rate_limit_rps,
rate_limit_bucket=rate_limit_bucket
)
# Add HTML content at the start of the result array
html_content = f"<h1 style='width:{stream_vw}vw; height:{stream_vh}vh'>Using browser...</h1>"
Expand Down Expand Up @@ -518,7 +526,9 @@ async def run_with_stream(
max_steps=max_steps,
use_vision=use_vision,
max_actions_per_step=max_actions_per_step,
tool_calling_method=tool_calling_method
tool_calling_method=tool_calling_method,
rate_limit_rps=rate_limit_rps,
rate_limit_bucket=rate_limit_bucket
)
)

Expand Down Expand Up @@ -632,7 +642,7 @@ async def close_global_browser():
await _global_browser.close()
_global_browser = None

async def run_deep_search(research_task, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless):
async def run_deep_search(research_task, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless, rate_limit_rps, rate_limit_bucket):
from src.utils.deep_research import deep_research
global _global_agent_state

Expand All @@ -646,6 +656,8 @@ async def run_deep_search(research_task, max_search_iteration_input, max_query_p
temperature=llm_temperature,
base_url=llm_base_url,
api_key=llm_api_key,
rate_limit_rps=rate_limit_rps,
rate_limit_bucket=rate_limit_bucket
)
markdown_content, file_path = await deep_research(research_task, llm, _global_agent_state,
max_search_iterations=max_search_iteration_input,
Expand Down Expand Up @@ -775,6 +787,19 @@ def create_ui(config, theme_name="Ocean"):
value=config['llm_api_key'],
info="Your API key (leave blank to use .env)"
)
with gr.Row():
rate_limit_rps = gr.Number(
label="Requests/sec",
value=config.get('rate_limit_rps', 1),
precision=1,
info="Max requests per second"
)
rate_limit_bucket = gr.Number(
label="Max Bucket Size",
value=config.get('rate_limit_bucket', 10),
precision=0,
info="Maximum burst capacity"
)

# Change event to update context length slider
def update_llm_num_ctx_visibility(llm_provider):
Expand Down Expand Up @@ -932,7 +957,8 @@ def update_llm_num_ctx_visibility(llm_provider):
agent_type, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key,
use_own_browser, keep_browser_open, headless, disable_security, window_w, window_h,
save_recording_path, save_agent_history_path, save_trace_path, # Include the new path
enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_calling_method
enable_recording, task, add_infos, max_steps, use_vision, max_actions_per_step, tool_calling_method,
rate_limit_rps, rate_limit_bucket
],
outputs=[
browser_view, # Browser view
Expand All @@ -951,7 +977,7 @@ def update_llm_num_ctx_visibility(llm_provider):
# Run Deep Research
research_button.click(
fn=run_deep_search,
inputs=[research_task_input, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless],
inputs=[research_task_input, max_search_iteration_input, max_query_per_iter_input, llm_provider, llm_model_name, llm_num_ctx, llm_temperature, llm_base_url, llm_api_key, use_vision, use_own_browser, headless, rate_limit_rps, rate_limit_bucket],
outputs=[markdown_output_display, markdown_download, stop_research_button, research_button]
)
# Bind the stop button click event after errors_output is defined
Expand Down