Skip to content

Commit

Permalink
add messages preloading
Browse files Browse the repository at this point in the history
move media downloading to another class and another async task
  • Loading branch information
RuslanUC committed Sep 23, 2023
1 parent 4be3f99 commit 4b9a2b1
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 54 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "t-export"
version = "0.1.0"
version = "0.1.1"
description = "Telegram chats export tool."
authors = ["RuslanUC <dev_ruslan_uc@protonmail.com>"]
readme = "README.md"
Expand Down
1 change: 1 addition & 0 deletions texport/export_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ExportConfig:
from_date: datetime = datetime(1970, 1, 1)
to_date: datetime = datetime.now()
print: bool = False
preload: bool = False

def excluded_media(self) -> set[MessageMediaType]:
result = set()
Expand Down
80 changes: 38 additions & 42 deletions texport/exporter.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,29 @@
import asyncio
from datetime import date
from os.path import relpath
from typing import Union, Optional
from typing import Union

from pyrogram import Client
from pyrogram.types import Message as PyroMessage
from pyrogram.utils import zero_datetime

from texport.export_config import ExportConfig
from texport.media import MEDIA_TYPES
from texport.messages_saver import MessagesSaver
from texport.progress_print import ProgressPrint


class ExportStatus:
def __init__(self):
self.approx_messages_count = None
self.last_message_id = None
self.last_date = None
from .export_config import ExportConfig
from .media import MEDIA_TYPES
from .media_downloader import MediaExporter
from .messages_preloader import Preloader
from .messages_saver import MessagesSaver
from .progress_print import ProgressPrint


class Exporter:
def __init__(self, client: Client, export_config: ExportConfig=None):
self._config = export_config or ExportConfig()
self._client = client
self._task = None
self.status: Optional[ExportStatus] = None
self._progress: ProgressPrint = ProgressPrint(disabled=not self._config.print)
self.progress: ProgressPrint = ProgressPrint(disabled=not self._config.print)
self._messages: list[PyroMessage] = []
self._media: dict[Union[int, str], str] = {}
self._saver = MessagesSaver(self._messages, self._media, export_config)
self._media_downloader = MediaExporter(client, export_config, self._media, self.progress)
self._excluded_media = self._config.excluded_media()

async def _export_media(self, message: PyroMessage) -> None:
Expand All @@ -40,57 +34,59 @@ async def _export_media(self, message: PyroMessage) -> None:
if media.file_size > self._config.size_limit * 1024 * 1024:
return

path = await message.download(file_name=f"{self._config.output_dir.absolute()}/{m.dir_name}/")
path = relpath(path, self._config.output_dir.absolute())
self._media[message.id] = path
self._media_downloader.add(media.file_id, f"{self._config.output_dir.absolute()}/{m.dir_name}/", message.id)

if hasattr(media, "thumbs") and media.thumbs:
path = await self._client.download_media(media.thumbs[0].file_id,
file_name=f"{self._config.output_dir.absolute()}/thumbs/")
path = relpath(path, self._config.output_dir.absolute())
self._media[f"{message.id}_thumb"] = path
self._media_downloader.add(media.thumbs[0].file_id, f"{self._config.output_dir.absolute()}/thumbs/",
f"{message.id}_thumb")

async def _write(self, wait_media: list[int]) -> None:
self.progress.status = "Waiting for all media to be downloaded..."
await self._media_downloader.wait(wait_media)
self.progress.status = "Writing messages to file..."
await self._saver.save()

async def _export(self, chat_id: Union[int, str]):
await self._media_downloader.run()

offset_date = zero_datetime() if self._config.to_date.date() >= date.today() else self._config.to_date
loaded = 0
self._progress.approx_messages_count = await self._client.get_chat_history_count(chat_id)
async for message in self._client.get_chat_history(chat_id, offset_date=offset_date):
medias = []
self.progress.approx_messages_count = await self._client.get_chat_history_count(chat_id)
messages_iter = Preloader(self._client, self.progress, self._export_media) \
if self._config.preload else self._client.get_chat_history
async for message in messages_iter(chat_id, offset_date=offset_date):
if message.date < self._config.from_date:
break

loaded += 1
with self._progress.update():
self._progress.status = "Exporting messages..."
self._progress.messages_exported = loaded

if self.status.approx_messages_count is None:
self.status.approx_messages_count = message.id
self.status.last_message_id = message.id
self.status.last_date = message.date
with self.progress.update():
self.progress.status = "Exporting messages..."
self.progress.messages_exported = loaded

if message.media:
self._progress.status = "Downloading media..."
medias.append(message.id)
medias.append(f"{message.id}_thumb")
await self._export_media(message)

if not message.text and not message.caption and message.id not in self._media:
continue

self._messages.append(message)
if len(self._messages) > 5000:
self._progress.status = "Writing messages to file..."
await self._saver.save()
if len(self._messages) > 1000:
await self._write(medias)

if self._messages:
self._progress.status = "Writing messages to file..."
await self._saver.save()
self.status = self._task = None
await self._write(medias)
self._task = None

self._progress.status = "Done!"
self.progress.status = "Stopping media downloader..."
await self._media_downloader.stop()
self.progress.status = "Done!"

async def export(self, block: bool=True) -> None:
if self._task is not None or self.status is not None:
if self._task is not None:
return
self.status = ExportStatus()
coro = self._export(self._config.chat_id)
if block:
await coro
Expand Down
9 changes: 6 additions & 3 deletions texport/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ async def _main(session_name: str, api_id: int, api_hash: str, config: ExportCon
async with Client(f"{Path.home()}/.texport/{session_name}", api_id=api_id, api_hash=api_hash) as client:
exporter = Exporter(client, config)
await exporter.export()
print("Export complete!")
if config.print:
print("Export complete!")


@click.command()
Expand All @@ -44,11 +45,12 @@ async def _main(session_name: str, api_id: int, api_hash: str, config: ExportCon
@click.option("--stickers/--no-stickers", default=True, help="Download stickers or not.")
@click.option("--gifs/--no-gifs", default=True, help="Download gifs or not.")
@click.option("--documents/--no-documents", default=True, help="Download documents or not.")
@click.option("--quiet", default=False, help="Do not print progress to console.")
@click.option("--quiet", "-q", is_flag=True, default=False, help="Do not print progress to console.")
@click.option("--no-preload", is_flag=True, default=False, help="Do not preload all messages.")
def main(
session_name: str, api_id: int, api_hash: str, chat_id: str, output: str, size_limit: int, from_date: str,
to_date: str, photos: bool, videos: bool, voice: bool, video_notes: bool, stickers: bool, gifs: bool,
documents: bool, quiet: bool,
documents: bool, quiet: bool, no_preload: bool,
) -> None:
home = Path.home()
texport_dir = home / ".texport"
Expand All @@ -69,6 +71,7 @@ def main(
export_gifs=gifs,
export_files=documents,
print=not quiet,
preload=not no_preload,
)

if session_name.endswith(".session"):
Expand Down
72 changes: 72 additions & 0 deletions texport/media_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import asyncio
from asyncio import sleep
from os.path import relpath
from typing import Union, Optional

from pyrogram import Client
from pyrogram.errors import RPCError

from .export_config import ExportConfig
from .progress_print import ProgressPrint


class MediaExporter:
def __init__(self, client: Client, config: ExportConfig, media_dict: dict, progress: ProgressPrint):
self.client = client
self.config = config
self.output = media_dict
self.task = None
self.queue: list[tuple[str, str, Union[str, int]]] = []
self.ids: set[Union[str, int]] = set()
self.all_ids: set[Union[str, int]] = set()
self.progress = progress

self._running = False

def add(self, file_id: str, download_dir: str, out_id: Union[str, int]) -> None:
if out_id in self.all_ids: return
self.queue.append((file_id, download_dir, out_id))
self.ids.add(out_id)
self.all_ids.add(out_id)
self._status()

async def _download(self, file_id: str, download_dir: str, out_id: Union[str, int]) -> None:
try:
path = await self.client.download_media(file_id, file_name=download_dir)
except RPCError:
return
path = relpath(path, self.config.output_dir.absolute())
self.output[out_id] = path

def _status(self, status: str=None) -> None:
with self.progress.update():
self.progress.media_status = status or self.progress.media_status
self.progress.media_queue = len(self.queue)

async def _task(self) -> None:
while self._running:
if not self.queue:
self._status("Idle...")
await sleep(.1)
continue
self._status("Downloading...")
await self._download(*self.queue[0])
_, _, task_id = self.queue.pop(0)
self.ids.discard(task_id)

self._status("Stopped...")

async def run(self) -> None:
self._running = True
self.task = asyncio.get_event_loop().create_task(self._task())

async def stop(self) -> None:
await self.wait()
self._running = False

async def wait(self, messages: Optional[list[int]]=None) -> None:
messages = set(messages) if messages is not None else None
while self._running and self.queue:
if messages is not None and not messages.intersection(self.ids):
break
await sleep(.1)
54 changes: 54 additions & 0 deletions texport/messages_preloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from asyncio import sleep, get_event_loop

from pyrogram import Client
from pyrogram.types import Message as PyroMessage

from .progress_print import ProgressPrint


class Preloader:
def __init__(self, client: Client, progress: ProgressPrint, media_cb):
self.client = client
self.progress = progress
self.finished = False
self.messages: list[PyroMessage] = []
self.messages_loaded = 0
self.media_cb = media_cb

self._task = None
self._pyro_args = ()
self._pyro_kwargs = {}

def __call__(self, *pyrogram_args, **pyrogram_kwargs):
self._pyro_args = pyrogram_args
self._pyro_kwargs = pyrogram_kwargs
return self

def __aiter__(self):
return self

async def _preload(self) -> None:
async for message in self.client.get_chat_history(*self._pyro_args, **self._pyro_kwargs):
self.messages.append(message)
self.messages_loaded += 1

if message.media and self.media_cb:
await self.media_cb(message)

with self.progress.update():
self.progress.status = "Preloading messages and media..."
self.progress.messages_loaded = self.messages_loaded

self.finished = True

async def __anext__(self) -> PyroMessage:
if self._task is None: self._task = get_event_loop().create_task(self._preload())

while not self.finished and not self.messages:
await sleep(.01)

if self.finished and not self.messages:
raise StopAsyncIteration

return self.messages.pop(0)

2 changes: 2 additions & 0 deletions texport/messages_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def _save(self) -> None:
output = Export(prev.chat.first_name, output).to_html()
with open(f"{out_dir}/messages{self.part}.html", "w", encoding="utf8") as f:
f.write(output)

self.part += 1

async def save(self) -> None:
loop = get_running_loop()
Expand Down
Loading

0 comments on commit 4b9a2b1

Please sign in to comment.