From e4ff72cb377ec1f6e74484fe4525c2f8c205ad41 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Mon, 23 Dec 2024 17:21:14 -0800 Subject: [PATCH] Fix softmax scale (#903) --- axlearn/common/flash_attention/tpu_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/axlearn/common/flash_attention/tpu_attention.py b/axlearn/common/flash_attention/tpu_attention.py index 7b44266b..0edc6c0c 100644 --- a/axlearn/common/flash_attention/tpu_attention.py +++ b/axlearn/common/flash_attention/tpu_attention.py @@ -594,7 +594,7 @@ def lm_index_map(batch_index, head_index, q_seq_index, _): _flash_attention_kernel, causal=causal, mask_value=DEFAULT_MASK_VALUE, - softmax_scale=softmax_scale, + sm_scale=softmax_scale, block_k=block_k, kv_seq_len=kv_seq_len, ) @@ -878,7 +878,7 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _): _flash_attention_dkv_kernel, block_q=block_q, block_k=block_k, - softmax_scale=softmax_scale, + sm_scale=softmax_scale, causal=causal, mask_value=mask_value, q_seq_len=q_seq_len, @@ -1068,7 +1068,7 @@ def kv_segment_ids_index_map(batch_index, head_index, q_seq_index, kv_seq_index) kernel = functools.partial( _flash_attention_dq_kernel, - softmax_scale=softmax_scale, + sm_scale=softmax_scale, causal=causal, mask_value=mask_value, block_k=block_k,