Skip to content

Commit f3d92a5

Browse files
pre-commit-ci[bot]huanghua1994
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 3cf1b22 commit f3d92a5

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

transformer_engine/jax/attention.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,11 @@ def _segment_ids_pos_to_seqlens_offsets(
441441
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
442442
# examine only O(Q+KV) elements.
443443
# TODO(huah): This fast path does not work for CP + THD + unbalanced, need to fix later
444-
if context_parallel_load_balanced and attn_mask_type.is_causal() and (window_size is None or window_size == (-1, -1)):
444+
if (
445+
context_parallel_load_balanced
446+
and attn_mask_type.is_causal()
447+
and (window_size is None or window_size == (-1, -1))
448+
):
445449
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path(
446450
segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq
447451
)
@@ -529,7 +533,11 @@ def tree_unflatten(cls, aux_data, children):
529533
return cls(*children)
530534

531535
def get_seqlens_and_offsets(
532-
self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq,
536+
self,
537+
attn_mask_type,
538+
qkv_layout,
539+
window_size,
540+
max_segments_per_seq,
533541
context_parallel_load_balanced: bool = False,
534542
):
535543
"""
@@ -552,7 +560,7 @@ def get_seqlens_and_offsets(
552560
attn_mask_type,
553561
window_size,
554562
max_segments_per_seq,
555-
context_parallel_load_balanced
563+
context_parallel_load_balanced,
556564
)
557565
else:
558566
q_seqlens, kv_seqlens = _segment_ids_to_seqlens(

transformer_engine/jax/cpp_extensions/attention.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ def impl(
465465
config.qkv_layout,
466466
config.window_size,
467467
config.max_segments_per_seq,
468-
config.context_parallel_load_balanced
468+
config.context_parallel_load_balanced,
469469
)
470470
)
471471

@@ -864,7 +864,7 @@ def impl(
864864
config.qkv_layout,
865865
config.window_size,
866866
config.max_segments_per_seq,
867-
config.context_parallel_load_balanced
867+
config.context_parallel_load_balanced,
868868
)
869869
)
870870

0 commit comments

Comments
 (0)