|
1 | 1 | from collections import deque
|
| 2 | +from dataclasses import dataclass, field |
2 | 3 | from itertools import cycle
|
3 | 4 | from random import Random
|
4 | 5 | from typing import Iterable, Iterator, Optional
|
|
11 | 12 | from ..utils.typing import assert_type
|
12 | 13 |
|
13 | 14 |
|
| 15 | +@dataclass |
14 | 16 | class BalancedSampler(TorchIterableDataset):
|
15 | 17 | """
|
16 |
| - Approximately balances a binary classification dataset in a streaming fashion. |
17 |
| -
|
18 |
| - Args: |
19 |
| - dataset (IterableDataset): The HuggingFace IterableDataset to balance. |
20 |
| - label_col (Optional[str], optional): The name of the column containing the |
21 |
| - binary label. If not provided, the label column will be inferred from |
22 |
| - the dataset features. Defaults to None. |
23 |
| - buffer_size (int, optional): The total buffer size to use for balancing the |
24 |
| - dataset. This value should be divisible by 2, as it will be equally |
25 |
| - divided between the two binary label values (0 and 1). Defaults to 1000. |
| 18 | + A sampler that approximately balances a multi-class classification dataset in a |
| 19 | + streaming fashion. |
| 20 | +
|
| 21 | + Attributes: |
| 22 | + data: The input dataset to balance. |
| 23 | + num_classes: The total number of classes expected in the data. |
| 24 | + buffer_size: The total buffer size to use for balancing the dataset. Each class |
| 25 | + will have its own buffer with this size. |
26 | 26 | """
|
27 | 27 |
|
28 |
| - def __init__(self, data: Iterable[dict], buffer_size: int = 1000): |
29 |
| - self.data = data |
| 28 | + data: Iterable[dict] |
| 29 | + num_classes: int |
| 30 | + buffer_size: int = 1000 |
| 31 | + buffers: dict[int, deque[dict]] = field(default_factory=dict, init=False) |
| 32 | + label_col: str = "label" |
30 | 33 |
|
31 |
| - self.neg_buffer = deque(maxlen=buffer_size) |
32 |
| - self.pos_buffer = deque(maxlen=buffer_size) |
| 34 | + def __post_init__(self): |
| 35 | + # Initialize empty buffers |
| 36 | + self.buffers = { |
| 37 | + label: deque(maxlen=self.buffer_size) for label in range(self.num_classes) |
| 38 | + } |
33 | 39 |
|
34 | 40 | def __iter__(self):
|
35 | 41 | for sample in self.data:
|
36 |
| - label = sample["label"] |
| 42 | + label = sample[self.label_col] |
37 | 43 |
|
38 |
| - # Add the sample to the appropriate buffer |
39 |
| - if label == 0: |
40 |
| - self.neg_buffer.append(sample) |
41 |
| - else: |
42 |
| - self.pos_buffer.append(sample) |
| 44 | + # This whole class is a no-op if the label is not an integer |
| 45 | + if not isinstance(label, int): |
| 46 | + yield sample |
| 47 | + continue |
| 48 | + |
| 49 | + # Add the sample to the buffer for its class label |
| 50 | + self.buffers[label].append(sample) |
43 | 51 |
|
44 |
| - while self.neg_buffer and self.pos_buffer: |
45 |
| - yield self.neg_buffer.popleft() |
46 |
| - yield self.pos_buffer.popleft() |
| 52 | + # Check if all buffers have at least one sample |
| 53 | + while all(len(buffer) > 0 for buffer in self.buffers.values()): |
| 54 | + # Yield one sample from each buffer in a round-robin fashion |
| 55 | + for buf in self.buffers.values(): |
| 56 | + yield buf.popleft() |
47 | 57 |
|
48 | 58 |
|
49 | 59 | class FewShotSampler:
|
|
0 commit comments