Skip to content

Add Idefics2/3 and SmolVLM Fast image processors + improvements for fast image processors #38157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions docs/source/en/model_doc/idefics2.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ To load and run a model using Flash Attention-2, simply change the code snippet
```diff
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
+ torch_dtype=torch.float16,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
).to(device)
```
Expand All @@ -184,7 +184,7 @@ Quantizing a model is as simple as passing a `quantization_config` to the model.
+ )
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
+ torch_dtype=torch.float16,
+ torch_dtype=torch.float16,
+ quantization_config=quantization_config,
).to(device)
```
Expand Down Expand Up @@ -218,7 +218,10 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h
[[autodoc]] Idefics2ImageProcessor
- preprocess

## Idefics2ImageProcessorFast
[[autodoc]] Idefics2ImageProcessorFast
- preprocess

## Idefics2Processor
[[autodoc]] Idefics2Processor
- __call__
- __call__
3 changes: 3 additions & 0 deletions docs/source/en/model_doc/idefics3.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ This model was contributed by [amyeroberts](https://huggingface.co/amyeroberts)
[[autodoc]] Idefics3ImageProcessor
- preprocess

## Idefics3ImageProcessorFast
[[autodoc]] Idefics3ImageProcessorFast
- preprocess

## Idefics3Processor
[[autodoc]] Idefics3Processor
Expand Down
7 changes: 5 additions & 2 deletions docs/source/en/model_doc/smolvlm.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ SmolVLM2 is an adaptation of the Idefics3 model with two main differences:

Input images are processed either by upsampling (if resizing is enabled) or at their original resolution. The resizing behavior depends on two parameters: do_resize and size.

Videos should not be upsampled.
Videos should not be upsampled.

If `do_resize` is set to `True`, the model resizes images so that the longest edge is 4*512 pixels by default.
The default resizing behavior can be customized by passing a dictionary to the `size` parameter. For example, `{"longest_edge": 4 * 512}` is the default, but you can change it to a different value if needed.
Expand Down Expand Up @@ -192,11 +192,14 @@ print(generated_texts[0])
[[autodoc]] SmolVLMForConditionalGeneration
- forward


## SmolVLMImageProcessor
[[autodoc]] SmolVLMImageProcessor
- preprocess

## SmolVLMImageProcessorFast
[[autodoc]] SmolVLMImageProcessorFast
- preprocess

## SmolVLMVideoProcessor
[[autodoc]] SmolVLMVideoProcessor
- preprocess
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/commands/add_fast_image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def add_fast_image_processor_file(

content_header = get_fast_image_processing_content_header(content_base_file)
content_base_file = (
f"@auto_docstring(\n"
f"@auto_docstring\n"
f"class {fast_image_processor_name}(BaseImageProcessorFast):\n"
" # This generated class can be used as a starting point for the fast image processor.\n"
" # if the image processor is only used for simple augmentations, such as resizing, center cropping, rescaling, or normalizing,\n"
Expand Down
91 changes: 31 additions & 60 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@

import numpy as np

from .image_processing_utils import (
BaseImageProcessor,
BatchFeature,
get_size_dict,
)
from .image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from .image_transforms import (
convert_to_rgb,
get_resize_output_image_size,
Expand Down Expand Up @@ -188,6 +184,7 @@ class DefaultFastImageProcessorKwargs(TypedDict, total=False):
data_format: Optional[ChannelDimension]
input_data_format: Optional[Union[str, ChannelDimension]]
device: Optional["torch.device"]
disable_grouping: Optional[bool]


@auto_docstring
Expand Down Expand Up @@ -481,18 +478,35 @@ def _prepare_input_images(
) -> list["torch.Tensor"]:
"""
Prepare the input images for processing.

Args:
images (`ImageInput`):
The input images to process.
do_convert_rgb (`bool`, *optional*):
Whether to convert the images to RGB.
input_data_format (`str` or `ChannelDimension`, *optional*):
The input data format of the images.
device (`torch.device`, *optional*):
The device to put the processed images on.

Returns:
List[`torch.Tensor`]: The processed images.
"""

# Get structured images (potentially nested)
images = self._prepare_images_structure(images)
process_image_fn = partial(
self._process_image,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,

process_image_partial = partial(
self._process_image, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format, device=device
)
# todo: yoni - check if we can parallelize this efficiently
processed_images = []
for image in images:
processed_images.append(process_image_fn(image))

# Check if we have nested structure, assuming the nesting is consistent
has_nested_structure = len(images) > 0 and isinstance(images[0], (list, tuple))

if has_nested_structure:
processed_images = [[process_image_partial(img) for img in nested_list] for nested_list in images]
else:
processed_images = [process_image_partial(img) for img in images]

return processed_images

Expand Down Expand Up @@ -618,11 +632,12 @@ def _preprocess(
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images)
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_resize:
Expand All @@ -632,7 +647,7 @@ def _preprocess(

# Group images by size for further processing
# Needed in case do_resize is False, or resize returns images with different sizes
grouped_images, grouped_images_index = group_images_by_shape(resized_images)
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
for shape, stacked_images in grouped_images.items():
if do_center_crop:
Expand All @@ -652,47 +667,3 @@ def to_dict(self):
encoder_dict = super().to_dict()
encoder_dict.pop("_valid_processor_keys", None)
return encoder_dict


class SemanticSegmentationMixin:
def post_process_semantic_segmentation(self, outputs, target_sizes: Optional[list[tuple]] = None):
"""
Converts the output of [`MobileNetV2ForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.

Args:
outputs ([`MobileNetV2ForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`List[Tuple]` of length `batch_size`, *optional*):
List of tuples corresponding to the requested final size (height, width) of each prediction. If unset,
predictions will not be resized.

Returns:
semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is
specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits = outputs.logits

# Resize logits and compute semantic segmentation maps
if target_sizes is not None:
if len(logits) != len(target_sizes):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)

# if is_torch_tensor(target_sizes):
# target_sizes = target_sizes.numpy()

semantic_segmentation = []

for idx in range(len(logits)):
resized_logits = torch.nn.functional.interpolate(
logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False
)
semantic_map = resized_logits[0].argmax(dim=0)
semantic_segmentation.append(semantic_map)
else:
semantic_segmentation = logits.argmax(dim=1)
semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])]

return semantic_segmentation
135 changes: 116 additions & 19 deletions src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,37 +841,134 @@ def _cast_tensor_to_float(x):
return x.float()


def _flatten_nested_images(nested_images):
"""Helper function to flatten a single level of nested image structures and group by shape."""
grouped_images = {}
grouped_images_index = {}

for i, sublist in enumerate(nested_images):
for j, image in enumerate(sublist):
shape = image.shape[1:]
if shape not in grouped_images:
grouped_images[shape] = []
grouped_images[shape].append(image)
grouped_images_index[(i, j)] = (shape, len(grouped_images[shape]) - 1)

return grouped_images, grouped_images_index


def _reconstruct_nested_structure(indices, processed_images):
"""Helper function to reconstruct a single level nested structure."""
# Find the maximum outer index
max_outer_idx = max(idx[0] for idx in indices.keys())

# Create the outer list
result = [None] * (max_outer_idx + 1)

# Group indices by outer index
nested_indices = {}
for i, j in indices.keys():
if i not in nested_indices:
nested_indices[i] = []
nested_indices[i].append(j)

for i in range(max_outer_idx + 1):
if i in nested_indices:
inner_max_idx = max(nested_indices[i])
inner_list = [None] * (inner_max_idx + 1)
for j in range(inner_max_idx + 1):
if (i, j) in indices:
shape, idx = indices[(i, j)]
inner_list[j] = processed_images[shape][idx]
result[i] = inner_list

return result


def group_images_by_shape(
images: list["torch.Tensor"],
) -> tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[int, tuple[tuple[int, int], int]]]:
images: Union[list["torch.Tensor"], "torch.Tensor"],
is_nested: bool = False,
disable_grouping: Optional[bool] = None,
) -> tuple[
dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]
]:
"""
Groups images by shape.
Returns a dictionary with the shape as key and a list of images with that shape as value,
and a dictionary with the index of the image in the original list as key and the shape and index in the grouped list as value.

The function supports both flat lists of tensors and nested structures.
The input must be either all flat or all nested, not a mix of both.

Args:
images (Union[list["torch.Tensor"], "torch.Tensor"]):
A list of images or a single tensor

Returns:
tuple[dict[tuple[int, int], list["torch.Tensor"]], dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]]:
- A dictionary with shape as key and list of images with that shape as value
- A dictionary mapping original indices to (shape, index) tuples
"""
grouped_images = {}
grouped_images_index = {}
for i, image in enumerate(images):
shape = image.shape[1:]
if shape not in grouped_images:
grouped_images[shape] = []
grouped_images[shape].append(image)
grouped_images_index[i] = (shape, len(grouped_images[shape]) - 1)
# stack images with the same shape
grouped_images = {shape: torch.stack(images, dim=0) for shape, images in grouped_images.items()}
if not is_nested:
disable_grouping = images[0].device == "cpu" if disable_grouping is None else disable_grouping
if disable_grouping:
return {i: images[i].unsqueeze(0) for i in range(len(images))}, {i: (i, 0) for i in range(len(images))}

grouped_images = {}
grouped_images_index = {}
for i, image in enumerate(images):
shape = image.shape[1:]
if shape not in grouped_images:
grouped_images[shape] = []
grouped_images[shape].append(image)
grouped_images_index[i] = (shape, len(grouped_images[shape]) - 1)

# Stack images with the same shape
grouped_images = {shape: torch.stack(imgs, dim=0) for shape, imgs in grouped_images.items()}
return grouped_images, grouped_images_index

disable_grouping = images[0][0].device == "cpu" if disable_grouping is None else disable_grouping

if disable_grouping:
return {(i, j): images[i][j].unsqueeze(0) for i in range(len(images)) for j in range(len(images[i]))}, {
(i, j): ((i, j), 0) for i in range(len(images)) for j in range(len(images[i]))
}

# Handle single level nested structure
grouped_images, grouped_images_index = _flatten_nested_images(images)

# Stack images with the same shape
grouped_images = {shape: torch.stack(imgs, dim=0) for shape, imgs in grouped_images.items()}

return grouped_images, grouped_images_index


def reorder_images(
processed_images: dict[tuple[int, int], "torch.Tensor"], grouped_images_index: dict[int, tuple[int, int]]
) -> list["torch.Tensor"]:
processed_images: dict[tuple[int, int], "torch.Tensor"],
grouped_images_index: dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]],
is_nested: bool = False,
) -> Union[list["torch.Tensor"], "torch.Tensor"]:
"""
Reconstructs a list of images in the original order.
Reconstructs images in the original order, preserving the original structure (nested or not).
The input structure is either all flat or all nested.

Args:
processed_images (dict[tuple[int, int], "torch.Tensor"]):
Dictionary mapping shapes to batched processed images.
grouped_images_index (dict[Union[int, tuple[int, int]], tuple[tuple[int, int], int]]):
Dictionary mapping original indices to (shape, index) tuples.

Returns:
Union[list["torch.Tensor"], "torch.Tensor"]:
Images in the original structure.
"""
return [
processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
for i in range(len(grouped_images_index))
]
if not is_nested:
return [
processed_images[grouped_images_index[i][0]][grouped_images_index[i][1]]
for i in range(len(grouped_images_index))
]

return _reconstruct_nested_structure(grouped_images_index, processed_images)


class NumpyToTensor:
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/auto/image_processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@
("groupvit", ("CLIPImageProcessor", "CLIPImageProcessorFast")),
("hiera", ("BitImageProcessor", "BitImageProcessorFast")),
("idefics", ("IdeficsImageProcessor",)),
("idefics2", ("Idefics2ImageProcessor",)),
("idefics3", ("Idefics3ImageProcessor",)),
("idefics2", ("Idefics2ImageProcessor", "Idefics2ImageProcessorFast")),
("idefics3", ("Idefics3ImageProcessor", "Idefics3ImageProcessorFast")),
("ijepa", ("ViTImageProcessor", "ViTImageProcessorFast")),
("imagegpt", ("ImageGPTImageProcessor",)),
("instructblip", ("BlipImageProcessor", "BlipImageProcessorFast")),
Expand Down Expand Up @@ -147,6 +147,7 @@
("shieldgemma2", ("Gemma3ImageProcessor", "Gemma3ImageProcessorFast")),
("siglip", ("SiglipImageProcessor", "SiglipImageProcessorFast")),
("siglip2", ("Siglip2ImageProcessor", "Siglip2ImageProcessorFast")),
("smolvlm", ("SmolVLMImageProcessor", "SmolVLMImageProcessorFast")),
("superglue", ("SuperGlueImageProcessor",)),
("swiftformer", ("ViTImageProcessor", "ViTImageProcessorFast")),
("swin", ("ViTImageProcessor", "ViTImageProcessorFast")),
Expand Down
Loading