Skip to content

Commit

Permalink
Add batch size tuning for LLMs (#3871)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Arnav Garg <arnav@predibase.com>
  • Loading branch information
3 people authored Jan 22, 2024
1 parent 138cc4a commit 91c8975
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 10 deletions.
4 changes: 2 additions & 2 deletions ludwig/models/embedder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Dict, List
from typing import Callable, Dict, List, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self):
self.embedder = embedder.to(self.device)
self.embedder.eval()

def step(self, batch_size: int):
def step(self, batch_size: int, global_max_sequence_length: Optional[int] = None):
inputs = {
input_feature_name: input_feature.create_sample_input(batch_size=batch_size).to(self.device)
for input_feature_name, input_feature in self.embedder.input_features.items()
Expand Down
2 changes: 1 addition & 1 deletion ludwig/models/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(self):
self.model = model.to(get_torch_device())
self.samples = samples

def step(self, batch_size: int):
def step(self, batch_size: int, global_max_sequence_length: Optional[int] = None):
self.model.encode(self.samples[:batch_size], batch_size=batch_size, show_progress_bar=False)

return _RetrievalModelEvaluator
Expand Down
7 changes: 4 additions & 3 deletions ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,7 @@ def tune_batch_size(
snapshot_weights: bool = True,
on_best_batch_size_updated: Optional[Callable[[int, float, int], None]] = None,
tune_for_training: bool = True,
global_max_sequence_length: Optional[int] = None,
) -> int:
logger.info("Tuning batch size...")
skip_save_model = self.skip_save_model
Expand Down Expand Up @@ -592,7 +593,7 @@ def tune_batch_size(
checkpoint.save(os.path.join(tmpdir, "latest.ckpt"), global_step=0)
try:
best_batch_size = evaluator.select_best_batch_size(
len(training_set), max_batch_size, max_trials, self.is_coordinator()
len(training_set), max_batch_size, max_trials, self.is_coordinator(), global_max_sequence_length
)
best_batch_size = self.distributed.broadcast_object(best_batch_size)

Expand Down Expand Up @@ -626,7 +627,7 @@ def reset(self):
trainer.model.reset_metrics()
trainer.optimizer.zero_grad()

def step(self, batch_size: int):
def step(self, batch_size: int, global_max_sequence_length: Optional[int] = None):
trainer.distributed.set_batch_size(trainer.dist_model, batch_size)
inputs = {
input_feature_name: input_feature.create_sample_input(batch_size=batch_size).to(trainer.device)
Expand All @@ -648,7 +649,7 @@ def reset(self):
trainer.model.reset_metrics()
trainer.optimizer.zero_grad()

def step(self, batch_size: int):
def step(self, batch_size: int, global_max_sequence_length: Optional[int] = None):
trainer.distributed.set_batch_size(trainer.dist_model, batch_size)
inputs = {
input_feature_name: input_feature.create_sample_input(batch_size=batch_size).to(trainer.device)
Expand Down
78 changes: 78 additions & 0 deletions ludwig/trainers/trainer_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from typing import Callable, Dict, List, Optional, Union

import torch
from torch.utils.tensorboard import SummaryWriter

from ludwig.constants import MINIMUM_BATCH_SIZE, TEST, TRAINING, VALIDATION
Expand All @@ -18,6 +19,7 @@
from ludwig.trainers.trainer import Trainer
from ludwig.types import ModelConfigDict
from ludwig.utils import time_utils
from ludwig.utils.batch_size_tuner import BatchSizeEvaluator
from ludwig.utils.defaults import default_random_seed
from ludwig.utils.metric_utils import TrainerMetric
from ludwig.utils.metrics_printed_table import print_metrics_table
Expand Down Expand Up @@ -471,6 +473,82 @@ def evaluation(self, dataset, dataset_name, metrics_log, batch_size, progress_tr
progress_tracker.llm_eval_examples = llm_eval_examples
return append_metrics(self.model, dataset_name, metrics, metrics_log, progress_tracker)

def tune_batch_size(
self,
config: ModelConfigDict,
training_set: Dataset,
random_seed: int = default_random_seed,
max_trials: int = 20,
halving_limit: int = 3,
snapshot_weights: bool = True,
on_best_batch_size_updated: Optional[Callable[[int, float, int], None]] = None,
tune_for_training: bool = True,
global_max_sequence_length: Optional[int] = None,
) -> int:
if global_max_sequence_length is None:
global_max_sequence_length = self.model.global_max_sequence_length
return super().tune_batch_size(
config,
training_set,
random_seed,
max_trials,
halving_limit,
snapshot_weights,
on_best_batch_size_updated,
tune_for_training,
global_max_sequence_length,
)

def _create_batch_size_evaluator(self) -> BatchSizeEvaluator:
trainer = self

class _TrainerBatchSizeEvaluator(BatchSizeEvaluator):
def __init__(self):
self.input_feature_name, self.input_feature = trainer.model.input_features.items()[0]
self.output_feature_name, self.output_feature = trainer.model.output_features.items()[0]

# Get the length of the longest input sequence from the training data
self.input_msl = self.input_feature.input_shape[0]
# Get the length of the longest output sequence from the training data
self.output_msl = self.output_feature.output_shape[0]
# max_sequence_length here is the smaller value between the global max sequence length of the model
# and the model's context length
if trainer.model.config_obj.output_features[0].preprocessing.max_sequence_length:
self.output_msl = trainer.model.config_obj.output_features[0].preprocessing.max_sequence_length

# This is useful to create the synthetic input and target data which will be a
# random sequence of integers between 0 and vocab_size
self.vocab_size = len(trainer.model.config_obj.input_features[0].encoder.vocab)

def reset(self):
trainer.model.reset_metrics()
trainer.optimizer.zero_grad()

def step(self, batch_size: int, global_max_sequence_length: int):
trainer.distributed.set_batch_size(trainer.dist_model, batch_size)

input_msl = self.input_msl
output_msl = self.output_msl
if self.input_msl + self.output_msl > global_max_sequence_length:
# In this case, we just need to make sure that the length of the synthetic data exceeds
# max_sequence_length by at most a small amount
input_msl = global_max_sequence_length // 2 + 1
output_msl = global_max_sequence_length // 2 + 1

inputs = {
self.input_feature_name: torch.randint(0, self.vocab_size, size=(batch_size, input_msl))
.to(self.input_feature.input_dtype)
.to(trainer.device)
}
targets = {
self.output_feature_name: torch.randint(0, self.vocab_size, size=(batch_size, output_msl))
.to(self.output_feature.get_output_dtype())
.to(trainer.device)
}
trainer.train_step(inputs, targets)

return _TrainerBatchSizeEvaluator()


class RemoteLLMTrainer(NoneTrainer):
def __init__(self, gpus=None, gpu_memory_limit=None, allow_parallel_threads=True, **kwargs):
Expand Down
15 changes: 11 additions & 4 deletions ludwig/utils/batch_size_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

logger = logging.getLogger(__name__)

TOTAL_STEPS = 5


@DeveloperAPI
class BatchSizeEvaluator(ABC):
Expand All @@ -21,6 +23,7 @@ def select_best_batch_size(
max_batch_size: Optional[int] = None,
max_trials: int = 20,
is_coordinator: Optional[bool] = True,
global_max_sequence_length: Optional[int] = None,
) -> int:
"""Returns optimal batch size as measured by throughput (samples / sec)."""
logger.info("Tuning batch size...")
Expand Down Expand Up @@ -51,7 +54,9 @@ def _is_valid_batch_size(batch_size):
gc.collect()

try:
samples_per_sec = self.evaluate(batch_size, total_steps=5)
samples_per_sec = self.evaluate(
batch_size, total_steps=TOTAL_STEPS, global_max_sequence_length=global_max_sequence_length
)
if is_coordinator:
logger.info(f"Throughput at batch_size={batch_size}: {samples_per_sec:.5f} samples/s")
if samples_per_sec < best_samples_per_sec:
Expand Down Expand Up @@ -88,7 +93,9 @@ def _is_valid_batch_size(batch_size):
logger.info(f"Selected batch_size={best_batch_size}")
return best_batch_size

def evaluate(self, batch_size: int, total_steps: int = 5) -> float:
def evaluate(
self, batch_size: int, total_steps: int = 5, global_max_sequence_length: Optional[int] = None
) -> float:
"""Evaluates throughput of the given batch size.
Return:
Expand All @@ -98,7 +105,7 @@ def evaluate(self, batch_size: int, total_steps: int = 5) -> float:
for _ in range(total_steps):
self.reset()
start_ts = time.time()
self.step(batch_size)
self.step(batch_size, global_max_sequence_length=global_max_sequence_length)
durations.append(time.time() - start_ts)

med_duration_s = statistics.median(durations)
Expand All @@ -111,6 +118,6 @@ def reset(self):
"""Called at the beginning of each evaluation step."""
pass

def step(self, batch_size: int):
def step(self, batch_size: int, global_max_sequence_length: Optional[int] = None):
"""Called each step to evaluate the given batch size."""
raise NotImplementedError("`step` must be implemented by concrete evaluator.")
34 changes: 34 additions & 0 deletions tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pandas as pd
import pytest
import torch
import yaml

import ludwig.error as ludwig_error
from ludwig.api import LudwigModel
Expand Down Expand Up @@ -1254,6 +1255,39 @@ def test_llm_encoding(llm_encoder_config, adapter, quantization, tmpdir):
model.train(dataset=dataset_path, output_directory=str(tmpdir))


def test_llm_batch_size_tuning():
dataset = pd.DataFrame({"instruction": ["a"] * 100, "output": ["a"] * 100})
config = yaml.safe_load(
"""
model_type: llm
input_features:
- name: instruction
type: text
output_features:
- name: output
type: text
prompt:
template: >-
{instruction}
adapter:
type: lora
trainer:
type: finetune
optimizer:
type: adam
train_steps: 1
learning_rate: 0.0002
eval_batch_size: 2
backend:
type: local
base_model: HuggingFaceH4/tiny-random-LlamaForCausalLM
"""
)
model = LudwigModel(config=config)
model.train(dataset=dataset)
assert model.config_obj.trainer.batch_size > 1


@pytest.mark.llm
def test_llm_used_tokens(tmpdir):
input_features = [text_feature(name="input", encoder={"type": "passthrough"})]
Expand Down

0 comments on commit 91c8975

Please sign in to comment.