Skip to content

Commit b11b28c

Browse files
uminatyaurelien.lac
and
aurelien.lac
authored
Hotfix: Flash Attention 2 support in Pixtral (huggingface#38146)
setting attention_mask to None when flash_attention_2 is selected Co-authored-by: aurelien.lac <aurelien.lac@lighton.ai>
1 parent 0e0e5c1 commit b11b28c

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

src/transformers/models/pixtral/modeling_pixtral.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ def forward(
211211
else:
212212
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
213213

214+
# Since we use packing, if flash_attention_2 is selected we rely on position_ids
215+
if self.config._attn_implementation == "flash_attention_2":
216+
kwargs["position_ids"] = kwargs["position_ids"].to(hidden_states.device, non_blocking=True)
217+
attention_mask = None
218+
214219
attn_output, attn_weights = attention_interface(
215220
self,
216221
query_states,

0 commit comments

Comments
 (0)