From 5571101a50804406ef0fe23e7ea6795b3c4a1bcb Mon Sep 17 00:00:00 2001 From: Milo Cress Date: Sun, 9 Jun 2024 10:50:54 -0400 Subject: [PATCH] fix linting (#1270) * fix linting * fix --- llmfoundry/data/dataloader.py | 4 ++-- llmfoundry/utils/config_utils.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py index 83a9a7d8ea..e7521bc343 100644 --- a/llmfoundry/data/dataloader.py +++ b/llmfoundry/data/dataloader.py @@ -3,7 +3,7 @@ """Dataloader builder utilities.""" -from typing import Any, Dict +from typing import Any, Dict, Union from composer import DataSpec from transformers import PreTrainedTokenizerBase @@ -19,7 +19,7 @@ def build_dataloader( cfg: Dict[str, Any], tokenizer: PreTrainedTokenizerBase, - device_batch_size: int, + device_batch_size: Union[int, float], ) -> DataSpec: """Builds a dataloader from a config. diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 5ab148bbe8..5c1ec9114a 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -100,7 +100,7 @@ class TrainConfig: optimizer: Dict[str, Any] = MISSING scheduler: Dict[str, Any] = MISSING train_loader: Dict[str, Any] = MISSING - device_train_batch_size: int = MISSING + device_train_batch_size: Union[int, float] = MISSING device_eval_batch_size: int = MISSING max_duration: Union[int, str] = MISSING eval_interval: Union[int, str] = MISSING @@ -183,7 +183,6 @@ class TrainConfig: # Fields created by `update_batch_size_info` n_gpus: int = MISSING - device_train_batch_size: int = MISSING device_train_grad_accum: str = MISSING