diff --git a/.idea/FXUtils.iml b/.idea/FXUtils.iml index 5970d29..ffda314 100644 --- a/.idea/FXUtils.iml +++ b/.idea/FXUtils.iml @@ -1,10 +1,8 @@ - - - - + + diff --git a/.idea/deployment.xml b/.idea/deployment.xml new file mode 100644 index 0000000..7339fa5 --- /dev/null +++ b/.idea/deployment.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml index a49a390..a6218fe 100644 --- a/.idea/misc.xml +++ b/.idea/misc.xml @@ -3,5 +3,5 @@ - + \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0fc55af..afe8cd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires-python = ">=3.8" readme = "README.md" -version = "0.0.56" +version = "0.0.57" dependencies = [ "jax>=0.4.20", diff --git a/src/fjformer/__init__.py b/src/fjformer/__init__.py index 5d71305..8b1cb3e 100644 --- a/src/fjformer/__init__.py +++ b/src/fjformer/__init__.py @@ -68,7 +68,7 @@ from . import optimizers as optimizers from . import linen as linen -__version__ = "0.0.56" +__version__ = "0.0.57" __all__ = ( # Loss and extra function diff --git a/src/fjformer/pallas_operations/__init__.py b/src/fjformer/pallas_operations/__init__.py index b2eee8a..aa7d023 100644 --- a/src/fjformer/pallas_operations/__init__.py +++ b/src/fjformer/pallas_operations/__init__.py @@ -6,7 +6,7 @@ """ from .efficient_attention import efficient_attention as efficient_attention -from .flash_attention import ( +from .tpu_flash_attention import ( flash_attention as tpu_flash_attention, mha as gpu_flash_attention, BlockSizes diff --git a/src/fjformer/pallas_operations/pallas_flash_attention/attention.py b/src/fjformer/pallas_operations/pallas_flash_attention/attention.py index b8dd18c..955ad9c 100644 --- a/src/fjformer/pallas_operations/pallas_flash_attention/attention.py +++ b/src/fjformer/pallas_operations/pallas_flash_attention/attention.py @@ -1,14 +1,10 @@ -# Modified Implementation of Flash attention or MHA from org jax authors +# Modified Implementation of Flash attention from __future__ import annotations +import math import os -os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" -os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".99" -os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" - import functools -import math from typing import Any, Optional import jax @@ -20,99 +16,43 @@ DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) -def attention_mask_movement( - q_attention_mask: jax.Array, - kv_attention_mask: jax.Array, - is_left_padded: bool = False -): - if is_left_padded: - q_attention_mask = jnp.atleast_2d(q_attention_mask) - kv_attention_mask = jnp.atleast_2d(kv_attention_mask) - combined_mask = jnp.atleast_2d(jnp.logical_or(q_attention_mask, kv_attention_mask.transpose(1, 0))) - - # jax.debug.print("*************************************\nCOMB:\n{x}", x=combined_mask) - return combined_mask - else: - return jnp.bitwise_or(q_attention_mask, kv_attention_mask).astype(jnp.bool_) - - -def control_combination( - causal_mask, - mask=None, - is_left_padded: bool = False -): - if mask is None: - return causal_mask - if is_left_padded: - # jax.debug.print("*************************************\nMASK:\n{x}", x=mask) - mask = jnp.logical_or(mask[0, :], mask[-1, :]) - # jax.debug.print("*************************************\nCASK:\n{x}", x=mask) - mask = jnp.logical_and(mask, causal_mask) - # jax.debug.print("*************************************\nFMSK:\n{x} " + "--" * 20, x=mask) - - else: - mask = jnp.logical_and(mask, causal_mask) - return mask.astype(jnp.bool_) - - def flash_attention_forward_kernel( q_ref, k_ref, v_ref, # Input arrays - attention_mask_ref: jax.Array | None, # segment_id arrays + b_ref: jax.Array | None, # bias o_ref: Any, # Output *residual_refs: Any, # Residual outputs num_heads: int, sm_scale: float, - is_left_padded: bool, - causal: bool, block_q: int, block_d: int, block_k: int, ): seq_len = q_ref.shape[0] start_q = pl.program_id(0) - + if sm_scale is None: + sm_scale = 1 / math.sqrt(q_ref.shape[-1]) m_i = jnp.zeros(block_q, dtype=jnp.float32) - float("inf") l_i = jnp.zeros(block_q, dtype=jnp.float32) o = jnp.zeros((block_q, block_d), dtype=jnp.float32) curr_q_slice = pl.dslice(start_q * block_q, block_q) q = pl.load(q_ref, (curr_q_slice, pl.dslice(None))) - q_attention_mask = ( - None - if attention_mask_ref is None - else pl.load(attention_mask_ref, (curr_q_slice,)) - ) def body(start_k, carry): o_prev, m_prev, l_prev = carry curr_k_slice = pl.dslice(start_k * block_k, block_k) k = pl.load(k_ref, (curr_k_slice, slice(None))) - kv_attention_mask = ( - None - if attention_mask_ref is None - else pl.load(attention_mask_ref, (curr_k_slice,)) - ) + qk = pl.dot(q, k.T) if sm_scale != 1.: qk *= sm_scale - if causal or attention_mask_ref is not None: - mask = None - if attention_mask_ref is not None: - mask = attention_mask_movement( - q_attention_mask, - kv_attention_mask, - is_left_padded - ) - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - span_k = start_k * block_k + jnp.arange(block_k) - causal_mask = span_q[:, None] >= span_k[None, :] - mask = control_combination(causal_mask, mask, is_left_padded) - qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + if b_ref is not None: + b = pl.load(b_ref, (curr_q_slice, curr_k_slice)) + qk = jnp.add(b, qk, ) m_curr = qk.max(axis=-1) m_next = jnp.maximum(m_prev, m_curr) @@ -130,10 +70,7 @@ def body(start_k, carry): o_next = o_prev_corr + o_curr return o_next, m_next, l_next - if causal: - upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k) - else: - upper_bound = pl.cdiv(seq_len, block_k) + upper_bound = pl.cdiv(seq_len, block_k) o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i)) if residual_refs: @@ -146,14 +83,12 @@ def body(start_k, carry): @functools.partial( - jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] + jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12] ) @functools.partial( jax.jit, static_argnames=[ "sm_scale", - "is_left_padded", - "causal", "block_q", "block_k", "backward_pass_impl", @@ -168,9 +103,8 @@ def flash_attention( query, key, value, - attention_mask: Optional[jnp.ndarray] = None, - sm_scale: float = 1.0, - causal: bool = False, + bias: Optional[jnp.ndarray] = None, + sm_scale: Optional[float] = None, block_q: int = 128, block_k: int = 128, backward_pass_impl: str = "triton", @@ -178,17 +112,13 @@ def flash_attention( num_stages: int = 2, grid: Optional[tuple[int, ...]] = None, interpret: Optional[bool] = None, - is_left_padded: Optional[bool] = None, debug: bool = False, ): del backward_pass_impl batch_size, seq_len, num_heads, head_dim = query.shape - - if is_left_padded is None: - if attention_mask is None: - is_left_padded = False - # else: + if sm_scale is None: + sm_scale = 1 / math.sqrt(head_dim) if interpret is None: interpret = not (seq_len / 16).is_integer() or jax.lib.xla_bridge.get_backend().platform == "cpu" @@ -204,20 +134,18 @@ def flash_attention( num_warps_ = 4 if head_dim <= 64 else 8 kernel = functools.partial( flash_attention_forward_kernel, - is_left_padded=is_left_padded, num_heads=num_heads, sm_scale=sm_scale, block_q=block_q, block_k=block_k, block_d=head_dim, - causal=causal, ) in_specs = [ pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - None if attention_mask is None else pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len)) + None if bias is None else pl.BlockSpec(lambda _, j, k: (j, 0, 0, 0), (None, None, seq_len, seq_len)) ] out_shape = jax.ShapeDtypeStruct(shape=query.shape, dtype=query.dtype) return pl.pallas_call( @@ -234,17 +162,15 @@ def flash_attention( debug=debug, interpret=interpret, name="flash_attention_forward", - )(query, key, value, attention_mask) + )(query, key, value, bias) def _flash_attention_forward( - q, - k, - v, - attention_mask: jax.Array | None, + query, + key, + value, + bias: jax.Array | None, sm_scale: float, - is_left_padded: bool, - causal: bool, block_q: int, block_k: int, backward_pass_impl: str, @@ -255,7 +181,11 @@ def _flash_attention_forward( debug: bool, ): del backward_pass_impl - batch_size, seq_len, num_heads, head_dim = q.shape + batch_size, seq_len, num_heads, head_dim = query.shape + if sm_scale is None: + sm_scale = 1 / math.sqrt(head_dim) + if interpret is None: + interpret = not (seq_len / 16).is_integer() or jax.lib.xla_bridge.get_backend().platform == "cpu" block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) # Heuristics. @@ -270,14 +200,12 @@ def _flash_attention_forward( flash_attention_forward_kernel, num_heads=num_heads, sm_scale=sm_scale, - is_left_padded=is_left_padded, - causal=causal, block_q=block_q, block_k=block_k, block_d=head_dim ) out_shape = [ - jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out + jax.ShapeDtypeStruct(shape=query.shape, dtype=query.dtype), # out jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), dtype=jnp.float32), jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), dtype=jnp.float32) ] @@ -285,7 +213,7 @@ def _flash_attention_forward( pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - None if attention_mask is None else pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len)) + None if bias is None else pl.BlockSpec(lambda _, j, k: (j, 0, 0, 0), (None, None, seq_len, seq_len)) ] out_specs = [ pl.BlockSpec( @@ -306,8 +234,8 @@ def _flash_attention_forward( debug=debug, interpret=interpret, name="flash_attention_forward", - )(q, k, v, attention_mask) - return out, (q, k, v, attention_mask, out, l, m) + )(query, key, value, bias) + return out, (query, key, value, bias, out, l, m) def _preprocess_backward_kernel( @@ -375,7 +303,7 @@ def flash_attention_backward_kernel( q_ref, k_ref, v_ref, - attention_mask_ref: jax.Array | None, + b_ref: jax.Array | None, out_ref, do_scaled_ref, l_ref, @@ -388,8 +316,6 @@ def flash_attention_backward_kernel( dv_ref, *, sm_scale: float, - is_left_padded: bool, - causal: bool, block_q: int, block_d: int, block_k: int, @@ -403,12 +329,6 @@ def outer_loop(start_k, _): dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) - span_k = start_k * block_k + jnp.arange(block_k) - kv_attention_mask = ( - None - if attention_mask_ref is None - else pl.load(attention_mask_ref, (pl.ds(start_k * block_k, block_k),)) - ) def inner_loop(start_q, carry): dv, dk = carry @@ -418,24 +338,9 @@ def inner_loop(start_q, carry): qk = qk.astype(jnp.float32) if sm_scale != 1.0: qk *= sm_scale - - q_attention_mask = ( - None - if attention_mask_ref is None - else pl.load(attention_mask_ref, (pl.ds(start_q * block_q, block_q),)) - ) - - if causal or attention_mask_ref is not None: - mask = None - if attention_mask_ref is not None: - mask = attention_mask_movement(q_attention_mask, kv_attention_mask, is_left_padded) - - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - causal_mask = span_q[:, None] >= span_k[None, :] - mask = control_combination(causal_mask, mask, is_left_padded) - - qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + if b_ref is not None: + b = pl.load(b_ref, (pl.ds(start_q * block_q, block_q), pl.ds(start_k * block_k, block_k))) + qk = jnp.add(b, qk) m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) p = jnp.exp(qk - m[:, None]) @@ -453,46 +358,29 @@ def inner_loop(start_q, carry): pl.store(dq_ref, (pl.ds(start_q * block_q, block_q), slice(None)), dq, eviction_policy="evict_last") return dv, dk - if causal: - lower_bound = lax.div(start_k * block_k, block_q) - else: - lower_bound = 0 - dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop, (dv, dk)) + dv, dk = lax.fori_loop(0, pl.cdiv(seq_len, block_q), inner_loop, (dv, dk)) pl.store(dv_ref, (pl.ds(start_k * block_k, block_k), slice(None)), dv.astype(dv_ref.dtype)) pl.store(dk_ref, (pl.ds(start_k * block_k, block_k), slice(None)), dk.astype(dk_ref.dtype)) lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None) -@functools.partial(jax.jit, static_argnames=["sm_scale", "causal"]) +@functools.partial(jax.jit, static_argnames=["sm_scale"]) def _flash_attention_reference( q, k, v, - attention_mask: jnp.ndarray | None, + b: Optional[jax.Array] = None, sm_scale=1.0, - causal: bool = False, ): - q_seq_len = q.shape[1] - kv_seq_len = k.shape[1] logits = jnp.einsum("bqhc,bkhc->bhqk", q, k).astype(jnp.float32) - mask = None - if attention_mask is not None: - mask = jnp.expand_dims(attention_mask_movement(attention_mask, attention_mask), 1) - mask = jnp.broadcast_to(mask, logits.shape) - if causal: - causal_mask = jnp.tril(jnp.ones((1, 1, q_seq_len, kv_seq_len), dtype=bool)) - causal_mask = jnp.broadcast_to(causal_mask, logits.shape) - mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask) - logits = logits if mask is None else jnp.where(mask, logits, float("-inf")) + logits = b + logits weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype) return jnp.einsum("bhqk,bkhc->bqhc", weights, v) def _flash_attention_backward( sm_scale: float, - is_left_padded: bool, - causal: bool, block_q: int, block_k: int, backward_pass_impl: str, @@ -505,22 +393,23 @@ def _flash_attention_backward( do ): del num_warps, num_stages, grid - q, k, v, attention_mask, out, l, m = res + q, k, v, b, out, l, m = res + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) if backward_pass_impl == "xla": return jax.vjp( - functools.partial(_flash_attention_reference, sm_scale=sm_scale, causal=causal), + functools.partial(_flash_attention_reference, sm_scale=sm_scale), q, k, v, - attention_mask, + b, )[1](do) elif backward_pass_impl == "triton": batch_size, seq_len, num_heads, head_dim = q.shape block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) - # We accumulate into dq so we need to initialize it to zeros. dq = jnp.zeros(q.shape, jnp.float32) out_shapes = [ jax.ShapeDtypeStruct(dq.shape, dq.dtype), @@ -551,11 +440,11 @@ def _flash_attention_backward( lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) ), ] - if attention_mask is None: + if b is None: in_specs.insert(3, None) # type: ignore[arg-type] input_output_aliases = {8: 0} else: - in_specs.insert(3, pl.BlockSpec(lambda j, k: (j, 0), (None, seq_len))) + in_specs.insert(3, pl.BlockSpec(lambda j, k: (j, 0, 0, 0), (None, None, seq_len, seq_len))) input_output_aliases = {9: 0} grid = (batch_size, num_heads) num_warps = 8 @@ -577,8 +466,6 @@ def _flash_attention_backward( block_d=head_dim, block_k=block_k, sm_scale=sm_scale, - is_left_padded=is_left_padded, - causal=causal, ), out_specs=out_specs, # type:ignore grid=grid, @@ -593,7 +480,7 @@ def _flash_attention_backward( ) ), input_output_aliases=input_output_aliases, - )(q, k, v, attention_mask, out, do_scaled, l, m, delta, dq) + )(q, k, v, b, out, do_scaled, l, m, delta, dq) else: raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") return dq.astype(q.dtype), dk, dv, None diff --git a/src/fjformer/pallas_operations/flash_attention/__init__.py b/src/fjformer/pallas_operations/tpu_flash_attention/__init__.py similarity index 100% rename from src/fjformer/pallas_operations/flash_attention/__init__.py rename to src/fjformer/pallas_operations/tpu_flash_attention/__init__.py diff --git a/src/fjformer/pallas_operations/flash_attention/gpu/__init__.py b/src/fjformer/pallas_operations/tpu_flash_attention/gpu/__init__.py similarity index 100% rename from src/fjformer/pallas_operations/flash_attention/gpu/__init__.py rename to src/fjformer/pallas_operations/tpu_flash_attention/gpu/__init__.py diff --git a/src/fjformer/pallas_operations/flash_attention/gpu/jax_flash_attn_gpu.py b/src/fjformer/pallas_operations/tpu_flash_attention/gpu/jax_flash_attn_gpu.py similarity index 100% rename from src/fjformer/pallas_operations/flash_attention/gpu/jax_flash_attn_gpu.py rename to src/fjformer/pallas_operations/tpu_flash_attention/gpu/jax_flash_attn_gpu.py diff --git a/src/fjformer/pallas_operations/flash_attention/tpu/__init__.py b/src/fjformer/pallas_operations/tpu_flash_attention/tpu/__init__.py similarity index 100% rename from src/fjformer/pallas_operations/flash_attention/tpu/__init__.py rename to src/fjformer/pallas_operations/tpu_flash_attention/tpu/__init__.py diff --git a/src/fjformer/pallas_operations/flash_attention/tpu/jax_flash_attn_tpu.py b/src/fjformer/pallas_operations/tpu_flash_attention/tpu/jax_flash_attn_tpu.py similarity index 100% rename from src/fjformer/pallas_operations/flash_attention/tpu/jax_flash_attn_tpu.py rename to src/fjformer/pallas_operations/tpu_flash_attention/tpu/jax_flash_attn_tpu.py diff --git a/test/attention_test.py b/test/attention_test.py index 599edf8..ecbbe9e 100644 --- a/test/attention_test.py +++ b/test/attention_test.py @@ -1,5 +1,9 @@ import math +import os +import jax + +os.environ['JAX_TRACEBACK_FILTERING'] = 'off' from flax.linen.attention import dot_product_attention, make_attention_mask, make_causal_mask, combine_masks from src.fjformer.pallas_operations import flash_attention @@ -36,18 +40,17 @@ def main(): # print(a) csm = make_causal_mask(jnp.ones((batch, seq))) mask = combine_masks(csm, a[:, None, None, :]) - out = dot_product_attention(q, k, v, mask=mask) + b = jnp.where(mask, 0, jnp.finfo(jnp.float32).min) + out = dot_product_attention(q, k, v, b) cnk = seq // 2 out_flash = flash_attention( q, k, v, - attention_mask=a, - causal=True, - sm_scale=1 / math.sqrt(hd), + b, + # sm_scale=1 / math.sqrt(hd), block_k=cnk, block_q=cnk, - is_left_padded=True ) print(jnp.mean(jnp.sum(out_flash))) print(jnp.mean(jnp.sum(out)))