From 24071f5d9cd9ac60d6e27f4cac4144e0e44f0d88 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Mon, 13 Jan 2025 12:31:50 -0800 Subject: [PATCH] Address comments --- axlearn/common/flash_attention/gpu_attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 43e62d8b..d39608ab 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -83,6 +83,9 @@ def _mha_forward_kernel( See also `_mha_backward_kernel` for the backward pass. + Note: the kernel name is used to do string matching for rematerialization in `remat.py`. Be + careful when renaming this. + Args: q_ref: Input query ref. k_ref: Input key ref.