Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Jan 13, 2025
1 parent 9850e2e commit 24071f5
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions axlearn/common/flash_attention/gpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def _mha_forward_kernel(
See also `_mha_backward_kernel` for the backward pass.
Note: the kernel name is used to do string matching for rematerialization in `remat.py`. Be
careful when renaming this.
Args:
q_ref: Input query ref.
k_ref: Input key ref.
Expand Down

0 comments on commit 24071f5

Please sign in to comment.