Skip to content

Commit aa8592a

Browse files
author
kyobrien
committed
Minor refactor
2 parents f242347 + 6af7d01 commit aa8592a

37 files changed

+1263
-618
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ repos:
2424
hooks:
2525
- id: codespell
2626
# The promptsource templates spuriously get flagged without this
27-
args: ["--skip=*.yaml"]
27+
args: ["-L fpr", "--skip=*.yaml"]

.vscode/launch.json

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
"type": "python",
1010
"request": "launch",
1111
"module": "elk",
12-
"args": ["elicit", "RWKV", "imdb", "--max_examples=5"],
12+
"args": ["elicit", "rwkv", "imdb", "--max_examples=5"],
1313
"env": {
1414
"CUDA_VISIBLE_DEVICES": "0",
1515
},
1616
"justMyCode": true
1717
}
1818
]
19-
}
19+
}

README.md

+6
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@ The following command will evaluate the probe from the run naughty-northcutt on
3232
elk eval naughty-northcutt microsoft/deberta-v2-xxlarge-mnli imdb
3333
```
3434

35+
The following runs `elicit` on the Cartesian product of the listed models and datasets, storing it in a special folder ELK_DIR/sweeps/<memorable_name>. Moreover, `--add_pooled` adds an additional dataset that pools all of the datasets together.
36+
37+
```bash
38+
elk sweep --models gpt2-{medium,large,xl} --datasets imdb amazon_polarity --add_pooled
39+
```
40+
3541
## Caching
3642

3743
The hidden states resulting from `elk elicit` are cached as a HuggingFace dataset to avoid having to recompute them every time we want to train a probe. The cache is stored in the same place as all other HuggingFace datasets, which is usually `~/.cache/huggingface/datasets`.

elk/__main__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,22 @@
55
from simple_parsing import ArgumentParser
66

77
from elk.evaluation.evaluate import Eval
8+
from elk.training.sweep import Sweep
89
from elk.training.train import Elicit
910

1011

1112
@dataclass
1213
class Command:
1314
"""Some top-level command"""
1415

15-
command: Elicit | Eval
16+
command: Elicit | Eval | Sweep
1617

1718
def execute(self):
1819
return self.command.execute()
1920

2021

2122
def run():
22-
parser = ArgumentParser(add_help=False, add_config_path_arg=True)
23+
parser = ArgumentParser(add_help=False)
2324
parser.add_arguments(Command, dest="run")
2425
args = parser.parse_args()
2526
run: Command = args.run

elk/debug_logging.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@ def save_debug_log(datasets: list[DatasetDict], out_dir: Path) -> None:
2727
)
2828

2929
train_split, val_split = select_train_val_splits(ds)
30-
text_inputs = ds[val_split][0]["text_inputs"]
30+
text_questions = ds[val_split][0]["text_questions"]
3131
template_ids = ds[val_split][0]["variant_ids"]
3232
label = ds[val_split][0]["label"]
3333

3434
# log the train size and val size
3535
logging.info(f"Train size: {len(ds[train_split])}")
3636
logging.info(f"Val size: {len(ds[val_split])}")
3737

38-
templates_text = f"{len(text_inputs)} templates used:\n"
38+
templates_text = f"{len(text_questions)} templates used:\n"
3939
trailing_whitespace = False
40-
for (text0, text1), id in zip(text_inputs, template_ids):
40+
for (text0, text1), id in zip(text_questions, template_ids):
4141
templates_text += (
4242
f'***---TEMPLATE "{id}"---***\n'
4343
f"{'false' if label else 'true'}:\n"

elk/evaluation/evaluate.py

+14-20
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
from ..extraction.extraction import Extract
1111
from ..files import elk_reporter_dir
12+
from ..metrics import evaluate_preds
1213
from ..run import Run
1314
from ..training import Reporter
14-
from ..training.supervised import evaluate_supervised
1515
from ..utils import select_usable_devices
1616

1717

@@ -43,6 +43,8 @@ class Eval(Serializable):
4343
out_dir: Path | None = None
4444
skip_supervised: bool = False
4545

46+
disable_cache: bool = field(default=False, to_dict=False)
47+
4648
def execute(self):
4749
transfer_dir = elk_reporter_dir() / self.source / "transfer_eval"
4850

@@ -69,34 +71,26 @@ def evaluate_reporter(
6971
reporter.eval()
7072

7173
row_buf = []
72-
for ds_name, (val_x0, val_x1, val_gt, _) in val_output.items():
73-
val_result = reporter.score(
74-
val_gt,
75-
val_x0,
76-
val_x1,
77-
)
78-
79-
stats_row = pd.Series(
80-
{
81-
"dataset": ds_name,
82-
"layer": layer,
83-
**val_result._asdict(),
84-
}
85-
)
74+
for ds_name, (val_h, val_gt, _) in val_output.items():
75+
val_result = evaluate_preds(val_gt, reporter(val_h))
76+
77+
stats_row = {
78+
"dataset": ds_name,
79+
"layer": layer,
80+
**val_result.to_dict(),
81+
}
8682

8783
lr_dir = experiment_dir / "lr_models"
8884
if not self.cfg.skip_supervised and lr_dir.exists():
8985
with open(lr_dir / f"layer_{layer}.pt", "rb") as f:
9086
lr_model = torch.load(f, map_location=device).eval()
9187

92-
lr_auroc, lr_acc = evaluate_supervised(lr_model, val_x0, val_x1, val_gt)
93-
94-
stats_row["lr_auroc"] = lr_auroc
95-
stats_row["lr_acc"] = lr_acc
88+
lr_result = evaluate_preds(val_gt, lr_model(val_h))
89+
stats_row.update(lr_result.to_dict(prefix="lr_"))
9690

9791
row_buf.append(stats_row)
9892

99-
return pd.DataFrame(row_buf)
93+
return pd.DataFrame.from_records(row_buf)
10094

10195
def evaluate(self):
10296
"""Evaluate the reporter on all layers."""

elk/extraction/balanced_sampler.py

+33-23
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import deque
2+
from dataclasses import dataclass, field
23
from itertools import cycle
34
from random import Random
45
from typing import Iterable, Iterator, Optional
@@ -11,39 +12,48 @@
1112
from ..utils.typing import assert_type
1213

1314

15+
@dataclass
1416
class BalancedSampler(TorchIterableDataset):
1517
"""
16-
Approximately balances a binary classification dataset in a streaming fashion.
17-
18-
Args:
19-
dataset (IterableDataset): The HuggingFace IterableDataset to balance.
20-
label_col (Optional[str], optional): The name of the column containing the
21-
binary label. If not provided, the label column will be inferred from
22-
the dataset features. Defaults to None.
23-
buffer_size (int, optional): The total buffer size to use for balancing the
24-
dataset. This value should be divisible by 2, as it will be equally
25-
divided between the two binary label values (0 and 1). Defaults to 1000.
18+
A sampler that approximately balances a multi-class classification dataset in a
19+
streaming fashion.
20+
21+
Attributes:
22+
data: The input dataset to balance.
23+
num_classes: The total number of classes expected in the data.
24+
buffer_size: The total buffer size to use for balancing the dataset. Each class
25+
will have its own buffer with this size.
2626
"""
2727

28-
def __init__(self, data: Iterable[dict], buffer_size: int = 1000):
29-
self.data = data
28+
data: Iterable[dict]
29+
num_classes: int
30+
buffer_size: int = 1000
31+
buffers: dict[int, deque[dict]] = field(default_factory=dict, init=False)
32+
label_col: str = "label"
3033

31-
self.neg_buffer = deque(maxlen=buffer_size)
32-
self.pos_buffer = deque(maxlen=buffer_size)
34+
def __post_init__(self):
35+
# Initialize empty buffers
36+
self.buffers = {
37+
label: deque(maxlen=self.buffer_size) for label in range(self.num_classes)
38+
}
3339

3440
def __iter__(self):
3541
for sample in self.data:
36-
label = sample["label"]
42+
label = sample[self.label_col]
3743

38-
# Add the sample to the appropriate buffer
39-
if label == 0:
40-
self.neg_buffer.append(sample)
41-
else:
42-
self.pos_buffer.append(sample)
44+
# This whole class is a no-op if the label is not an integer
45+
if not isinstance(label, int):
46+
yield sample
47+
continue
48+
49+
# Add the sample to the buffer for its class label
50+
self.buffers[label].append(sample)
4351

44-
while self.neg_buffer and self.pos_buffer:
45-
yield self.neg_buffer.popleft()
46-
yield self.pos_buffer.popleft()
52+
# Check if all buffers have at least one sample
53+
while all(len(buffer) > 0 for buffer in self.buffers.values()):
54+
# Yield one sample from each buffer in a round-robin fashion
55+
for buf in self.buffers.values():
56+
yield buf.popleft()
4757

4858

4959
class FewShotSampler:

0 commit comments

Comments
 (0)