Skip to content

Commit b8591a1

Browse files
committed
feat: add llama3 demo code
1 parent b1d9a68 commit b8591a1

10 files changed

+294
-2
lines changed

.flake8

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[flake8]
2+
max-line-length = 88
3+
select = C,E,F,W,B,B9
4+
ignore = E203, E501, W503, E301
5+
exclude = __init__.py

.gitignore

+195
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python
2+
# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python
3+
4+
### Python ###
5+
# Byte-compiled / optimized / DLL files
6+
__pycache__/
7+
*.py[cod]
8+
*$py.class
9+
10+
# C extensions
11+
*.so
12+
13+
# Distribution / packaging
14+
.Python
15+
build/
16+
develop-eggs/
17+
dist/
18+
downloads/
19+
eggs/
20+
.eggs/
21+
lib/
22+
lib64/
23+
parts/
24+
sdist/
25+
var/
26+
wheels/
27+
share/python-wheels/
28+
*.egg-info/
29+
.installed.cfg
30+
*.egg
31+
MANIFEST
32+
33+
# PyInstaller
34+
# Usually these files are written by a python script from a template
35+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
36+
*.manifest
37+
*.spec
38+
39+
# Installer logs
40+
pip-log.txt
41+
pip-delete-this-directory.txt
42+
43+
# Unit test / coverage reports
44+
htmlcov/
45+
.tox/
46+
.nox/
47+
.coverage
48+
.coverage.*
49+
.cache
50+
nosetests.xml
51+
coverage.xml
52+
*.cover
53+
*.py,cover
54+
.hypothesis/
55+
.pytest_cache/
56+
cover/
57+
58+
# Translations
59+
*.mo
60+
*.pot
61+
62+
# Django stuff:
63+
*.log
64+
local_settings.py
65+
db.sqlite3
66+
db.sqlite3-journal
67+
68+
# Flask stuff:
69+
instance/
70+
.webassets-cache
71+
72+
# Scrapy stuff:
73+
.scrapy
74+
75+
# Sphinx documentation
76+
docs/_build/
77+
78+
# PyBuilder
79+
.pybuilder/
80+
target/
81+
82+
# Jupyter Notebook
83+
.ipynb_checkpoints
84+
85+
# IPython
86+
profile_default/
87+
ipython_config.py
88+
89+
# pyenv
90+
# For a library or package, you might want to ignore these files since the code is
91+
# intended to run in multiple environments; otherwise, check them in:
92+
# .python-version
93+
94+
# pipenv
95+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
97+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
98+
# install all needed dependencies.
99+
#Pipfile.lock
100+
101+
# poetry
102+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103+
# This is especially recommended for binary packages to ensure reproducibility, and is more
104+
# commonly ignored for libraries.
105+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106+
#poetry.lock
107+
108+
# pdm
109+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110+
#pdm.lock
111+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112+
# in version control.
113+
# https://pdm.fming.dev/#use-with-ide
114+
.pdm.toml
115+
116+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117+
__pypackages__/
118+
119+
# Celery stuff
120+
celerybeat-schedule
121+
celerybeat.pid
122+
123+
# SageMath parsed files
124+
*.sage.py
125+
126+
# Environments
127+
.env
128+
.venv
129+
env/
130+
venv/
131+
ENV/
132+
env.bak/
133+
venv.bak/
134+
135+
# Spyder project settings
136+
.spyderproject
137+
.spyproject
138+
139+
# Rope project settings
140+
.ropeproject
141+
142+
# mkdocs documentation
143+
/site
144+
145+
# mypy
146+
.mypy_cache/
147+
.dmypy.json
148+
dmypy.json
149+
150+
# Pyre type checker
151+
.pyre/
152+
153+
# pytype static type analyzer
154+
.pytype/
155+
156+
# Cython debug symbols
157+
cython_debug/
158+
159+
# PyCharm
160+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162+
# and can be added to the global gitignore or merged into this file. For a more nuclear
163+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
164+
#.idea/
165+
166+
### Python Patch ###
167+
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168+
poetry.toml
169+
170+
# ruff
171+
.ruff_cache/
172+
173+
# LSP config files
174+
pyrightconfig.json
175+
176+
### VisualStudioCode ###
177+
.vscode/*
178+
!.vscode/settings.json
179+
!.vscode/tasks.json
180+
!.vscode/launch.json
181+
!.vscode/extensions.json
182+
!.vscode/*.code-snippets
183+
184+
# Local History for Visual Studio Code
185+
.history/
186+
187+
# Built Visual Studio Code Extensions
188+
*.vsix
189+
190+
### VisualStudioCode Patch ###
191+
# Ignore all local history of files
192+
.history
193+
.ionide
194+
195+
# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python

.isort.cfg

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[settings]
2+
profile=black
3+
multi_line_output=3

HakaseCore/llm/hakase_prompt.json

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[
2+
{
3+
"role": "system", "content": "너는 Hakase Project의 버츄얼 스트리머 Hakase이다. 유머있게 응답해라."
4+
}
5+
]

HakaseCore/llm/llama3.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import json
2+
import os.path
3+
4+
import torch
5+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
6+
7+
8+
class LLama3(object):
9+
def __init__(self, accelerate_engine: str = "cuda", debug: bool = False) -> None:
10+
self.prompt: list[dict[str, str]] = []
11+
self.model_id = "maum-ai/Llama-3-MAAL-8B-Instruct-v0.1"
12+
self.accelerate_engine = accelerate_engine
13+
if debug:
14+
match self.accelerate_engine:
15+
case "mps":
16+
print("MPS Bulit : ", torch.backends.mps.is_built())
17+
print("MPS available : ", torch.backends.mps.is_available())
18+
case "cuda":
19+
print("CUDA Bulit : ", torch.backends.cuda.is_built())
20+
case "mkl":
21+
print("MKL available : ", torch.backends.mkl.is_available())
22+
case _:
23+
raise ValueError(
24+
f"{accelerate_engine} is not a valid accelerate_engine"
25+
)
26+
27+
self.model = AutoModelForCausalLM.from_pretrained(self.model_id).to(
28+
self.accelerate_engine
29+
)
30+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
31+
self.streamer = TextStreamer(
32+
self.tokenizer, skip_prompt=True, skip_special_tokens=True
33+
)
34+
35+
def load_prompt(self) -> list[dict[str, str]]:
36+
# Get Hakase Project Path
37+
prompt_path = (
38+
os.path.join(os.path.dirname(os.path.abspath(__file__)))
39+
+ "/hakase_prompt.json"
40+
)
41+
with open(prompt_path, "r") as prompt_file:
42+
prompt = json.load(prompt_file)
43+
return prompt
44+
45+
def generate_instruction(self, instruction: str) -> None:
46+
self.prompt = self.load_prompt()
47+
self.prompt.append({"role": "user", "content": f"{instruction}"})
48+
49+
def generate_text(self, instruction: str) -> str:
50+
self.generate_instruction(instruction=instruction)
51+
inputs = self.tokenizer.apply_chat_template(
52+
self.prompt, tokenize=True, return_tensors="pt"
53+
).to(self.accelerate_engine)
54+
outputs = self.model.generate(
55+
inputs,
56+
streamer=self.streamer,
57+
max_new_tokens=1024,
58+
pad_token_id=self.tokenizer.eos_token_id,
59+
)
60+
print(outputs)

requirements-dev.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
black==24.4.2
2-
flake8==7.0.0
2+
flake8==7.0.0
3+
isort==5.13.2
4+
autoflake==2.3.1

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
fastapi==0.111.0
22
torch==2.2.2
33
torchvision==0.17.2
4-
torchaudio==2.2.2
4+
torchaudio==2.2.2
5+
transformers==4.40.1

run.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from HakaseCore.llm.llama3 import LLama3
2+
3+
core = LLama3(accelerate_engine="mps", debug=True)
4+
while 1:
5+
message = input(">> ")
6+
core.generate_text(message)

scripts/format.sh

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/bin/sh -e
2+
set -x
3+
4+
isort run.py HakaseAPI HakaseCore --force-single-line-imports
5+
autoflake --remove-all-unused-imports --recursive --remove-unused-variables run.py HakaseAPI HakaseCore --exclude=__init__.py
6+
black run.py HakaseAPI HakaseCore
7+
isort run.py HakaseAPI HakaseCore

scripts/list.sh

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/usr/bin/env bash
2+
3+
set -e
4+
set -x
5+
6+
flake8 run.py HakaseAPI HakaseCore
7+
black run.py HakaseAPI HakaseCore --check
8+
isort run.py HakaseAPI HakaseCore --check-only

0 commit comments

Comments
 (0)