Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Excessive warnings when resuming an IterableDataset+buffered shuffle+DDP. #7444

Open
dhruvdcoder opened this issue Mar 11, 2025 · 0 comments

Comments

@dhruvdcoder
Copy link

dhruvdcoder commented Mar 11, 2025

Describe the bug

I have a large dataset that I shared into 1024 shards and save on the disk during pre-processing. During training, I load the dataset using load_from_disk() and convert it into an iterable dataset, shuffle it and split the shards to different DDP nodes using the recommended method.

However, when the training is resumed mid-epoch, I get thousands of identical warning messages:

Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.

Steps to reproduce the bug

  1. Run a multi-node training job using the following python script and interrupt the training after a few seconds to save a mid-epoch checkpoint.
#!/usr/bin/env python
import os
import time
from typing import Dict, List
import torch
import lightning as pl
from torch.utils.data import DataLoader
from datasets import Dataset
from datasets.distributed import split_dataset_by_node
import datasets
from transformers import AutoTokenizer
from more_itertools import flatten, chunked
from torchdata.stateful_dataloader import StatefulDataLoader
from lightning.pytorch.callbacks.on_exception_checkpoint import (
    OnExceptionCheckpoint,
)

datasets.logging.set_verbosity_debug()


def dummy_generator():
    # Generate 60 examples: integers from $0$ to $59$
    # 64 sequences of different lengths
    dataset = [
        list(range(3, 10)),
        list(range(10, 15)),
        list(range(15, 21)),
        list(range(21, 27)),
        list(range(27, 31)),
        list(range(31, 36)),
        list(range(36, 45)),
        list(range(45, 50)),
    ]
    for i in range(8):
        for j, ids in enumerate(dataset):
            yield {"token_ids": [idx + i * 50 for idx in ids]}


def group_texts(
    examples: Dict[str, List[List[int]]],
    block_size: int,
    eos_token_id: int,
    bos_token_id: int,
    pad_token_id: int,
) -> Dict[str, List[List[int]]]:
    real_block_size = block_size - 2  # make space for bos and eos
    # colapse the sequences into a single list of tokens and then create blocks of real_block_size
    input_ids = []
    attention_mask = []
    for block in chunked(flatten(examples["token_ids"]), real_block_size):
        s = [bos_token_id] + list(block) + [eos_token_id]
        ls = len(s)
        attn = [True] * ls
        s += [pad_token_id] * (block_size - ls)
        attn += [False] * (block_size - ls)
        input_ids.append(s)
        attention_mask.append(attn)

    return {"input_ids": input_ids, "attention_mask": attention_mask}


def collate_fn(batch):
    return {
        "input_ids": torch.tensor(
            [item["input_ids"] for item in batch], dtype=torch.long
        ),
        "attention_mask": torch.tensor(
            [item["attention_mask"] for item in batch], dtype=torch.long
        ),
    }


class DummyModule(pl.LightningModule):

    def __init__(self):
        super().__init__()
        # A dummy linear layer (not used for actual computation)
        self.layer = torch.nn.Linear(1, 1)
        self.ds = None
        self.prepare_data_per_node = False

    def on_train_start(self):
        # This hook is called once training begins on each process.
        print(f"[Rank {self.global_rank}] Training started.", flush=True)
        self.data_file = open(f"data_{self.global_rank}.txt", "w")

    def on_train_end(self):
        self.data_file.close()

    def training_step(self, batch, batch_idx):
        # Print batch information to verify data loading.
        time.sleep(5)
        # print("batch", batch, flush=True)
        print(
            f"\n[Rank {self.global_rank}] Training step, epoch {self.trainer.current_epoch}, batch {batch_idx}: {batch['input_ids']}",
            flush=True,
        )
        self.data_file.write(
            f"[Rank {self.global_rank}] Training step, epoch {self.trainer.current_epoch}, batch {batch_idx}: {batch['input_ids']}\n"
        )
        # Compute a dummy loss (here, simply a constant tensor)
        loss = torch.tensor(0.0, requires_grad=True)
        return loss

    def on_train_epoch_start(self):
        epoch = self.trainer.current_epoch
        print(
            f"[Rank {self.global_rank}] Training epoch {epoch} started.",
            flush=True,
        )
        self.data_file.write(
            f"[Rank {self.global_rank}] Training epoch {epoch} started.\n"
        )

    def configure_optimizers(self):
        # Return a dummy optimizer.
        return torch.optim.SGD(self.parameters(), lr=0.001)


class DM(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.ds = None
        self.prepare_data_per_node = False

    def set_epoch(self, epoch: int):
        self.ds.set_epoch(epoch)

    def prepare_data(self):
        # download the dataset
        dataset = Dataset.from_generator(dummy_generator)
        # save the dataset
        dataset.save_to_disk("dataset", num_shards=4)

    def setup(self, stage: str):
        # load the dataset
        ds = datasets.load_from_disk("dataset").to_iterable_dataset(
            num_shards=4
        )
        ds = ds.map(
            group_texts,
            batched=True,
            batch_size=5,
            fn_kwargs={
                "block_size": 5,
                "eos_token_id": 1,
                "bos_token_id": 0,
                "pad_token_id": 2,
            },
            remove_columns=["token_ids"],
        ).shuffle(seed=42, buffer_size=8)
        ds = split_dataset_by_node(
            ds,
            rank=self.trainer.global_rank,
            world_size=self.trainer.world_size,
        )
        self.ds = ds

    def train_dataloader(self):
        print(
            f"[Rank {self.trainer.global_rank}] Preparing train_dataloader...",
            flush=True,
        )
        rank = self.trainer.global_rank
        print(
            f"[Rank {rank}] Global rank: {self.trainer.global_rank}",
            flush=True,
        )
        world_size = self.trainer.world_size
        print(f"[Rank {rank}] World size: {world_size}", flush=True)
        return StatefulDataLoader(
            self.ds,
            batch_size=2,
            num_workers=2,
            collate_fn=collate_fn,
            drop_last=True,
            persistent_workers=True,
        )


if __name__ == "__main__":
    print("Starting Lightning training", flush=True)
    # Optionally, print some SLURM environment info for debugging.
    print(f"SLURM_NNODES: {os.environ.get('SLURM_NNODES', '1')}", flush=True)

    # Determine the number of nodes from SLURM (defaulting to 1 if not set)
    num_nodes = int(os.environ.get("SLURM_NNODES", "1"))
    model = DummyModule()
    dm = DM()
    on_exception = OnExceptionCheckpoint(
        dirpath="checkpoints",
        filename="on_exception",
    )

    # Configure the Trainer to use distributed data parallel (DDP).
    trainer = pl.Trainer(
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        devices=1,
        strategy=(
            "ddp" if num_nodes > 1 else "auto"
        ),  # Use DDP strategy for multi-node training.
        num_nodes=num_nodes,
        max_epochs=2,
        logger=False,
        enable_checkpointing=True,
        num_sanity_val_steps=0,
        enable_progress_bar=False,
        callbacks=[on_exception],
    )

    # resume (uncomment to resume)
    # trainer.fit(model, datamodule=dm, ckpt_path="checkpoints/on_exception.ckpt")
    # train

    trainer.fit(model, datamodule=dm)
#!/bin/bash
#SBATCH --job-name=pl_ddp_test
#SBATCH --nodes=2                     # Adjust number of nodes as needed
#SBATCH --ntasks-per-node=1           # One GPU (process) per node
#SBATCH --cpus-per-task=3             # At least as many dataloader workers as required
#SBATCH --gres=gpu:1                  # Request one GPU per node
#SBATCH --time=00:10:00               # Job runtime (adjust as needed)
#SBATCH --partition=gpu-preempt               # Partition or queue name
#SBATCH -o script.out

# Disable Python output buffering.
export PYTHONUNBUFFERED=1

echo "SLURM job starting on $(date)"
echo "Running on nodes: $SLURM_NODELIST"
echo "Current directory: $(pwd)"
ls -l

# Launch the script using srun so that each process starts the Lightning module.
srun script.py
  1. Uncomment the "resume" line (second to last) and comment the original trainer.fit call (last line).
    It will produce the following log.
[Rank 0] Preparing train_dataloader...
[Rank 0] Global rank: 0
[Rank 0] World size: 2
[Rank 1] Preparing train_dataloader...
[Rank 1] Global rank: 1
[Rank 1] World size: 2
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Assigning 2 shards (or data sources) of the dataset to each node.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#0 dataloader worker#1, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#0 dataloader worker#0, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Set __getitem__(key) output type to arrow for no columns  (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to arrow for no columns  (when key is int or slice) and don't output other (un-formatted) columns.
node#0 dataloader worker#1, ': Finished iterating over 1/1 shards.
node#0 dataloader worker#0, ': Finished iterating over 1/1 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
[Rank 0] Training started.
[Rank 0] Training epoch 0 started.
[Rank 0] Training epoch 1 started.
Assigning 2 shards (or data sources) of the dataset to each node.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#0 dataloader worker#1, ': Starting to iterate over 1/2 shards.
node#0 dataloader worker#0, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#1 dataloader worker#1, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#1 dataloader worker#0, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Set __getitem__(key) output type to arrow for no columns  (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to arrow for no columns  (when key is int or slice) and don't output other (un-formatted) columns.
node#0 dataloader worker#1, ': Finished iterating over 1/1 shards.
node#0 dataloader worker#0, ': Finished iterating over 1/1 shards.
`Trainer.fit` stopped: `max_epochs=2` reached.
Set __getitem__(key) output type to arrow for no columns  (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to arrow for no columns  (when key is int or slice) and don't output other (un-formatted) columns.
node#1 dataloader worker#1, ': Finished iterating over 1/1 shards.
node#1 dataloader worker#0, ': Finished iterating over 1/1 shards.
[Rank 1] Training started.
[Rank 1] Training epoch 0 started.
[Rank 1] Training epoch 1 started.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
node#1 dataloader worker#1, ': Starting to iterate over 1/2 shards.
node#1 dataloader worker#0, ': Starting to iterate over 1/2 shards.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Loading a state dict of a shuffle buffer of a dataset without the buffer content.The shuffle buffer will be refilled before starting to yield new examples.
Set __getitem__(key) output type to arrow for no columns  (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to arrow for no columns  (when key is int or slice) and don't output other (un-formatted) columns.
node#1 dataloader worker#0, ': Finished iterating over 1/1 shards.
node#1 dataloader worker#1, ': Finished iterating over 1/1 shards.

I'm also attaching the relevant state_dict to make sure that the state is being checkpointed as expected.

{'_iterator_finished': True,
                                     '_snapshot': {'_last_yielded_worker_id': 1,
                                                   '_main_snapshot': {'_IterableDataset_len_called': None,
                                                                      '_base_seed': 3992758080362545099,
                                                                      '_index_sampler_state': {'samples_yielded': 64},
                                                                      '_num_workers': 2,
                                                                      '_sampler_iter_state': None,
                                                                      '_sampler_iter_yielded': 32,
                                                                      '_shared_seed': None},
                                                   '_snapshot_step': 32,
                                                   '_worker_snapshots': {'worker_0': {'dataset_state': {'ex_iterable': {'shard_example_idx': 0,
                                                                                                                        'shard_idx': 1},
                                                                                                        'num_examples_since_previous_state': 0,
                                                                                                        'previous_state': {'shard_example_idx': 0,
                                                                                                                           'shard_idx': 1},
                                                                                                        'previous_state_example_idx': 33},
                                                                                      'fetcher_state': {'dataset_iter_state': None,
                                                                                                        'fetcher_ended': False},
                                                                                      'worker_id': 0},
                                                                         'worker_1': {'dataset_state': {'ex_iterable': {'shard_example_idx': 0,
                                                                                                                        'shard_idx': 1},
                                                                                                        'num_examples_since_previous_state': 0,
                                                                                                        'previous_state': {'shard_example_idx': 0,
                                                                                                                           'shard_idx': 1},
                                                                                                        'previous_state_example_idx': 33},
                                                                                      'fetcher_state': {'dataset_iter_state': None,
                                                                                                        'fetcher_ended': False},
                                                                                      'worker_id': 1}}},
                                     '_steps_since_snapshot': 0}

Expected behavior

Since I'm following all the recommended steps, I don't expect to see any warning when resuming. Am I doing something wrong? Also, can someone explain why I'm seeing 20 identical messages in the log in this reproduction setting? I'm trying to understand why I see thousands of these messages with the actual dataset.

One more surprising thing I noticed in the logs is the change in a number of shards per worker. In the following messages, the denominator changes from 2 to 1.

node#1 dataloader worker#1, ': Starting to iterate over 1/2 shards.
...
node#1 dataloader worker#1, ': Finished iterating over 1/1 shards.

Environment info

python: 3.11.10
datasets: 3.3.2
lightning: 2.3.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant