Skip to content

Commit

Permalink
feat: support max_prompt_tokens for chat completion requests (#13) (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
vladisavvv authored Nov 6, 2023
1 parent f793236 commit 57dcb1b
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 16 deletions.
43 changes: 39 additions & 4 deletions aidial_adapter_openai/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
import logging.config
import os
from typing import Dict

import uvicorn
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
from openai import ChatCompletion, Embedding, error
from openai.openai_object import OpenAIObject

from aidial_adapter_openai.openai_override import OpenAIException
from aidial_adapter_openai.utils.exceptions import HTTPException
Expand All @@ -16,10 +18,11 @@
parse_upstream,
)
from aidial_adapter_openai.utils.streaming import generate_stream
from aidial_adapter_openai.utils.tokens import discard_messages

logging.config.dictConfig(LogConfig().dict())
app = FastAPI()
model_aliases = json.loads(os.getenv("MODEL_ALIASES", "{}"))
model_aliases: Dict[str, str] = json.loads(os.getenv("MODEL_ALIASES", "{}"))
azure_api_version = os.getenv("AZURE_API_VERSION", "2023-03-15-preview")


Expand All @@ -41,12 +44,34 @@ async def chat_completion(deployment_id: str, request: Request):
data = await parse_body(request)

is_stream = data.get("stream", False)
openai_model_name = model_aliases.get(deployment_id, deployment_id)
dial_api_key = request.headers["X-UPSTREAM-KEY"]

api_base, upstream_deployment = parse_upstream(
request.headers["X-UPSTREAM-ENDPOINT"], ApiType.CHAT_COMPLETION
)

discarded_messages = None
if "max_prompt_tokens" in data:
max_prompt_tokens = data["max_prompt_tokens"]
if type(max_prompt_tokens) != int:
raise HTTPException(
f"'{max_prompt_tokens}' is not of type 'integer' - 'max_prompt_tokens'",
400,
"invalid_request_error",
)
if max_prompt_tokens < 1:
raise HTTPException(
f"'{max_prompt_tokens}' is less than the minimum of 1 - 'max_prompt_tokens'",
400,
"invalid_request_error",
)
del data["max_prompt_tokens"]

data["messages"], discarded_messages = discard_messages(
data["messages"], openai_model_name, max_prompt_tokens
)

response = await handle_exceptions(
ChatCompletion().acreate(
engine=upstream_deployment,
Expand All @@ -55,7 +80,7 @@ async def chat_completion(deployment_id: str, request: Request):
api_type="azure",
api_version=azure_api_version,
request_timeout=(10, 600), # connect timeout and total timeout
**data
**data,
)
)

Expand All @@ -67,12 +92,22 @@ async def chat_completion(deployment_id: str, request: Request):
generate_stream(
data["messages"],
response,
model_aliases.get(deployment_id, deployment_id),
openai_model_name,
deployment_id,
discarded_messages,
),
media_type="text/event-stream",
)
else:
if discarded_messages is not None:
assert type(response) == OpenAIObject

response_with_statistics = response.to_dict() | {
"statistics": {"discarded_messages": discarded_messages}
}

return response_with_statistics

return response


Expand All @@ -93,7 +128,7 @@ async def embedding(deployment_id: str, request: Request):
api_type="azure",
api_version=azure_api_version,
request_timeout=(10, 600), # connect timeout and total timeout
**data
**data,
)
)

Expand Down
4 changes: 2 additions & 2 deletions aidial_adapter_openai/utils/parsers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from enum import Enum
from json import JSONDecodeError
from typing import Any, Mapping
from typing import Any, Dict

from fastapi import Request

Expand Down Expand Up @@ -32,7 +32,7 @@ def parse_upstream(

async def parse_body(
request: Request,
) -> Mapping[str, Any]:
) -> Dict[str, Any]:
try:
data = await request.json()
except JSONDecodeError as e:
Expand Down
12 changes: 10 additions & 2 deletions aidial_adapter_openai/utils/streaming.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from time import time
from typing import Any, Mapping
from typing import Any, Mapping, Optional
from uuid import uuid4

import tiktoken
Expand All @@ -21,7 +21,11 @@ def chunk_format(data: str | Mapping[str, Any]):


async def generate_stream(
messages: list[Any], response, model: str, deployment: str
messages: list[Any],
response,
model: str,
deployment: str,
discarded_messages: Optional[int],
):
encoding = tiktoken.encoding_for_model(model)

Expand All @@ -42,6 +46,10 @@ async def generate_stream(
"prompt_tokens": prompt_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
if discarded_messages is not None:
chunk_dict["statistics"] = {
"discarded_messages": discarded_messages
}
else:
total_content += chunk_dict["choices"][0]["delta"].get(
"content", ""
Expand Down
88 changes: 80 additions & 8 deletions aidial_adapter_openai/utils/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,99 @@
"""
from typing import Any, List

from tiktoken import Encoding
from tiktoken import Encoding, encoding_for_model

from aidial_adapter_openai.utils.exceptions import HTTPException


def calculate_prompt_tokens(
messages: List[Any], model: str, encoding: Encoding
):
) -> int:
prompt_tokens = 3

for message in messages:
prompt_tokens += calculate_tokens_per_message(message, encoding, model)

return prompt_tokens


def calculate_tokens_per_message(
message: Any,
encoding: Encoding,
model: str,
) -> int:
if model == "gpt-3.5-turbo-0301":
tokens_per_message = 4
tokens_per_name = -1
else:
tokens_per_message = 3
tokens_per_name = 1

prompt_tokens = tokens_per_message
for key, value in message.items():
prompt_tokens += len(encoding.encode(value))
if key == "name":
prompt_tokens += tokens_per_name

return prompt_tokens


def discard_messages(
messages: List[Any], model: str, max_prompt_tokens: int
) -> tuple[List[Any], int]:
if len(messages) == 0:
return messages, 0 # will be rejected by the upstream

encoding = encoding_for_model(model)

prompt_tokens = 3

non_system_messages_count = 0
for message in messages:
prompt_tokens += tokens_per_message
if message["role"] != "system":
non_system_messages_count += 1
continue

for key, value in message.items():
prompt_tokens += len(encoding.encode(value))
if key == "name":
prompt_tokens += tokens_per_name
prompt_tokens += calculate_tokens_per_message(message, encoding, model)

return prompt_tokens
if max_prompt_tokens < prompt_tokens:
raise HTTPException(
message=f"The token size of system messages ({prompt_tokens}) exceeds prompt token limit ({max_prompt_tokens})"
)

discarded_messages = non_system_messages_count
for message in reversed(messages):
if message["role"] == "system":
continue

prompt_tokens += calculate_tokens_per_message(message, encoding, model)

if max_prompt_tokens < prompt_tokens:
break

discarded_messages -= 1

if (
discarded_messages == non_system_messages_count
and non_system_messages_count > 0
):
raise HTTPException(
message=f"The token size of system messages and the last user message ({prompt_tokens}) exceeds prompt token limit ({max_prompt_tokens})",
status_code=400,
type="invalid_request_error",
)

messages_without_discarded = []

remaining_discarded_messages = discarded_messages
for message in messages:
if message["role"] == "system":
messages_without_discarded.append(message)
continue

if remaining_discarded_messages > 0:
remaining_discarded_messages -= 1
else:
messages_without_discarded.append(message)

return messages_without_discarded, discarded_messages
116 changes: 116 additions & 0 deletions tests/test_discard_messages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import pytest

from aidial_adapter_openai.utils.exceptions import HTTPException
from aidial_adapter_openai.utils.tokens import discard_messages

gpt4_testdata = [
(
[],
0,
([], 0),
),
(
[{"role": "system", "message": "This is four tokens"}],
11,
([{"role": "system", "message": "This is four tokens"}], 0),
),
(
[
{"role": "system", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
],
18,
"The token size of system messages and the last user message (19) exceeds prompt token limit (18)",
),
(
[
{"role": "system", "message": "This is four tokens"},
{"role": "system", "message": "This is four tokens"},
{"role": "system", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
],
11,
"The token size of system messages (27) exceeds prompt token limit (11)",
),
(
[
{"role": "system", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
{"role": "assistant", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
],
27,
(
[
{"role": "system", "message": "This is four tokens"},
{"role": "assistant", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
],
1,
),
),
(
[
{"role": "system", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
{"role": "assistant", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
],
34,
(
[
{"role": "system", "message": "This is four tokens"},
{"role": "assistant", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
],
1,
),
),
(
[
{"role": "system", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
{"role": "assistant", "message": "This is four tokens"},
{"role": "system", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
],
27,
(
[
{"role": "system", "message": "This is four tokens"},
{"role": "system", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
],
2,
),
),
(
[
{"role": "system", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
{"role": "assistant", "message": "This is four tokens"},
{"role": "system", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
],
35,
(
[
{"role": "system", "message": "This is four tokens"},
{"role": "assistant", "message": "This is four tokens"},
{"role": "system", "message": "This is four tokens"},
{"role": "user", "message": "This is four tokens"},
],
1,
),
),
]


@pytest.mark.parametrize("messages, max_prompt_tokens, response", gpt4_testdata)
def test_discarded_messages(messages, max_prompt_tokens, response):
try:
assert (
discard_messages(messages, "gpt-4", max_prompt_tokens) == response
)
except HTTPException as e:
assert e.message == response

0 comments on commit 57dcb1b

Please sign in to comment.