You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have a large dataset that I shared into 1024 shards and save on the disk during pre-processing. During training, I load the dataset using load_from_disk() and convert it into an iterable dataset, shuffle it and split the shards to different DDP nodes using the recommended method.
However, when the training is resumed mid-epoch, I get thousands of identical warning messages:
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Steps to reproduce the bug
Run a multi-node training job using the following python script and interrupt the training after a few seconds to save a mid-epoch checkpoint.
#!/usr/bin/env pythonimportosimporttimefromtypingimportDict, Listimporttorchimportlightningasplfromtorch.utils.dataimportDataLoaderfromdatasetsimportDatasetfromdatasets.distributedimportsplit_dataset_by_nodeimportdatasetsfromtransformersimportAutoTokenizerfrommore_itertoolsimportflatten, chunkedfromtorchdata.stateful_dataloaderimportStatefulDataLoaderfromlightning.pytorch.callbacks.on_exception_checkpointimport (
OnExceptionCheckpoint,
)
datasets.logging.set_verbosity_debug()
defdummy_generator():
# Generate 60 examples: integers from $0$ to $59$# 64 sequences of different lengthsdataset= [
list(range(3, 10)),
list(range(10, 15)),
list(range(15, 21)),
list(range(21, 27)),
list(range(27, 31)),
list(range(31, 36)),
list(range(36, 45)),
list(range(45, 50)),
]
foriinrange(8):
forj, idsinenumerate(dataset):
yield {"token_ids": [idx+i*50foridxinids]}
defgroup_texts(
examples: Dict[str, List[List[int]]],
block_size: int,
eos_token_id: int,
bos_token_id: int,
pad_token_id: int,
) ->Dict[str, List[List[int]]]:
real_block_size=block_size-2# make space for bos and eos# colapse the sequences into a single list of tokens and then create blocks of real_block_sizeinput_ids= []
attention_mask= []
forblockinchunked(flatten(examples["token_ids"]), real_block_size):
s= [bos_token_id] +list(block) + [eos_token_id]
ls=len(s)
attn= [True] *lss+= [pad_token_id] * (block_size-ls)
attn+= [False] * (block_size-ls)
input_ids.append(s)
attention_mask.append(attn)
return {"input_ids": input_ids, "attention_mask": attention_mask}
defcollate_fn(batch):
return {
"input_ids": torch.tensor(
[item["input_ids"] foriteminbatch], dtype=torch.long
),
"attention_mask": torch.tensor(
[item["attention_mask"] foriteminbatch], dtype=torch.long
),
}
classDummyModule(pl.LightningModule):
def__init__(self):
super().__init__()
# A dummy linear layer (not used for actual computation)self.layer=torch.nn.Linear(1, 1)
self.ds=Noneself.prepare_data_per_node=Falsedefon_train_start(self):
# This hook is called once training begins on each process.print(f"[Rank {self.global_rank}] Training started.", flush=True)
self.data_file=open(f"data_{self.global_rank}.txt", "w")
defon_train_end(self):
self.data_file.close()
deftraining_step(self, batch, batch_idx):
# Print batch information to verify data loading.time.sleep(5)
# print("batch", batch, flush=True)print(
f"\n[Rank {self.global_rank}] Training step, epoch {self.trainer.current_epoch}, batch {batch_idx}: {batch['input_ids']}",
flush=True,
)
self.data_file.write(
f"[Rank {self.global_rank}] Training step, epoch {self.trainer.current_epoch}, batch {batch_idx}: {batch['input_ids']}\n"
)
# Compute a dummy loss (here, simply a constant tensor)loss=torch.tensor(0.0, requires_grad=True)
returnlossdefon_train_epoch_start(self):
epoch=self.trainer.current_epochprint(
f"[Rank {self.global_rank}] Training epoch {epoch} started.",
flush=True,
)
self.data_file.write(
f"[Rank {self.global_rank}] Training epoch {epoch} started.\n"
)
defconfigure_optimizers(self):
# Return a dummy optimizer.returntorch.optim.SGD(self.parameters(), lr=0.001)
classDM(pl.LightningDataModule):
def__init__(self):
super().__init__()
self.ds=Noneself.prepare_data_per_node=Falsedefset_epoch(self, epoch: int):
self.ds.set_epoch(epoch)
defprepare_data(self):
# download the datasetdataset=Dataset.from_generator(dummy_generator)
# save the datasetdataset.save_to_disk("dataset", num_shards=4)
defsetup(self, stage: str):
# load the datasetds=datasets.load_from_disk("dataset").to_iterable_dataset(
num_shards=4
)
ds=ds.map(
group_texts,
batched=True,
batch_size=5,
fn_kwargs={
"block_size": 5,
"eos_token_id": 1,
"bos_token_id": 0,
"pad_token_id": 2,
},
remove_columns=["token_ids"],
).shuffle(seed=42, buffer_size=8)
ds=split_dataset_by_node(
ds,
rank=self.trainer.global_rank,
world_size=self.trainer.world_size,
)
self.ds=dsdeftrain_dataloader(self):
print(
f"[Rank {self.trainer.global_rank}] Preparing train_dataloader...",
flush=True,
)
rank=self.trainer.global_rankprint(
f"[Rank {rank}] Global rank: {self.trainer.global_rank}",
flush=True,
)
world_size=self.trainer.world_sizeprint(f"[Rank {rank}] World size: {world_size}", flush=True)
returnStatefulDataLoader(
self.ds,
batch_size=2,
num_workers=2,
collate_fn=collate_fn,
drop_last=True,
persistent_workers=True,
)
if__name__=="__main__":
print("Starting Lightning training", flush=True)
# Optionally, print some SLURM environment info for debugging.print(f"SLURM_NNODES: {os.environ.get('SLURM_NNODES', '1')}", flush=True)
# Determine the number of nodes from SLURM (defaulting to 1 if not set)num_nodes=int(os.environ.get("SLURM_NNODES", "1"))
model=DummyModule()
dm=DM()
on_exception=OnExceptionCheckpoint(
dirpath="checkpoints",
filename="on_exception",
)
# Configure the Trainer to use distributed data parallel (DDP).trainer=pl.Trainer(
accelerator="gpu"iftorch.cuda.is_available() else"cpu",
devices=1,
strategy=(
"ddp"ifnum_nodes>1else"auto"
), # Use DDP strategy for multi-node training.num_nodes=num_nodes,
max_epochs=2,
logger=False,
enable_checkpointing=True,
num_sanity_val_steps=0,
enable_progress_bar=False,
callbacks=[on_exception],
)
# resume (uncomment to resume)# trainer.fit(model, datamodule=dm, ckpt_path="checkpoints/on_exception.ckpt")# traintrainer.fit(model, datamodule=dm)
#!/bin/bash#SBATCH --job-name=pl_ddp_test#SBATCH --nodes=2 # Adjust number of nodes as needed#SBATCH --ntasks-per-node=1 # One GPU (process) per node#SBATCH --cpus-per-task=3 # At least as many dataloader workers as required#SBATCH --gres=gpu:1 # Request one GPU per node#SBATCH --time=00:10:00 # Job runtime (adjust as needed)#SBATCH --partition=gpu-preempt # Partition or queue name#SBATCH -o script.out# Disable Python output buffering.export PYTHONUNBUFFERED=1
echo"SLURM job starting on $(date)"echo"Running on nodes: $SLURM_NODELIST"echo"Current directory: $(pwd)"
ls -l
# Launch the script using srun so that each process starts the Lightning module.
srun script.py
Uncomment the "resume" line (second to last) and comment the original trainer.fit call (last line).
It will produce the following log.
[Rank 0] Preparing train_dataloader...
[Rank 0] Global rank: 0
[Rank 0] World size: 2
[Rank 1] Preparing train_dataloader...
[Rank 1] Global rank: 1
[Rank 1] World size: 2
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Assigning 2 shards (or data sources) of the dataset to each node.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#0 dataloader worker#1, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#0 dataloader worker#0, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
node#0 dataloader worker#1, ': Finished iterating over 1/1 shards.
node#0 dataloader worker#0, ': Finished iterating over 1/1 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
[Rank 0] Training started.
[Rank 0] Training epoch 0 started.
[Rank 0] Training epoch 1 started.
Assigning 2 shards (or data sources) of the dataset to each node.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#0 dataloader worker#1, ': Starting to iterate over 1/2 shards.
node#0 dataloader worker#0, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#1 dataloader worker#1, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#1 dataloader worker#0, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
node#0 dataloader worker#1, ': Finished iterating over 1/1 shards.
node#0 dataloader worker#0, ': Finished iterating over 1/1 shards.
`Trainer.fit` stopped: `max_epochs=2` reached.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
node#1 dataloader worker#1, ': Finished iterating over 1/1 shards.
node#1 dataloader worker#0, ': Finished iterating over 1/1 shards.
[Rank 1] Training started.
[Rank 1] Training epoch 0 started.
[Rank 1] Training epoch 1 started.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#1 dataloader worker#1, ': Starting to iterate over 1/2 shards.
node#1 dataloader worker#0, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to arrow for no columns (when key is int or slice) and don't output other (un-formatted) columns.
node#1 dataloader worker#0, ': Finished iterating over 1/1 shards.
node#1 dataloader worker#1, ': Finished iterating over 1/1 shards.
I'm also attaching the relevant state_dict to make sure that the state is being checkpointed as expected.
Since I'm following all the recommended steps, I don't expect to see any warning when resuming. Am I doing something wrong? Also, can someone explain why I'm seeing 20 identical messages in the log in this reproduction setting? I'm trying to understand why I see thousands of these messages with the actual dataset.
One more surprising thing I noticed in the logs is the change in a number of shards per worker. In the following messages, the denominator changes from 2 to 1.
node#1 dataloader worker#1, ': Starting to iterate over 1/2 shards.
...
node#1 dataloader worker#1, ': Finished iterating over 1/1 shards.
Environment info
python: 3.11.10
datasets: 3.3.2
lightning: 2.3.1
The text was updated successfully, but these errors were encountered:
Describe the bug
I have a large dataset that I shared into 1024 shards and save on the disk during pre-processing. During training, I load the dataset using load_from_disk() and convert it into an iterable dataset, shuffle it and split the shards to different DDP nodes using the recommended method.
However, when the training is resumed mid-epoch, I get thousands of identical warning messages:
Steps to reproduce the bug
trainer.fit
call (last line).It will produce the following log.
I'm also attaching the relevant state_dict to make sure that the state is being checkpointed as expected.
Expected behavior
Since I'm following all the recommended steps, I don't expect to see any warning when resuming. Am I doing something wrong? Also, can someone explain why I'm seeing 20 identical messages in the log in this reproduction setting? I'm trying to understand why I see thousands of these messages with the actual dataset.
One more surprising thing I noticed in the logs is the change in a number of shards per worker. In the following messages, the denominator changes from 2 to 1.
Environment info
python: 3.11.10
datasets: 3.3.2
lightning: 2.3.1
The text was updated successfully, but these errors were encountered: