Skip to content

Commit 76b13ca

Browse files
kristian-georgievAlaa Khaddaj
and
Alaa Khaddaj
authored
0.3.0 (#50)
* allow scoring with features only (gradients deleted) * migrate to black codestyle * update scores_finalized in JSON file * blockwise scoring (relevant when scoring large datasets, i.e. many targets) * move vectorize f-n to projectors; fast projector (incoming feature) will not vectorize * save on I/O overhead by only writing once to disk when scoring * custom CudaProjector for large models to avoid overflow error in CUDA kernel * allow computing TRAK with respect to specified parameter groups --------- Co-authored-by: Alaa Khaddaj <alaakh@mit.edu>
1 parent 5cbe528 commit 76b13ca

31 files changed

+2796
-1491
lines changed

.git-blame-ignore-revs

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# https://black.readthedocs.io/en/stable/guides/introducing_black_to_your_project.html
2+
# Migrate code style to Black
3+
3141015f3687dc11c311f1270c7dff80f1299fe3

.github/workflows/python-package.yml

-5
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@ jobs:
3333
uses: actions/setup-python@v3
3434
with:
3535
python-version: ${{ matrix.python-version }}
36-
- name: cuda-toolkit
37-
uses: Jimver/cuda-toolkit@v0.2.8
38-
id: cuda-toolkit
39-
with:
40-
cuda: '11.7.0'
4136
- name: Install dependencies
4237
run: |
4338
python -m pip install --upgrade pip

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
[![arXiv](https://img.shields.io/badge/arXiv-2303.14186-b31b1b.svg?style=flat-square)](https://arxiv.org/abs/2303.14186)
22
[![PyPI version](https://badge.fury.io/py/traker.svg)](https://badge.fury.io/py/traker)
33
[![Documentation Status](https://readthedocs.org/projects/trak/badge/?version=latest)](https://trak.readthedocs.io/en/latest/?badge=latest)
4+
[![Code style:
5+
black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
6+
47

58
# TRAK: Attributing Model Behavior at Scale
69

docs/source/bert.rst

+57-30
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ to fit our API signatures.
6363
6464
We slightly redefine the :code:`forward` function so that we can pass in the inputs (:code:`input_ids`, etc.) as positional arguments instead of as keyword arguments.
6565

66-
For data loading, we adapt the code from Hugging Face example:
66+
For data loading, we adapt the code from the HuggingFace example:
6767

6868
.. raw:: html
6969

@@ -132,7 +132,7 @@ For data loading, we adapt the code from Hugging Face example:
132132
133133
# NOTE: CHANGE THIS IF YOU WANT TO RUN ON FULL DATASET
134134
TRAIN_SET_SIZE = 5_000
135-
VAL_SET_SIZE = 1_00
135+
VAL_SET_SIZE = 10
136136
137137
def init_loaders(batch_size=16):
138138
ds_train = get_dataset('train')
@@ -180,38 +180,59 @@ The model output function is implemented as follows:
180180

181181
.. code-block:: python
182182
183-
def get_output(func_model,
184-
weights: Iterable[Tensor],
185-
buffers: Iterable[Tensor],
186-
input_id: Tensor,
187-
token_type_id: Tensor,
188-
attention_mask: Tensor,
189-
label: Tensor,
190-
) -> Tensor:
191-
logits = func_model(weights, buffers, input_id.unsqueeze(0),
192-
token_type_id.unsqueeze(0),
193-
attention_mask.unsqueeze(0))
183+
def get_output(
184+
model,
185+
weights: Iterable[Tensor],
186+
buffers: Iterable[Tensor],
187+
input_id: Tensor,
188+
token_type_id: Tensor,
189+
attention_mask: Tensor,
190+
label: Tensor,
191+
) -> Tensor:
192+
kw_inputs = {
193+
"input_ids": input_id.unsqueeze(0),
194+
"token_type_ids": token_type_id.unsqueeze(0),
195+
"attention_mask": attention_mask.unsqueeze(0),
196+
}
197+
198+
logits = ch.func.functional_call(
199+
model, (weights, buffers), args=(), kwargs=kw_inputs
200+
)
194201
bindex = ch.arange(logits.shape[0]).to(logits.device, non_blocking=False)
195202
logits_correct = logits[bindex, label.unsqueeze(0)]
196203
197204
cloned_logits = logits.clone()
198-
cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(-ch.inf).to(logits.device)
205+
cloned_logits[bindex, label.unsqueeze(0)] = ch.tensor(
206+
-ch.inf, device=logits.device, dtype=logits.dtype
207+
)
199208
200209
margins = logits_correct - cloned_logits.logsumexp(dim=-1)
201210
return margins.sum()
202211
203-
The implementation is identical to the standard classification example in :ref:`MODELOUTPUT tutorial`,
204-
except here the signature of the method and the :code:`func_model` is slightly different
205-
as the language model takes in three inputs instead of just one.
212+
The implementation is identical to the standard classification example in
213+
:ref:`MODELOUTPUT tutorial`, except here the signature of the method and the
214+
:code:`func_model` is slightly different as the language model takes in three
215+
inputs instead of just one.
206216

207217
Similarly, the gradient function is implemented as follows:
208218

209219
.. code-block:: python
210220
211-
def get_out_to_loss_grad(self, func_model, weights, buffers, batch: Iterable[Tensor]) -> Tensor:
221+
def get_out_to_loss_grad(
222+
self, model, weights, buffers, batch: Iterable[Tensor]
223+
) -> Tensor:
212224
input_ids, token_type_ids, attention_mask, labels = batch
213-
logits = func_model(weights, buffers, input_ids, token_type_ids, attention_mask)
214-
ps = self.softmax(logits / self.loss_temperature)[ch.arange(logits.size(0)), labels]
225+
kw_inputs = {
226+
"input_ids": input_ids,
227+
"token_type_ids": token_type_ids,
228+
"attention_mask": attention_mask,
229+
}
230+
logits = ch.func.functional_call(
231+
model, (weights, buffers), args=(), kwargs=kw_inputs
232+
)
233+
ps = self.softmax(logits / self.loss_temperature)[
234+
ch.arange(logits.size(0)), labels
235+
]
215236
return (1 - ps).clone().detach().unsqueeze(-1)
216237
217238
Putting it together
@@ -221,12 +242,14 @@ Using the above :code:`TextClassificationModelOutput` implementation, we can com
221242

222243
.. code-block:: python
223244
224-
traker = TRAKer(model=model,
225-
task=TextClassificationModelOutput, # you can also just pass in "text_classification"
226-
train_set_size=TRAIN_SET_SIZE,
227-
save_dir=args.out,
228-
device=device,
229-
proj_dim=1024)
245+
traker = TRAKer(
246+
model=model,
247+
task=TextClassificationModelOutput, # you can also just pass in "text_classification"
248+
train_set_size=TRAIN_SET_SIZE,
249+
save_dir=SAVE_DIR,
250+
device=DEVICE,
251+
proj_dim=1024,
252+
)
230253
231254
def process_batch(batch):
232255
return batch['input_ids'], batch['token_type_ids'], batch['attention_mask'], batch['labels']
@@ -235,18 +258,21 @@ Using the above :code:`TextClassificationModelOutput` implementation, we can com
235258
for batch in tqdm(loader_train, desc='Featurizing..'):
236259
# process batch into compatible form for TRAKer TextClassificationModelOutput
237260
batch = process_batch(batch)
238-
batch = [x.cuda() for x in batch]
261+
batch = [x.to(DEVICE) for x in batch]
239262
traker.featurize(batch=batch, num_samples=batch[0].shape[0])
240263
241264
traker.finalize_features()
242265
243-
traker.start_scoring_checkpoint(model.state_dict(), model_id=0, num_targets=VAL_SET_SIZE)
266+
traker.start_scoring_checkpoint(exp_name='qnli',
267+
checkpoint=model.state_dict(),
268+
model_id=0,
269+
num_targets=VAL_SET_SIZE)
244270
for batch in tqdm(loader_val, desc='Scoring..'):
245271
batch = process_batch(batch)
246272
batch = [x.cuda() for x in batch]
247273
traker.score(batch=batch, num_samples=batch[0].shape[0])
248274
249-
scores = traker.finalize_scores()
275+
scores = traker.finalize_scores(exp_name='qnli')
250276
251277
We use :code:`process_batch` to transform the batch from dictionary (which is the form used by Hugging Face dataloaders) to a tuple.
252278

@@ -256,4 +282,5 @@ That's all! You can find this tutorial as a complete script in `here <https://gi
256282
Extending to other tasks
257283
----------------------------------
258284

259-
For a more involved example that is *not* classification, see :ref:`CLIP tutorial`.
285+
For a more involved example that is *not* classification, see :ref:`CLIP
286+
tutorial`.

docs/source/clip.rst

+12-7
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ Now we are ready to implement :meth:`.CLIPModelOutput.get_output`:
6161
buffers: Iterable[Tensor],
6262
image: Tensor,
6363
label: Tensor):
64+
# tailored for open_clip
65+
# https://github.com/mlfoundations/open_clip/blob/fb72f4db1b17133befd6c67c9cf32a533b85a321/src/open_clip/model.py#L242-L245
66+
clip_inputs = {"image": image.unsqueeze(0), "text": label.unsqueeze(0)}
6467
image_embeddings, text_embeddings, _ = ch.func.functional_call(model,
6568
(weights, buffers),
6669
args=(),
@@ -116,24 +119,26 @@ Using the above :code:`CLIPModelOutput` implementation, we can compute
116119
device=device,
117120
proj_dim=1024)
118121
119-
traker.task.get_embeddings(model, loader_train, batch_size=...,
122+
traker.task.get_embeddings(model, ds_train, batch_size=1, size=600, embedding_dim=1024,
120123
preprocess_fn_img=lambda x: preprocess(x).to(device).unsqueeze(0),
121124
preprocess_fn_txt=lambda x: tokenizer(x[0]).to(device))
122125
123126
traker.load_checkpoint(model.state_dict(), model_id=0)
124-
for batch in tqdm(loader_train, desc='Featurizing...'):
125-
batch = [x.cuda() for x in batch]
126-
traker.featurize(batch=batch, num_samples=batch[0].shape[0])
127+
for (img, captions) in tqdm(loader_train, desc='Featurizing...'):
128+
x = preprocess(img).to('cuda').unsqueeze(0)
129+
y = tokenizer(captions).to('cuda')
130+
traker.featurize(batch=(x, y), num_samples=x.shape[0])
127131
128132
traker.finalize_features()
129133
130134
traker.start_scoring_checkpoint(exp_name='clip_example',
131135
checkpoint=model.state_dict(),
132136
model_id=0,
133137
num_targets=VAL_SET_SIZE)
134-
for batch in tqdm(loader_val, desc='Scoring...'):
135-
batch = [x.cuda() for x in batch]
136-
traker.score(batch=batch, num_samples=batch[0].shape[0])
138+
for (img, captions) in tqdm(loader_val, desc='Scoring...'):
139+
x = preprocess(img).to('cuda').unsqueeze(0)
140+
y = tokenizer(captions).to('cuda')
141+
traker.score(batch=(x, y), num_samples=x.shape[0])
137142
138143
scores = traker.finalize_scores(exp_name='clip_example')
139144

docs/source/conf.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
author = 'Kristian Georgiev'
2323

2424
# The full version, including alpha/beta/rc tags
25-
release = '0.2.2'
26-
version = '0.2.2'
25+
release = '0.3.0'
26+
version = '0.3.0'
2727

2828

2929
# -- General configuration ---------------------------------------------------
@@ -46,11 +46,13 @@
4646
# This pattern also affects html_static_path and html_extra_path.
4747
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
4848

49+
4950
def skip(app, what, name, obj, would_skip, options):
5051
if name == "__init__":
5152
return False
5253
return would_skip
5354

55+
5456
def setup(app):
5557
app.connect("autodoc-skip-member", skip)
5658

docs/source/modeloutput.rst

-3
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,6 @@ the :code:`task` when instantiating :class:`.TRAKer`:
8080
def get_output(...):
8181
# Implement
8282
83-
def forward(...):
84-
# Implement
85-
8683
def get_out_to_loss_grad(...):
8784
# Implement
8885

setup.py

+28-26
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
11
#!/usr/bin/env python
22
from setuptools import setup
33

4-
setup(name="traker",
5-
version="0.2.2",
6-
description="TRAK: Attributing Model Behavior at Scale",
7-
long_description="Check https://trak.csail.mit.edu/ to learn more about TRAK",
8-
author="MadryLab",
9-
author_email='trak@mit.edu',
10-
license_files=('LICENSE.txt', ),
11-
packages=['trak'],
12-
install_requires=[
13-
"torch>=2.0.0",
14-
"numpy",
15-
"tqdm",
16-
],
17-
extras_require={
18-
'tests':
19-
["assertpy",
20-
"torchvision",
21-
"open_clip_torch",
22-
"wget",
23-
"scipy",
24-
],
25-
'fast':
26-
["fast_jl"
27-
]},
28-
include_package_data=True,
29-
)
4+
setup(
5+
name="traker",
6+
version="0.3.0",
7+
description="TRAK: Attributing Model Behavior at Scale",
8+
long_description="Check https://trak.csail.mit.edu/ to learn more about TRAK",
9+
author="MadryLab",
10+
author_email="trak@mit.edu",
11+
license_files=("LICENSE.txt",),
12+
packages=["trak"],
13+
install_requires=[
14+
"torch>=2.0.0",
15+
"numpy",
16+
"tqdm",
17+
],
18+
extras_require={
19+
"tests": [
20+
"assertpy",
21+
"torchvision",
22+
"open_clip_torch",
23+
"wget",
24+
"scipy",
25+
"datasets",
26+
"transformers",
27+
],
28+
"fast": ["fast_jl"],
29+
},
30+
include_package_data=True,
31+
)

tests/autocast.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -23,33 +23,33 @@ def compute_loss_autocast(params, inputs, targets):
2323

2424
print("1. Without autocast")
2525
grads = ch.func.grad(compute_loss)(weights, inputs, targets)
26-
print(f'grads are {grads}')
26+
print(f"grads are {grads}")
2727
print(f"grads dtype: {grads['weight'].dtype}")
28-
print('='*50)
28+
print("=" * 50)
2929

3030
inputs = inputs.half()
3131
targets = targets.half()
3232

33-
print('2. With autocast for forward pass')
33+
print("2. With autocast for forward pass")
3434
grads = ch.func.grad(compute_loss_autocast)(weights, inputs, targets)
35-
print(f'grads are {grads}')
35+
print(f"grads are {grads}")
3636
print(f"grads dtype: {grads['weight'].dtype}")
37-
print('='*50)
37+
print("=" * 50)
3838

39-
print('3. With autocast for forward pass and backward pass')
39+
print("3. With autocast for forward pass and backward pass")
4040
with autocast(device_type="cuda", dtype=ch.float16):
4141
grads = ch.func.grad(compute_loss)(weights, inputs, targets)
42-
print(f'inside grads are {grads}')
42+
print(f"inside grads are {grads}")
4343
print(f"inside grads dtype: {grads['weight'].dtype}")
44-
print('exiting autocast')
45-
print(f'grads are {grads}')
44+
print("exiting autocast")
45+
print(f"grads are {grads}")
4646
print(f"grads dtype: {grads['weight'].dtype}")
47-
print('='*50)
47+
print("=" * 50)
4848

49-
print('4. .half() the model')
49+
print("4. .half() the model")
5050
model = model.half()
5151
grads = ch.func.grad(compute_loss)(weights, inputs, targets)
52-
print(f'grads are {grads}')
52+
print(f"grads are {grads}")
5353
print(f"grads dtype: {grads['weight'].dtype}")
5454

5555
"""

0 commit comments

Comments
 (0)