Skip to content

Commit 3ca0ce4

Browse files
authored
Add TOVA press (#12)
* Add TOVA press * update README * update docstring * Address PR comment
1 parent 896f610 commit 3ca0ce4

File tree

6 files changed

+59
-6
lines changed

6 files changed

+59
-6
lines changed

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,11 @@ All current presses are training free. We provide the following presses associat
5555

5656
- `RandomPress`: random score
5757
- `KnormPress`: inverse norm of the key ([paper](https://arxiv.org/abs/2406.11430))
58-
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048) or [TOVA](https://arxiv.org/abs/2401.06104))
58+
- `ObservedAttentionPress`: average attention weight observed during in pre-filling phase (similar to [H2O](https://arxiv.org/abs/2306.14048))
5959
- `SnapKVPress`: average attention weight of the last 64 queries ([paper](https://arxiv.org/abs/2404.14469))
6060
- `ExpectedAttentionPress` (ours): expected attention weight during the generation phase (see [this notebook](notebooks/expected_attention.ipynb))
6161
- `StreamingLLMPress`: keep only the first and last tokens ([paper](https://arxiv.org/abs/2309.17453))
62+
- `TOVAPress`: attention weight of the last query averaged across heads ([paper](https://arxiv.org/abs/2401.06104))
6263

6364
For a detailed list of existing KV cache compression methods, check [Awesome-KV-Cache-Compression](https://github.com/October2001/Awesome-KV-Cache-Compression) or [Awesome-LLM-Compression](https://github.com/HuangOwen/Awesome-LLM-Compression?tab=readme-ov-file#kv-cache-compression)
6465

@@ -186,5 +187,3 @@ press = apply_per_layer_compression(press, compression_ratios=[...])
186187

187188
Check the [demo notebook](notebooks/per_layer_compression_demo.ipynb) for more details.
188189
</details>
189-
190-
<details><summary>

kvpress/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from kvpress.presses.random_press import RandomPress
1212
from kvpress.presses.snapkv_press import SnapKVPress
1313
from kvpress.presses.streaming_llm_press import StreamingLLMPress
14+
from kvpress.presses.tova_press import TOVAPress
1415

1516
__all__ = [
1617
"BasePress",
@@ -20,6 +21,7 @@
2021
"RandomPress",
2122
"SnapKVPress",
2223
"StreamingLLMPress",
24+
"TOVAPress",
2325
"KVPressTextGenerationPipeline",
2426
"apply_per_layer_compression",
2527
]

kvpress/presses/observed_attention_press.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class ObservedAttentionPress(BasePress):
1818
"""The observed attention score is defined as the average attention weight over all prompt tokens
1919
Requires output_attentions=True and attn_implementation="eager" to have access to attentions
20-
This approach is related to H2O (https://arxiv.org/abs/2306.14048) and TOVA (https://arxiv.org/abs/2401.06104)
20+
This approach is related to H2O (https://arxiv.org/abs/2306.14048).
2121
"""
2222

2323
compression_ratio: float = 0.0

kvpress/presses/tova_press.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from dataclasses import dataclass
5+
6+
import torch
7+
from torch import nn
8+
import torch.nn.functional as F
9+
10+
from kvpress.presses.snapkv_press import SnapKVPress
11+
12+
13+
@dataclass
14+
class TOVAPress(SnapKVPress):
15+
"""
16+
TOVA (https://arxiv.org/abs/2401.06104) use the attention of the last token averaged across heads
17+
to estimate the importance of the previous KV pairs. This press was reviewed by Michael Hassid,
18+
one of the authors of the TOVA paper.
19+
20+
Official implementation can be found here: https://github.com/schwartz-lab-NLP/TOVA/blob/main/src/tova_cache.py
21+
"""
22+
23+
compression_ratio: float = 0.0
24+
window_size: int = 1 # re-use the attention weight computation from SnapKVPress for last token
25+
26+
def score(
27+
self,
28+
module: nn.Module,
29+
hidden_states: torch.Tensor,
30+
keys: torch.Tensor,
31+
values: torch.Tensor,
32+
attentions: torch.Tensor,
33+
kwargs,
34+
) -> torch.Tensor:
35+
36+
if attentions is not None:
37+
attn_weights = attentions[..., -1:, :-1]
38+
else:
39+
attn_weights = self.compute_window_attention(module, hidden_states, keys)
40+
41+
# Average across heads and repeat num_key_value_head times
42+
scores = attn_weights.mean(1)
43+
scores = scores.repeat(1, keys.shape[1], 1)
44+
45+
# Add back the last token. Use max score to make sure the window is not pruned.
46+
# This is a very slight difference from TOVA that don't enforce it, but the
47+
# last attention weight is usually very high so it should not change the results.
48+
scores = F.pad(scores, (0, 1), value=scores.max().item())
49+
50+
return scores

notebooks/wikipedia_demo.ipynb

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
},
1010
{
1111
"cell_type": "code",
12-
"execution_count": 1,
12+
"execution_count": null,
1313
"metadata": {},
1414
"outputs": [],
1515
"source": [
@@ -26,6 +26,7 @@
2626
" RandomPress,\n",
2727
" SnapKVPress,\n",
2828
" StreamingLLMPress,\n",
29+
" TOVAPress,\n",
2930
")"
3031
]
3132
},

tests/presses/test_presses.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,13 @@
1414
RandomPress,
1515
SnapKVPress,
1616
StreamingLLMPress,
17+
TOVAPress,
1718
)
1819
from tests.fixtures import unit_test_model, unit_test_model_output_attention # noqa: F401
1920

2021

2122
def test_presses_run(unit_test_model): # noqa: F811
22-
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress]:
23+
for cls in [KnormPress, ExpectedAttentionPress, RandomPress, StreamingLLMPress, SnapKVPress, TOVAPress]:
2324
for compression_ratio in [0.2, 0.4, 0.6, 0.8]:
2425
press = cls(compression_ratio=compression_ratio)
2526
if cls == SnapKVPress:

0 commit comments

Comments
 (0)