From 91c8975f162415bdfbb7fc83b8bf4d12c4afca23 Mon Sep 17 00:00:00 2001 From: Timothy <72055086+Infernaught@users.noreply.github.com> Date: Mon, 22 Jan 2024 17:02:39 -0500 Subject: [PATCH] Add batch size tuning for LLMs (#3871) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Arnav Garg --- ludwig/models/embedder.py | 4 +- ludwig/models/retrieval.py | 2 +- ludwig/trainers/trainer.py | 7 +-- ludwig/trainers/trainer_llm.py | 78 +++++++++++++++++++++++++++++ ludwig/utils/batch_size_tuner.py | 15 ++++-- tests/integration_tests/test_llm.py | 34 +++++++++++++ 6 files changed, 130 insertions(+), 10 deletions(-) diff --git a/ludwig/models/embedder.py b/ludwig/models/embedder.py index 7909762f962..a5a152e7c93 100644 --- a/ludwig/models/embedder.py +++ b/ludwig/models/embedder.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Optional import numpy as np import pandas as pd @@ -73,7 +73,7 @@ def __init__(self): self.embedder = embedder.to(self.device) self.embedder.eval() - def step(self, batch_size: int): + def step(self, batch_size: int, global_max_sequence_length: Optional[int] = None): inputs = { input_feature_name: input_feature.create_sample_input(batch_size=batch_size).to(self.device) for input_feature_name, input_feature in self.embedder.input_features.items() diff --git a/ludwig/models/retrieval.py b/ludwig/models/retrieval.py index 46920643238..553fe85f3a6 100644 --- a/ludwig/models/retrieval.py +++ b/ludwig/models/retrieval.py @@ -186,7 +186,7 @@ def __init__(self): self.model = model.to(get_torch_device()) self.samples = samples - def step(self, batch_size: int): + def step(self, batch_size: int, global_max_sequence_length: Optional[int] = None): self.model.encode(self.samples[:batch_size], batch_size=batch_size, show_progress_bar=False) return _RetrievalModelEvaluator diff --git a/ludwig/trainers/trainer.py b/ludwig/trainers/trainer.py index 66ff61c194b..777c3d8602d 100644 --- a/ludwig/trainers/trainer.py +++ b/ludwig/trainers/trainer.py @@ -552,6 +552,7 @@ def tune_batch_size( snapshot_weights: bool = True, on_best_batch_size_updated: Optional[Callable[[int, float, int], None]] = None, tune_for_training: bool = True, + global_max_sequence_length: Optional[int] = None, ) -> int: logger.info("Tuning batch size...") skip_save_model = self.skip_save_model @@ -592,7 +593,7 @@ def tune_batch_size( checkpoint.save(os.path.join(tmpdir, "latest.ckpt"), global_step=0) try: best_batch_size = evaluator.select_best_batch_size( - len(training_set), max_batch_size, max_trials, self.is_coordinator() + len(training_set), max_batch_size, max_trials, self.is_coordinator(), global_max_sequence_length ) best_batch_size = self.distributed.broadcast_object(best_batch_size) @@ -626,7 +627,7 @@ def reset(self): trainer.model.reset_metrics() trainer.optimizer.zero_grad() - def step(self, batch_size: int): + def step(self, batch_size: int, global_max_sequence_length: Optional[int] = None): trainer.distributed.set_batch_size(trainer.dist_model, batch_size) inputs = { input_feature_name: input_feature.create_sample_input(batch_size=batch_size).to(trainer.device) @@ -648,7 +649,7 @@ def reset(self): trainer.model.reset_metrics() trainer.optimizer.zero_grad() - def step(self, batch_size: int): + def step(self, batch_size: int, global_max_sequence_length: Optional[int] = None): trainer.distributed.set_batch_size(trainer.dist_model, batch_size) inputs = { input_feature_name: input_feature.create_sample_input(batch_size=batch_size).to(trainer.device) diff --git a/ludwig/trainers/trainer_llm.py b/ludwig/trainers/trainer_llm.py index 287cbe6bd7a..3db12947ab6 100644 --- a/ludwig/trainers/trainer_llm.py +++ b/ludwig/trainers/trainer_llm.py @@ -3,6 +3,7 @@ import time from typing import Callable, Dict, List, Optional, Union +import torch from torch.utils.tensorboard import SummaryWriter from ludwig.constants import MINIMUM_BATCH_SIZE, TEST, TRAINING, VALIDATION @@ -18,6 +19,7 @@ from ludwig.trainers.trainer import Trainer from ludwig.types import ModelConfigDict from ludwig.utils import time_utils +from ludwig.utils.batch_size_tuner import BatchSizeEvaluator from ludwig.utils.defaults import default_random_seed from ludwig.utils.metric_utils import TrainerMetric from ludwig.utils.metrics_printed_table import print_metrics_table @@ -471,6 +473,82 @@ def evaluation(self, dataset, dataset_name, metrics_log, batch_size, progress_tr progress_tracker.llm_eval_examples = llm_eval_examples return append_metrics(self.model, dataset_name, metrics, metrics_log, progress_tracker) + def tune_batch_size( + self, + config: ModelConfigDict, + training_set: Dataset, + random_seed: int = default_random_seed, + max_trials: int = 20, + halving_limit: int = 3, + snapshot_weights: bool = True, + on_best_batch_size_updated: Optional[Callable[[int, float, int], None]] = None, + tune_for_training: bool = True, + global_max_sequence_length: Optional[int] = None, + ) -> int: + if global_max_sequence_length is None: + global_max_sequence_length = self.model.global_max_sequence_length + return super().tune_batch_size( + config, + training_set, + random_seed, + max_trials, + halving_limit, + snapshot_weights, + on_best_batch_size_updated, + tune_for_training, + global_max_sequence_length, + ) + + def _create_batch_size_evaluator(self) -> BatchSizeEvaluator: + trainer = self + + class _TrainerBatchSizeEvaluator(BatchSizeEvaluator): + def __init__(self): + self.input_feature_name, self.input_feature = trainer.model.input_features.items()[0] + self.output_feature_name, self.output_feature = trainer.model.output_features.items()[0] + + # Get the length of the longest input sequence from the training data + self.input_msl = self.input_feature.input_shape[0] + # Get the length of the longest output sequence from the training data + self.output_msl = self.output_feature.output_shape[0] + # max_sequence_length here is the smaller value between the global max sequence length of the model + # and the model's context length + if trainer.model.config_obj.output_features[0].preprocessing.max_sequence_length: + self.output_msl = trainer.model.config_obj.output_features[0].preprocessing.max_sequence_length + + # This is useful to create the synthetic input and target data which will be a + # random sequence of integers between 0 and vocab_size + self.vocab_size = len(trainer.model.config_obj.input_features[0].encoder.vocab) + + def reset(self): + trainer.model.reset_metrics() + trainer.optimizer.zero_grad() + + def step(self, batch_size: int, global_max_sequence_length: int): + trainer.distributed.set_batch_size(trainer.dist_model, batch_size) + + input_msl = self.input_msl + output_msl = self.output_msl + if self.input_msl + self.output_msl > global_max_sequence_length: + # In this case, we just need to make sure that the length of the synthetic data exceeds + # max_sequence_length by at most a small amount + input_msl = global_max_sequence_length // 2 + 1 + output_msl = global_max_sequence_length // 2 + 1 + + inputs = { + self.input_feature_name: torch.randint(0, self.vocab_size, size=(batch_size, input_msl)) + .to(self.input_feature.input_dtype) + .to(trainer.device) + } + targets = { + self.output_feature_name: torch.randint(0, self.vocab_size, size=(batch_size, output_msl)) + .to(self.output_feature.get_output_dtype()) + .to(trainer.device) + } + trainer.train_step(inputs, targets) + + return _TrainerBatchSizeEvaluator() + class RemoteLLMTrainer(NoneTrainer): def __init__(self, gpus=None, gpu_memory_limit=None, allow_parallel_threads=True, **kwargs): diff --git a/ludwig/utils/batch_size_tuner.py b/ludwig/utils/batch_size_tuner.py index 74ccc777b51..e55187e0d7d 100644 --- a/ludwig/utils/batch_size_tuner.py +++ b/ludwig/utils/batch_size_tuner.py @@ -12,6 +12,8 @@ logger = logging.getLogger(__name__) +TOTAL_STEPS = 5 + @DeveloperAPI class BatchSizeEvaluator(ABC): @@ -21,6 +23,7 @@ def select_best_batch_size( max_batch_size: Optional[int] = None, max_trials: int = 20, is_coordinator: Optional[bool] = True, + global_max_sequence_length: Optional[int] = None, ) -> int: """Returns optimal batch size as measured by throughput (samples / sec).""" logger.info("Tuning batch size...") @@ -51,7 +54,9 @@ def _is_valid_batch_size(batch_size): gc.collect() try: - samples_per_sec = self.evaluate(batch_size, total_steps=5) + samples_per_sec = self.evaluate( + batch_size, total_steps=TOTAL_STEPS, global_max_sequence_length=global_max_sequence_length + ) if is_coordinator: logger.info(f"Throughput at batch_size={batch_size}: {samples_per_sec:.5f} samples/s") if samples_per_sec < best_samples_per_sec: @@ -88,7 +93,9 @@ def _is_valid_batch_size(batch_size): logger.info(f"Selected batch_size={best_batch_size}") return best_batch_size - def evaluate(self, batch_size: int, total_steps: int = 5) -> float: + def evaluate( + self, batch_size: int, total_steps: int = 5, global_max_sequence_length: Optional[int] = None + ) -> float: """Evaluates throughput of the given batch size. Return: @@ -98,7 +105,7 @@ def evaluate(self, batch_size: int, total_steps: int = 5) -> float: for _ in range(total_steps): self.reset() start_ts = time.time() - self.step(batch_size) + self.step(batch_size, global_max_sequence_length=global_max_sequence_length) durations.append(time.time() - start_ts) med_duration_s = statistics.median(durations) @@ -111,6 +118,6 @@ def reset(self): """Called at the beginning of each evaluation step.""" pass - def step(self, batch_size: int): + def step(self, batch_size: int, global_max_sequence_length: Optional[int] = None): """Called each step to evaluate the given batch size.""" raise NotImplementedError("`step` must be implemented by concrete evaluator.") diff --git a/tests/integration_tests/test_llm.py b/tests/integration_tests/test_llm.py index daed5d20931..eccb71f651c 100644 --- a/tests/integration_tests/test_llm.py +++ b/tests/integration_tests/test_llm.py @@ -10,6 +10,7 @@ import pandas as pd import pytest import torch +import yaml import ludwig.error as ludwig_error from ludwig.api import LudwigModel @@ -1254,6 +1255,39 @@ def test_llm_encoding(llm_encoder_config, adapter, quantization, tmpdir): model.train(dataset=dataset_path, output_directory=str(tmpdir)) +def test_llm_batch_size_tuning(): + dataset = pd.DataFrame({"instruction": ["a"] * 100, "output": ["a"] * 100}) + config = yaml.safe_load( + """ + model_type: llm + input_features: + - name: instruction + type: text + output_features: + - name: output + type: text + prompt: + template: >- + {instruction} + adapter: + type: lora + trainer: + type: finetune + optimizer: + type: adam + train_steps: 1 + learning_rate: 0.0002 + eval_batch_size: 2 + backend: + type: local + base_model: HuggingFaceH4/tiny-random-LlamaForCausalLM + """ + ) + model = LudwigModel(config=config) + model.train(dataset=dataset) + assert model.config_obj.trainer.batch_size > 1 + + @pytest.mark.llm def test_llm_used_tokens(tmpdir): input_features = [text_feature(name="input", encoder={"type": "passthrough"})]