Skip to content

Commit

Permalink
feat: fixed printing of public links (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Sep 6, 2024
1 parent 664101b commit 2dd4667
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
9 changes: 6 additions & 3 deletions aidial_adapter_openai/gpt4_multi_modal/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,17 @@ class Config:
async def get_attachment_name(
file_storage: Optional[FileStorage], attachment: dict
) -> str:
if title := attachment.get("title"):
return title

if "data" in attachment:
return attachment.get("title") or "data attachment"

if "url" in attachment:
attachment_link = attachment["url"]
link = attachment["url"]
if file_storage is not None:
return await file_storage.get_human_readable_name(attachment_link)
return attachment_link
return await file_storage.get_human_readable_name(link)
return link

return "invalid attachment"

Expand Down
22 changes: 18 additions & 4 deletions aidial_adapter_openai/utils/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import mimetypes
import os
from typing import Mapping, Optional, TypedDict
from urllib.parse import urljoin
from urllib.parse import unquote, urljoin

import aiohttp

Expand Down Expand Up @@ -111,11 +111,25 @@ def attachment_link_to_url(self, link: str) -> str:

return urljoin(base_url, link)

def _url_to_attachment_link(self, url: str) -> str:
if core_api_version == "0.6":
return url.removeprefix(f"{self.dial_url}/v1/files/")
else:
return url.removeprefix(f"{self.dial_url}/v1/")

async def get_human_readable_name(self, link: str) -> str:
url = self.attachment_link_to_url(link)
async with aiohttp.ClientSession() as session:
bucket = await self._get_user_bucket(session)
return url.removeprefix(f"{self.dial_url}/v1/files/{bucket}/")
link = self._url_to_attachment_link(url)

if link.startswith("public/"):
bucket = "public"
else:
async with aiohttp.ClientSession() as session:
bucket = await self._get_user_bucket(session)

link = link.removeprefix(f"{bucket}/")
decoded_link = unquote(link)
return link if link == decoded_link else repr(decoded_link)

async def download_file_as_base64(self, url: str) -> str:
headers: Mapping[str, str] = {}
Expand Down

0 comments on commit 2dd4667

Please sign in to comment.