Skip to content

Commit

Permalink
Enabled running Pallas Flash Attention on CPU. (#922)
Browse files Browse the repository at this point in the history
Pallas supports CPU simulation (`interpret=True`), so we can use the same
TPU Pallas kernel on CPU — making code debugging easier.

This change lets the following unittests run on CPU as if they were on TPU,
enabling easier testing and debugging:
- `axlearn/common/flash_attention/tpu_attention_test.py`

Similarly, `gpu_attention_test.py` can also be run on CPU as if they were on GPU.
- `axlearn/common/flash_attention/gpu_attention_test.py`

Now CI covers those tests on CPU as well.
In M3 Max MacBook Pro, test coverages and processing time are as follows,
* axlearn/common/flash_attention/gpu_attention_test.py: 3024 passed, 1345 skipped in 200.38s (0:03:20)
* axlearn/common/flash_attention/tpu_attention_test.py: 18 passed, 435 skipped in 34.82s
  • Loading branch information
ds-hwang authored Jan 16, 2025
1 parent b43f854 commit 0b9af56
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 41 deletions.
44 changes: 40 additions & 4 deletions axlearn/common/flash_attention/gpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from axlearn.common.flash_attention.utils import _repeat_kv_heads, mha_reference
from axlearn.common.test_utils import TestCase

if jax.default_backend() != "gpu":
if jax.default_backend() not in ("gpu", "cpu"):
pytest.skip(reason="Incompatible hardware", allow_module_level=True)


Expand Down Expand Up @@ -69,6 +69,8 @@ def test_triton_fwd_only_against_ref(
kv_seq_len = seq_len
if kv_seq_len != seq_len and use_segment_ids:
pytest.skip()
if jax.default_backend() == "cpu" and kv_seq_len > 128:
pytest.skip(reason="CI got OOM.")
k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5)
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=input_dtype)
k = jax.random.normal(k2, (batch_size, kv_seq_len, num_heads, per_head_dim), dtype=input_dtype)
Expand Down Expand Up @@ -101,6 +103,7 @@ def test_triton_fwd_only_against_ref(
causal=causal,
softmax_scale=softmax_scale,
dropout_rate=dropout_rate,
interpret=(jax.default_backend() == "cpu"),
)
o_ref = mha_reference(
q,
Expand Down Expand Up @@ -152,6 +155,8 @@ def test_decode_against_ref(
kv_head_factor: int,
window_len: int,
):
if jax.default_backend() == "cpu" and seq_len > 1024:
pytest.skip(reason="Too slow on CPU.")
self.assertEqual(num_heads % kv_head_factor, 0)
assert num_heads % kv_head_factor == 0
k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4)
Expand Down Expand Up @@ -180,7 +185,14 @@ def test_decode_against_ref(
if window_len > 0:
mask_fn = sliding_window_causal_mask(window_len)
o = flash_decoding(
q, k, v, bias=bias, softmax_scale=softmax_scale, kv_seq_len=seq_len, mask_fn=mask_fn
q,
k,
v,
bias=bias,
softmax_scale=softmax_scale,
kv_seq_len=seq_len,
mask_fn=mask_fn,
interpret=(jax.default_backend() == "cpu"),
)
if bias is not None:
bias = bias[:, :, :, :seq_len]
Expand Down Expand Up @@ -269,6 +281,7 @@ def test_triton_against_xla_ref(
block_q=block_size,
block_k=block_size,
dropout_rate=dropout_rate,
interpret=(jax.default_backend() == "cpu"),
)
jax_out = call_flash(
q,
Expand Down Expand Up @@ -346,6 +359,9 @@ def test_cudnn_against_triton_ref(
causal: bool,
dtype: jnp.dtype,
):
if jax.default_backend() == "cpu":
pytest.skip(reason="cudnn function needs GPU.")

k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
q = jax.random.normal(k1, (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype)
k = jax.random.normal(k2, (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype)
Expand All @@ -357,7 +373,15 @@ def test_cudnn_against_triton_ref(
jax_out = cudnn_dot_product_attention(
q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale
)
jax_ref_out = flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale)
jax_ref_out = flash_attention(
q,
k,
v,
bias=None,
causal=causal,
softmax_scale=softmax_scale,
interpret=(jax.default_backend() == "cpu"),
)
if dtype == jnp.bfloat16:
# We relax the atol to support bf16 in the unit test.
chex.assert_trees_all_close(jax_out, jax_ref_out, atol=0.02, rtol=1e-5)
Expand All @@ -372,7 +396,15 @@ def fn(q, k, v):
).sum()

def ref_fn(q, k, v):
return flash_attention(q, k, v, bias=None, causal=causal, softmax_scale=softmax_scale).sum()
return flash_attention(
q,
k,
v,
bias=None,
causal=causal,
softmax_scale=softmax_scale,
interpret=(jax.default_backend() == "cpu"),
).sum()

# Compare gradients.
jax_grads = jax.grad(fn, argnums=(0, 1, 2))(q, k, v)
Expand Down Expand Up @@ -414,6 +446,8 @@ def test_cudnn_dropout_against_xla_dropout(
by setting V to the identity matrix. However, this only works when seq_len == per_head_dim,
i.e. when the shape of output is the same as the shape of the dropout mask.
"""
if jax.default_backend() == "cpu":
pytest.skip(reason="cudnn function needs GPU.")
qkv_shape = (batch_size, seq_len, num_heads, per_head_dim)
softmax_scale = 1.0
cudnn_attn = functools.partial(
Expand Down Expand Up @@ -481,6 +515,8 @@ def ref_fn(q, k, v):

def test_cudnn_dropout_determinism():
"""Tests that cuDNN dropout produces identical outputs across runs."""
if jax.default_backend() == "cpu":
pytest.skip(reason="cudnn function needs GPU.")
k1, k2, k3 = jax.random.split(jax.random.PRNGKey(3), 3)
q = jax.random.normal(k1, (1, 128, 2, 64), dtype=jnp.float16)
k = jax.random.normal(k2, (1, 128, 2, 64), dtype=jnp.float16)
Expand Down
9 changes: 8 additions & 1 deletion axlearn/common/flash_attention/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized
from absl.testing import absltest, parameterized
from jax.experimental import mesh_utils
from jax.sharding import Mesh

Expand Down Expand Up @@ -98,6 +98,7 @@ def _prepare_layers(
sliding_window_size,
inference=False,
set_layer_bias_recursively=False,
tpu_block_size=512,
dropout_rate=0.0,
):
hidden_dim = num_heads * per_head_dim
Expand All @@ -119,6 +120,7 @@ def _prepare_layers(
.set(
mha_dim_to_partition_spec=default_mha_dim_to_partition_spec(mesh_axis_names),
output_dim_to_partition_spec=default_output_dim_to_partition_spec(mesh_axis_names),
tpu_block_size=tpu_block_size,
)
)
if inference:
Expand Down Expand Up @@ -458,6 +460,7 @@ def test_forward(
causal=causal,
sliding_window_size=sliding_window_size,
dropout_rate=dropout_rate,
tpu_block_size=128,
)

query_len = int(query_len_multiplier * seq_len)
Expand Down Expand Up @@ -815,3 +818,7 @@ def test_extend_step(
atol=2e-2,
)
jax.extend.backend.clear_backends()


if __name__ == "__main__":
absltest.main()
37 changes: 19 additions & 18 deletions axlearn/common/flash_attention/tpu_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

"""Wrappers for FlashAttention on TPU in JAX with logit bias support."""
import functools
from typing import Optional, Union
from typing import Optional

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -40,6 +40,8 @@
)
from axlearn.common.utils import Tensor

MaskFnOrZero = MaskFnAttentionBias | ZeroAttentionBias


def tpu_flash_attention(
query: Tensor, # [batch_size, target_len, num_heads, head_dim]
Expand All @@ -48,7 +50,7 @@ def tpu_flash_attention(
bias: Tensor = None, # [batch_size, num_heads, target_len, source_len]
segment_ids: Tensor = None, # [batch_size, target_len]
*,
mask: Optional[MaskFnAttentionBias] = None,
mask: MaskFnOrZero,
softmax_scale: float = 1.0,
block_size: int = 128,
interpret: bool = False,
Expand Down Expand Up @@ -113,16 +115,17 @@ def tpu_flash_attention(
f"Source seq len {key.shape[1]} must be divisible by block size {block_size}."
)

mask: Union[MaskFnAttentionBias | ZeroAttentionBias] = as_attention_bias(mask)
mask: MaskFnOrZero = as_attention_bias(mask)

# Switch num_heads and seq_len axes.
query = jnp.einsum("btnh->bnth", query)
key = jnp.einsum("bsnh->bnsh", key)
value = jnp.einsum("bsnh->bnsh", value)
try:
check_tpu_splash_attention(
query=query,
key=key,
target_len=query.shape[2],
source_len=key.shape[2],
head_dim=query.shape[3],
mask=mask,
has_segment_ids=(segment_ids is not None),
has_bias=(bias is not None),
Expand Down Expand Up @@ -199,7 +202,7 @@ def _legacy_tpu_flash_attention(
bias: Tensor = None, # [batch_size, num_heads, target_len, source_len]
segment_ids: Tensor = None, # [batch_size, target_len]
*,
mask: MaskFnAttentionBias,
mask: MaskFnOrZero,
block_sizes: Optional[LegacyBlockSizes] = None,
interpret: bool = False,
) -> Tensor: # [batch_size, num_heads, target_len, head_dim].
Expand Down Expand Up @@ -253,17 +256,19 @@ class SplashAttentionUnsupportedError(NotImplementedError):

def check_tpu_splash_attention(
*,
query: Tensor, # [batch_size, num_heads, source_len, head_dim]
key: Tensor, # [batch_size, num_heads, target_len, head_dim]
mask: Union[MaskFnAttentionBias | ZeroAttentionBias],
target_len: int,
source_len: int,
head_dim: int,
mask: MaskFnOrZero,
has_segment_ids: bool = False,
has_bias: bool = False,
):
"""Checks if splash attention is supported on TPU for the given arguments.
Args:
query: The query tensor, of shape [batch_size, num_heads, target_len, head_dim].
key: The key tensor, of shape [batch_size, num_heads, source_len, head_dim].
target_len: The length of the target sequence.
source_len: The length of the source sequence.
head_dim: The dimension of each head.
mask: The mask to apply. This is more compute efficient compared to setting bias = -inf.
has_segment_ids: Whether segment_ids is None or not.
has_bias: Whether attention involves a bias.
Expand All @@ -272,12 +277,8 @@ def check_tpu_splash_attention(
SplashAttentionUnsupportedError: If splash attention is not supported for the given
arguments.
"""
target_len = query.shape[2]
source_len = key.shape[2]
head_dim = query.shape[3]

if has_bias:
return False # SplashAttention does not support specifying a bias.
raise SplashAttentionUnsupportedError("SplashAttention does not support specifying a bias.")
with jax.ensure_compile_time_eval():
if jnp.any(
jnp.asarray([target_len, source_len, head_dim]) % splash_attention_kernel.NUM_LANES != 0
Expand Down Expand Up @@ -305,7 +306,7 @@ def check_tpu_splash_attention(


def _to_splash_mask(
mask: Union[MaskFnAttentionBias | ZeroAttentionBias],
mask: MaskFnOrZero,
*,
mask_shape: tuple[int, int],
q_seq_shards: int = 1,
Expand Down Expand Up @@ -344,7 +345,7 @@ def _tpu_splash_attention(
key: Tensor, # [batch_size, num_heads, source_len, head_dim]
value: Tensor, # [batch_size, num_heads, source_len, head_dim]
*,
mask: Union[MaskFnAttentionBias | ZeroAttentionBias],
mask: MaskFnOrZero,
segment_ids: Optional[Tensor] = None, # [batch_size, target_len]
block_sizes: Optional[splash_attention_kernel.BlockSizes] = None,
interpret: bool = False,
Expand Down
19 changes: 8 additions & 11 deletions axlearn/common/flash_attention/tpu_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import unittest

import chex
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -29,14 +28,10 @@
from axlearn.common.test_utils import TestCase, is_supported_mesh_shape
from axlearn.common.utils import Tensor

# Comment out to test on CPU manually. Technically, this test runs on the CPU, albeit very slowly.
if jax.default_backend() != "tpu":
pytest.skip(reason="Incompatible hardware", allow_module_level=True)


def setUpModule():
# If on CPU, emulate 4 devices.
chex.set_n_cpu_devices(4)
if jax.default_backend() not in ("tpu", "cpu"):
pytest.skip(reason="Incompatible hardware", allow_module_level=True)


def jax_fn_mask(query_position: Tensor, key_position: Tensor) -> Tensor:
Expand Down Expand Up @@ -102,7 +97,6 @@ def test_to_splash_mask(self, mask, expected):
sliding_window_size=[1024],
num_heads=[4],
per_head_dim=[256],
mesh=[(4, 1)],
mesh_axis_names=[("data", "model")],
)
def test_forward(
Expand All @@ -113,11 +107,12 @@ def test_forward(
per_head_dim,
mask_fn,
sliding_window_size,
mesh,
mesh_axis_names,
):
if not is_supported_mesh_shape(mesh):
pytest.skip(reason=f"Unsupported mesh {mesh}.")
if jax.default_backend() == "cpu" and seq_len > 1024:
pytest.skip(reason="Too slow on CPU.")
mesh = (1, 1) if jax.default_backend() == "cpu" else (4, 1)
self.assertTrue(is_supported_mesh_shape(mesh))

k1, k2, k3 = jax.random.split(jax.random.PRNGKey(0), 3)
q = jax.random.normal(
Expand Down Expand Up @@ -254,6 +249,8 @@ def ref_fn(q, k, v, bias, ids):

if mask is not None:
mask = MaskFnAttentionBias(mask, shape=(query_len, kv_len))
else:
mask = ZeroAttentionBias()

def fn(q, k, v, bias, ids):
record_legacy_call = unittest.mock.patch.object(
Expand Down
13 changes: 6 additions & 7 deletions axlearn/common/flash_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
MaskFnAttentionBias,
SegmentIdAttentionBias,
TensorAttentionBias,
ZeroAttentionBias,
split,
)
from axlearn.common.flash_attention.gpu_attention import cudnn_dot_product_attention
Expand Down Expand Up @@ -203,6 +202,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
mask_fn=mask_fn,
kv_seq_len=kv_seq_len,
softmax_scale=softmax_scale,
interpret=(backend == "cpu"),
)

key = _repeat_kv_heads(query.shape[2], key)
Expand Down Expand Up @@ -237,6 +237,7 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
softmax_scale=softmax_scale,
causal=causal.has_value(),
dropout_rate=dropout_rate,
interpret=(backend == "cpu"),
)
else:
explicit_bias += segment_ids
Expand Down Expand Up @@ -268,20 +269,18 @@ def get_segment_ids(segment_ids: SegmentIdAttentionBias) -> Optional[Tensor]:
value,
bias=explicit_bias.value(),
segment_ids=get_segment_ids(segment_ids),
# The `from_sequence()` function guarantees that if there is only one
# mask, it is returned without modification.
# This allows the `causal` path in `_legacy_tpu_flash_attention()` to work.
mask=mask if not isinstance(mask, ZeroAttentionBias) else None,
mask=mask,
softmax_scale=softmax_scale,
block_size=block_size,
interpret=(backend == "cpu"),
)

elif backend in ("cpu", "xla"):
key = _repeat_kv_heads(query.shape[2], key)
value = _repeat_kv_heads(query.shape[2], value)
if backend == "cpu":
logging.warning("Flash attention CPU backend is for testing only.")
logging.warning("Flash attention falling back using plain MHA implementation")
logging.info("Flash attention CPU backend is for testing only.")
logging.info("Flash attention falling back using plain MHA implementation")

# `causal` is supported.
# `segment_ids` is supported.
Expand Down

0 comments on commit 0b9af56

Please sign in to comment.