Skip to content

Commit d6fa1be

Browse files
authored
[Quality] Add code formatter and linter (vllm-project#326)
1 parent 0ffded8 commit d6fa1be

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1549
-619
lines changed

.pylintrc

+434
Large diffs are not rendered by default.

CONTRIBUTING.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,15 @@ If not, please file a new issue, providing as much relevant information as possi
4949

5050
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
5151

52+
We include a formatting script [`format.sh`](./format.sh) to format the code.
53+
5254
### Pull Requests
5355

5456
When submitting a pull request:
5557

5658
1. Make sure your code has been rebased on top of the latest commit on the main branch.
57-
2. Include a detailed description of the changes in the pull request.
59+
2. Ensure code is properly formatted by running [`format.sh`](./format.sh).
60+
3. Include a detailed description of the changes in the pull request.
5861
Explain why you made the changes you did.
5962
If your pull request fixes an open issue, please include a reference to it in the description.
6063

examples/api_client.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ def clear_line(n: int = 1) -> None:
1414
print(LINE_UP, end=LINE_CLEAR, flush=True)
1515

1616

17-
def post_http_request(prompt: str, api_url: str, n: int = 1,
17+
def post_http_request(prompt: str,
18+
api_url: str,
19+
n: int = 1,
1820
stream: bool = False) -> requests.Response:
1921
headers = {"User-Agent": "Test Client"}
2022
pload = {
@@ -30,7 +32,8 @@ def post_http_request(prompt: str, api_url: str, n: int = 1,
3032

3133

3234
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
33-
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False,
35+
for chunk in response.iter_lines(chunk_size=8192,
36+
decode_unicode=False,
3437
delimiter=b"\0"):
3538
if chunk:
3639
data = json.loads(chunk.decode("utf-8"))

examples/gradio_webserver.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,14 @@ def http_bot(prompt):
1212
"stream": True,
1313
"max_tokens": 128,
1414
}
15-
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
16-
17-
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
15+
response = requests.post(args.model_url,
16+
headers=headers,
17+
json=pload,
18+
stream=True)
19+
20+
for chunk in response.iter_lines(chunk_size=8192,
21+
decode_unicode=False,
22+
delimiter=b"\0"):
1823
if chunk:
1924
data = json.loads(chunk.decode("utf-8"))
2025
output = data["text"][0]
@@ -23,11 +28,11 @@ def http_bot(prompt):
2328

2429
def build_demo():
2530
with gr.Blocks() as demo:
26-
gr.Markdown(
27-
"# vLLM text completion demo\n"
28-
)
29-
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
30-
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
31+
gr.Markdown("# vLLM text completion demo\n")
32+
inputbox = gr.Textbox(label="Input",
33+
placeholder="Enter text and press ENTER")
34+
outputbox = gr.Textbox(label="Output",
35+
placeholder="Generated result from the model")
3136
inputbox.submit(http_bot, [inputbox], [outputbox])
3237
return demo
3338

@@ -36,7 +41,9 @@ def build_demo():
3641
parser = argparse.ArgumentParser()
3742
parser.add_argument("--host", type=str, default="localhost")
3843
parser.add_argument("--port", type=int, default=8001)
39-
parser.add_argument("--model-url", type=str, default="http://localhost:8000/generate")
44+
parser.add_argument("--model-url",
45+
type=str,
46+
default="http://localhost:8000/generate")
4047
args = parser.parse_args()
4148

4249
demo = build_demo()

examples/llm_engine_example.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,14 @@ def main(args: argparse.Namespace):
1414
("To be or not to be,",
1515
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
1616
("What is the meaning of life?",
17-
SamplingParams(n=2, best_of=5, temperature=0.8, top_p=0.95, frequency_penalty=0.1)),
17+
SamplingParams(n=2,
18+
best_of=5,
19+
temperature=0.8,
20+
top_p=0.95,
21+
frequency_penalty=0.1)),
1822
("It is only with the heart that one can see rightly",
19-
SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)),
23+
SamplingParams(n=3, best_of=3, use_beam_search=True,
24+
temperature=0.0)),
2025
]
2126

2227
# Run the engine by calling `engine.step()` manually.

examples/offline_inference.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from vllm import LLM, SamplingParams
22

3-
43
# Sample prompts.
54
prompts = [
65
"Hello, my name is",

examples/openai_client.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@
1212
# Test completion API
1313
stream = True
1414
completion = openai.Completion.create(
15-
model=model, prompt="A robot may not injure a human being", echo=False, n=2,
16-
best_of=3, stream=stream, logprobs=3)
15+
model=model,
16+
prompt="A robot may not injure a human being",
17+
echo=False,
18+
n=2,
19+
best_of=3,
20+
stream=stream,
21+
logprobs=3)
1722

1823
# print the completion
1924
if stream:

format.sh

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#!/usr/bin/env bash
2+
# YAPF formatter, adapted from ray and skypilot.
3+
#
4+
# Usage:
5+
# # Do work and commit your work.
6+
7+
# # Format files that differ from origin/main.
8+
# bash format.sh
9+
10+
# # Commit changed files with message 'Run yapf and pylint'
11+
#
12+
#
13+
# YAPF + Clang formatter (if installed). This script formats all changed files from the last mergebase.
14+
# You are encouraged to run this locally before pushing changes for review.
15+
16+
# Cause the script to exit if a single command fails
17+
set -eo pipefail
18+
19+
# this stops git rev-parse from failing if we run this from the .git directory
20+
builtin cd "$(dirname "${BASH_SOURCE:-$0}")"
21+
ROOT="$(git rev-parse --show-toplevel)"
22+
builtin cd "$ROOT" || exit 1
23+
24+
YAPF_VERSION=$(yapf --version | awk '{print $2}')
25+
PYLINT_VERSION=$(pylint --version | head -n 1 | awk '{print $2}')
26+
MYPY_VERSION=$(mypy --version | awk '{print $2}')
27+
28+
# # params: tool name, tool version, required version
29+
tool_version_check() {
30+
if [[ $2 != $3 ]]; then
31+
echo "Wrong $1 version installed: $3 is required, not $2."
32+
exit 1
33+
fi
34+
}
35+
36+
tool_version_check "yapf" $YAPF_VERSION "$(grep yapf requirements-dev.txt | cut -d'=' -f3)"
37+
tool_version_check "pylint" $PYLINT_VERSION "$(grep "pylint==" requirements-dev.txt | cut -d'=' -f3)"
38+
tool_version_check "mypy" "$MYPY_VERSION" "$(grep mypy requirements-dev.txt | cut -d'=' -f3)"
39+
40+
YAPF_FLAGS=(
41+
'--recursive'
42+
'--parallel'
43+
)
44+
45+
YAPF_EXCLUDES=(
46+
'--exclude' 'build/**'
47+
'--exclude' 'vllm/model_executor/parallel_utils/**'
48+
)
49+
50+
# Format specified files
51+
format() {
52+
yapf --in-place "${YAPF_FLAGS[@]}" "$@"
53+
}
54+
55+
# Format files that differ from main branch. Ignores dirs that are not slated
56+
# for autoformat yet.
57+
format_changed() {
58+
# The `if` guard ensures that the list of filenames is not empty, which
59+
# could cause yapf to receive 0 positional arguments, making it hang
60+
# waiting for STDIN.
61+
#
62+
# `diff-filter=ACM` and $MERGEBASE is to ensure we only format files that
63+
# exist on both branches.
64+
MERGEBASE="$(git merge-base origin/main HEAD)"
65+
66+
if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
67+
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs -P 5 \
68+
yapf --in-place "${YAPF_EXCLUDES[@]}" "${YAPF_FLAGS[@]}"
69+
fi
70+
71+
}
72+
73+
# Format all files
74+
format_all() {
75+
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
76+
}
77+
78+
## This flag formats individual files. --files *must* be the first command line
79+
## arg to use this option.
80+
if [[ "$1" == '--files' ]]; then
81+
format "${@:2}"
82+
# If `--all` is passed, then any further arguments are ignored and the
83+
# entire python directory is formatted.
84+
elif [[ "$1" == '--all' ]]; then
85+
format_all
86+
else
87+
# Format only the files that changed in last commit.
88+
format_changed
89+
fi
90+
echo 'vLLM yapf: Done'
91+
92+
# Run mypy
93+
# TODO(zhuohan): Enable mypy
94+
# echo 'vLLM mypy:'
95+
# mypy
96+
97+
# Run Pylint
98+
echo 'vLLM Pylint:'
99+
pylint vllm
100+
101+
if ! git diff --quiet &>/dev/null; then
102+
echo 'Reformatted files. Please review and stage the changes.'
103+
echo 'Changes not staged for commit:'
104+
echo
105+
git --no-pager diff --name-only
106+
107+
exit 1
108+
fi

requirements-dev.txt

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,12 @@
1-
mypy
1+
# formatting
2+
yapf==0.32.0
3+
pylint==2.8.2
4+
5+
# type checking
6+
mypy==0.991
7+
types-PyYAML
8+
types-requests
9+
types-setuptools
10+
11+
# testing
212
pytest

tests/kernels/test_attention.py

+33-20
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def ref_single_query_cached_kv_attention(
6060
keys = torch.stack(keys, dim=0)
6161
values = torch.stack(values, dim=0)
6262

63-
scale = 1.0 / (head_size ** 0.5)
63+
scale = 1.0 / (head_size**0.5)
6464
out = ref_masked_attention(q, keys, values, scale)
6565
out = out.view(num_heads, head_size)
6666
output[i].copy_(out, non_blocking=True)
@@ -74,7 +74,7 @@ def ref_multi_query_kv_attention(
7474
dtype: torch.dtype,
7575
) -> torch.Tensor:
7676
head_size = query.shape[-1]
77-
scale = 1.0 / (head_size ** 0.5)
77+
scale = 1.0 / (head_size**0.5)
7878

7979
num_seqs = len(cu_seq_lens) - 1
8080
ref_outputs = []
@@ -84,8 +84,8 @@ def ref_multi_query_kv_attention(
8484
seq_len = end_idx - start_idx
8585

8686
# Create attention mask.
87-
attn_mask = torch.triu(
88-
torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1)
87+
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
88+
diagonal=1)
8989
attn_mask = attn_mask * torch.finfo(dtype).min
9090
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
9191

@@ -113,7 +113,7 @@ def ref_multi_query_cached_kv_attention(
113113
num_heads = value_cache.shape[1]
114114
head_size = value_cache.shape[2]
115115
block_size = value_cache.shape[3]
116-
scale = 1.0 / (head_size ** 0.5)
116+
scale = 1.0 / (head_size**0.5)
117117

118118
num_queries = len(cu_query_lens) - 1
119119
ref_outputs = []
@@ -125,8 +125,8 @@ def ref_multi_query_cached_kv_attention(
125125
block_table = block_tables[i]
126126

127127
# Create attention mask
128-
attn_mask = torch.triu(
129-
torch.ones(query_len, context_len), diagonal=context_len - query_len + 1) * -1e5
128+
attn_mask = torch.triu(torch.ones(query_len, context_len),
129+
diagonal=context_len - query_len + 1) * -1e5
130130
attn_mask = attn_mask.to(dtype=dtype, device='cuda')
131131

132132
keys = []
@@ -165,22 +165,28 @@ def run_single_query_cached_kv_attention(
165165
num_blocks: int,
166166
dtype: torch.dtype,
167167
) -> None:
168-
qkv = torch.empty(
169-
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
168+
qkv = torch.empty(num_tokens,
169+
3,
170+
num_heads,
171+
head_size,
172+
dtype=dtype,
173+
device='cuda')
170174
qkv.uniform_(-1e-3, 1e-3)
171175
query, _, _ = qkv.unbind(dim=1)
172176

173177
x = 16 // torch.tensor([], dtype=dtype).element_size()
174178
key_block_shape = (num_heads, head_size // x, block_size, x)
175-
key_cache = torch.empty(
176-
size=(num_blocks, *key_block_shape), dtype=dtype, device='cuda')
179+
key_cache = torch.empty(size=(num_blocks, *key_block_shape),
180+
dtype=dtype,
181+
device='cuda')
177182
key_cache.uniform_(-1e-3, 1e-3)
178183
value_block_shape = (num_heads, head_size, block_size)
179-
value_cache = torch.empty(
180-
size=(num_blocks, *value_block_shape), dtype=dtype, device='cuda')
184+
value_cache = torch.empty(size=(num_blocks, *value_block_shape),
185+
dtype=dtype,
186+
device='cuda')
181187
value_cache.uniform_(-1e-3, 1e-3)
182188

183-
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
189+
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_tokens)]
184190
max_context_len = max(context_lens)
185191
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
186192

@@ -194,9 +200,12 @@ def run_single_query_cached_kv_attention(
194200
block_tables.append(block_table)
195201
block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
196202

197-
scale = float(1.0 / (head_size ** 0.5))
198-
output = torch.empty(
199-
num_tokens, num_heads, head_size, dtype=dtype, device='cuda')
203+
scale = float(1.0 / (head_size**0.5))
204+
output = torch.empty(num_tokens,
205+
num_heads,
206+
head_size,
207+
dtype=dtype,
208+
device='cuda')
200209
attention_ops.single_query_cached_kv_attention(
201210
output,
202211
query,
@@ -235,9 +244,13 @@ def run_multi_query_kv_attention(
235244
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
236245
num_tokens = sum(seq_lens)
237246

238-
scale = float(1.0 / (head_size ** 0.5))
239-
qkv = torch.empty(
240-
num_tokens, 3, num_heads, head_size, dtype=dtype, device='cuda')
247+
scale = float(1.0 / (head_size**0.5))
248+
qkv = torch.empty(num_tokens,
249+
3,
250+
num_heads,
251+
head_size,
252+
dtype=dtype,
253+
device='cuda')
241254
qkv.uniform_(-1e-3, 1e-3)
242255
query, key, value = qkv.unbind(dim=1)
243256

0 commit comments

Comments
 (0)