Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

📦 [SFT] Deprecate batched formatting_func #3147

Merged
merged 13 commits into from
Apr 8, 2025

Conversation

YeFD
Copy link

@YeFD YeFD commented Mar 24, 2025

What does this PR do?

When I defined formatting_func to return a list of processed strings following the documentation of SFTTrainer, an IndexError was raised.

File ~/.conda/envs/rrag/lib/python3.12/site-packages/trl/trainer/sft_trainer.py:413, in SFTTrainer._prepare_dataset(self, dataset, processing_class, args, packing, formatting_func, dataset_name)
    410 if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
    411     map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"
--> 413 batched = isinstance(formatting_func(next(iter(dataset))), list)
    415 def _func(example):
    416     return {"text": formatting_func(example)}

Cell In[17], line 4, in formatting_prompts_func(example)
      2 output_texts = []
      3 for i in range(len(example['instruction'])):
----> 4     text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
      5     output_texts.append(text)
      6 return output_texts

IndexError: string index out of range

I initially thought it was a documentation error and submitted a PR (#3141).

However, after reviewing the code, I found that it was a bug in the batched judgement. The issue was located in the following line:

batched = isinstance(formatting_func(next(iter(dataset))), list)

After fixing the bug, the code is now compatible with formatting_func returning either a processed string or a list.

def formatting_func(example):
    output_texts = []
    for i in range(len(example['instruction'])):
        text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}"
        output_texts.append(text)
    return output_texts

batched = isinstance(formatting_func(dataset), list) # True

def formatting_func(example):
    text = f"### Question: {example['instruction']}\n ### Answer: {example['output']}"
    return text

batched = isinstance(formatting_func(dataset), list) # False

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

@qgallouedec

@qgallouedec
Copy link
Member

Applying the function to the entire dataset just to check if it is batched doesn’t seem reasonable. In fact, formatting_func is quite cumbersome, and I would advise against using it. Instead, we can simply preprocess the dataset before passing it to the trainer.

That said, determining whether the function is batched is trickier than it seems. Also, I realize that we never document or test support for a non-batched function. The simplest approach would be to assume that the function is always batched, right?

@qgallouedec
Copy link
Member

            if formatting_func is not None and not is_processed:
                if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                    map_kwargs["desc"] = f"Applying formatting function to {dataset_name} dataset"

                def _func(example):
                    return {"text": formatting_func(example)}

                dataset = dataset.map(_func, batched=True, **map_kwargs)

@YeFD
Copy link
Author

YeFD commented Mar 26, 2025

Applying the function to the entire dataset just to check if it is batched doesn’t seem reasonable. In fact, formatting_func is quite cumbersome, and I would advise against using it. Instead, we can simply preprocess the dataset before passing it to the trainer.

That said, determining whether the function is batched is trickier than it seems. Also, I realize that we never document or test support for a non-batched function. The simplest approach would be to assume that the function is always batched, right?

In fact, there is an instance of formatting_func using a non-batched implementation:

def formatting_func(example):
text = f"### Question: {example['question']}\n ### Answer: {example['answer']}"
return text

If only the batched function is supported, we should update the documentation.

@qgallouedec
Copy link
Member

Indeed! Let's update the doc then

@YeFD
Copy link
Author

YeFD commented Apr 1, 2025

@qgallouedec Please review the code.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@YeFD
Copy link
Author

YeFD commented Apr 2, 2025

The test script is still using a non-batched formatting_func. Should I modify the test script to only test batched formatting_func?

@qgallouedec
Copy link
Member

I revisited this issue and explored some open codebases. Unfortunately, both forms are widely used.

Here’s my suggestion:

try:
    dataset = dataset.map(_func, batched=False, **map_kwargs)
except Exception as e:
    warnings.warn(
        f"Failed to apply the formatting function due to the following error: {e}. This may be "
        "because the function is designed for batched input. Please update it to process one example "
        "at a time (i.e., accept and return a single example). For now, we will attempt to apply the "
        "function in batched mode, but note that batched formatting is deprecated and will be removed "
        "in version 0.21.",
        DeprecationWarning,
    )
    dataset = dataset.map(_func, batched=True, **map_kwargs)

Apologies for the change of direction and any extra work this may cause.

@qgallouedec qgallouedec changed the title Fix: Compatibility for formatting_func returning a list 📦 [SFT] Deprecate batched formatting_func Apr 6, 2025
@YeFD
Copy link
Author

YeFD commented Apr 6, 2025

Thank you for updating the code and submitting the PR. I really appreciate your help during my break!

@qgallouedec qgallouedec merged commit 559724e into huggingface:main Apr 8, 2025
8 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants