Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixbug: #1703 & #1709 & #1721 & replace #1735 #1710

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion metagpt/configs/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ class LLMConfig(YamlModel):
frequency_penalty: float = 0.0
best_of: Optional[int] = None
n: Optional[int] = None
stream: bool = True
seed: Optional[int] = None
# https://cookbook.openai.com/examples/using_logprobs
logprobs: Optional[bool] = None
Expand Down
10 changes: 10 additions & 0 deletions metagpt/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,13 @@ def LLM(llm_config: Optional[LLMConfig] = None, context: Context = None) -> Base
if llm_config is not None:
return ctx.llm_with_cost_manager_from_llm_config(llm_config)
return ctx.llm()


if __name__ == "__main__":
import asyncio

llm = LLM()
rsp = asyncio.run(llm.aask("hello world", stream=False))
print(f"{rsp}")
rsp = asyncio.run(llm.aask("hello world", stream=True))
print(f"{rsp}")
6 changes: 2 additions & 4 deletions metagpt/provider/base_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ async def aask(
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=USE_CONFIG_TIMEOUT,
stream=None,
timeout: int = USE_CONFIG_TIMEOUT,
stream: bool = True,
) -> str:
if system_msgs:
message = self._system_msgs(system_msgs)
Expand All @@ -146,8 +146,6 @@ async def aask(
message.append(self._user_msg(msg, images=images))
else:
message.extend(msg)
if stream is None:
stream = self.config.stream
logger.debug(message)
rsp = await self.acompletion_text(message, stream=stream, timeout=self.get_timeout(timeout))
return rsp
Expand Down
62 changes: 29 additions & 33 deletions metagpt/provider/ollama_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : self-host open llm model with ollama which isn't openai-api-compatible
# @Modified by : mashenquan. Tested with llama 3.2, https://www.ollama.com/library/llama3.2;
# nomic-embed-text, https://www.ollama.com/library/nomic-embed-text

import json
from enum import Enum, auto
from typing import AsyncGenerator, Optional, Tuple
from typing import AsyncGenerator, List, Optional, Tuple

from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.const import USE_CONFIG_TIMEOUT
Expand Down Expand Up @@ -38,7 +40,21 @@ def apply(self, messages: list[dict]) -> dict:
raise NotImplementedError

def decode(self, response: OpenAIResponse) -> dict:
return json.loads(response.data.decode("utf-8"))
data = response.data.decode("utf-8")
rsp = {}
content = ""
for val in data.splitlines():
if not val:
continue
m = json.loads(val)
if "embedding" in m:
return m
content += m.get("message", {}).get("content", "")
rsp.update(m)
if "message" not in rsp:
rsp["message"] = {}
rsp["message"]["content"] = content
return rsp

def get_choice(self, to_choice_dict: dict) -> str:
raise NotImplementedError
Expand Down Expand Up @@ -204,16 +220,14 @@ def __init__(self, config: LLMConfig):
def _llama_api_inuse(self) -> OllamaMessageAPI:
return OllamaMessageAPI.CHAT

@property
def _llama_api_kwargs(self) -> dict:
return {"options": {"temperature": 0.3}, "stream": self.config.stream}

def __init_ollama(self, config: LLMConfig):
assert config.base_url, "ollama base url is required!"
self.model = config.model
self.pricing_plan = self.model
ollama_message = OllamaMessageMeta.get_message(self._llama_api_inuse)
self.ollama_message = ollama_message(model=self.model, **self._llama_api_kwargs)
options = {"temperature": config.temperature}
self.ollama_message = ollama_message(model=self.model, options=options)
self.ollama_stream = ollama_message(model=self.model, options=options, stream=True)

def get_usage(self, resp: dict) -> dict:
return {"prompt_tokens": resp.get("prompt_eval_count", 0), "completion_tokens": resp.get("eval_count", 0)}
Expand All @@ -225,12 +239,7 @@ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFI
params=self.ollama_message.apply(messages=messages),
request_timeout=self.get_timeout(timeout),
)
if isinstance(resp, AsyncGenerator):
return await self._processing_openai_response_async_generator(resp)
elif isinstance(resp, OpenAIResponse):
return self._processing_openai_response(resp)
else:
raise ValueError
return self._processing_openai_response(resp)

def get_choice_text(self, rsp):
return self.ollama_message.get_choice(rsp)
Expand All @@ -241,17 +250,12 @@ async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT) ->
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.ollama_message.api_suffix,
params=self.ollama_message.apply(messages=messages),
url=self.ollama_stream.api_suffix,
params=self.ollama_stream.apply(messages=messages),
request_timeout=self.get_timeout(timeout),
stream=True,
)
if isinstance(resp, AsyncGenerator):
return await self._processing_openai_response_async_generator(resp)
elif isinstance(resp, OpenAIResponse):
return self._processing_openai_response(resp)
else:
raise ValueError
return await self._processing_openai_response_async_generator(resp)

def _processing_openai_response(self, openai_resp: OpenAIResponse):
resp = self.ollama_message.decode(openai_resp)
Expand All @@ -263,10 +267,10 @@ async def _processing_openai_response_async_generator(self, ag_openai_resp: Asyn
collected_content = []
usage = {}
async for raw_chunk in ag_openai_resp:
chunk = self.ollama_message.decode(raw_chunk)
chunk = self.ollama_stream.decode(raw_chunk)

if not chunk.get("done", False):
content = self.ollama_message.get_choice(chunk)
content = self.ollama_stream.get_choice(chunk)
collected_content.append(content)
log_llm_stream(content)
else:
Expand All @@ -285,26 +289,18 @@ class OllamaGenerate(OllamaLLM):
def _llama_api_inuse(self) -> OllamaMessageAPI:
return OllamaMessageAPI.GENERATE

@property
def _llama_api_kwargs(self) -> dict:
return {"options": {"temperature": 0.3}, "stream": self.config.stream}


@register_provider(LLMType.OLLAMA_EMBEDDINGS)
class OllamaEmbeddings(OllamaLLM):
@property
def _llama_api_inuse(self) -> OllamaMessageAPI:
return OllamaMessageAPI.EMBEDDINGS

@property
def _llama_api_kwargs(self) -> dict:
return {"options": {"temperature": 0.3}}

@property
def _llama_embedding_key(self) -> str:
return "embedding"

async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> dict:
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> List[float]:
resp, _, _ = await self.client.arequest(
method=self.http_method,
url=self.ollama_message.api_suffix,
Expand All @@ -313,7 +309,7 @@ async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFI
)
return self.ollama_message.decode(resp)[self._llama_embedding_key]

async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> List[float]:
return await self._achat_completion(messages, timeout=self.get_timeout(timeout))

def get_choice_text(self, rsp):
Expand Down
14 changes: 9 additions & 5 deletions metagpt/provider/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,16 @@ async def _achat_completion_stream(self, messages: list[dict], timeout=USE_CONFI
return full_reply_content

def _cons_kwargs(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT, **extra_kwargs) -> dict:
max_token_key = self._get_max_tokens_key()
kwargs = {
"messages": messages,
"max_tokens": self._get_max_tokens(messages),
max_token_key: self._get_max_tokens(messages),
# "n": 1, # Some services do not provide this parameter, such as mistral
# "stop": None, # default it's None and gpt4-v can't have this one
"temperature": self.config.temperature,
"model": self.model,
"timeout": self.get_timeout(timeout),
}
if "o1-" in self.model:
# compatible to openai o1-series
kwargs["temperature"] = 1
kwargs.pop("max_tokens")
if extra_kwargs:
kwargs.update(extra_kwargs)
return kwargs
Expand Down Expand Up @@ -309,3 +306,10 @@ async def gen_image(
img_url_or_b64 = item.url if resp_format == "url" else item.b64_json
imgs.append(decode_image(img_url_or_b64))
return imgs

def _get_max_tokens_key(self) -> str:
pattern = r"^o\d+(-\w+)*$"
if re.match(pattern, self.model):
# o1 series, see more https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens
return "max_completion_tokens"
return "max_tokens"
2 changes: 1 addition & 1 deletion metagpt/utils/token_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def count_output_tokens(string: str, model: str) -> int:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
logger.info(f"Warning: model {model} not found in tiktoken. Using cl100k_base encoding.")
logger.debug(f"Warning: model {model} not found in tiktoken. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(string))

Expand Down
Loading