@@ -441,7 +441,11 @@ def _segment_ids_pos_to_seqlens_offsets(
441
441
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
442
442
# examine only O(Q+KV) elements.
443
443
# TODO(huah): This fast path does not work for CP + THD + unbalanced, need to fix later
444
- if context_parallel_load_balanced and attn_mask_type .is_causal () and (window_size is None or window_size == (- 1 , - 1 )):
444
+ if (
445
+ context_parallel_load_balanced
446
+ and attn_mask_type .is_causal ()
447
+ and (window_size is None or window_size == (- 1 , - 1 ))
448
+ ):
445
449
return _segment_ids_pos_to_seqlens_offsets_fast_causal_path (
446
450
segment_ids_q , segment_ids_kv , segment_pos_q , segment_pos_kv , max_segments_per_seq
447
451
)
@@ -529,7 +533,11 @@ def tree_unflatten(cls, aux_data, children):
529
533
return cls (* children )
530
534
531
535
def get_seqlens_and_offsets (
532
- self , attn_mask_type , qkv_layout , window_size , max_segments_per_seq ,
536
+ self ,
537
+ attn_mask_type ,
538
+ qkv_layout ,
539
+ window_size ,
540
+ max_segments_per_seq ,
533
541
context_parallel_load_balanced : bool = False ,
534
542
):
535
543
"""
@@ -552,7 +560,7 @@ def get_seqlens_and_offsets(
552
560
attn_mask_type ,
553
561
window_size ,
554
562
max_segments_per_seq ,
555
- context_parallel_load_balanced
563
+ context_parallel_load_balanced ,
556
564
)
557
565
else :
558
566
q_seqlens , kv_seqlens = _segment_ids_to_seqlens (
0 commit comments