Skip to content

Commit

Permalink
MaskFnAttentionBias._bool_value passes the same rank position tensors…
Browse files Browse the repository at this point in the history
… to mask_fn. (#888)

When target_positions is set, a rank 3 target_positions and a rank 2
source_positions are passed to mask_fn. From the perspective of a downstream
defining mask_fn, this is a big surprise.
  • Loading branch information
ds-hwang authored Dec 17, 2024
1 parent 92205bc commit a7e2a95
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
5 changes: 3 additions & 2 deletions axlearn/common/attention_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def __call__(self, query_position: Tensor, key_position: Tensor) -> Tensor:
x = f(jnp.asarray([1,2]), jnp.asarray([3,4]))
assert x[0] == f(jnp.asarray(1), jnp.asarray(3))[None]
```
* Both tensors have the same rank (either 2 or 3), as batch dim is optional.
* If given non-scalar arguments of different shapes, the result must be the same if we
first broadcast the arguments against each other to make them have the same shape.
* Beyond requiring broadcastability, must not impose any constraints on the shapes of its
Expand Down Expand Up @@ -494,14 +495,14 @@ def _bool_value(self) -> Optional[Tensor]:
NotImplementedError. If `target_positions.ndim not in [1,2]`.
"""
target_positions, source_positions = jnp.indices(self.shape, sparse=True)
# Shape: [batch, target_len, source_len].
# Shape: [1, target_len, 1], [1, 1, source_len].
target_positions, source_positions = target_positions[None], source_positions[None]
if self.target_positions is not None:
target_positions = self.target_positions
if target_positions.ndim not in [1, 2]:
raise NotImplementedError(f"Shape of target_positions: {target_positions.shape}.")
if target_positions.ndim == 1:
# Shape: [batch, target_len].
# Shape: [batch, 1] + [target_len] = [batch, target_len]
# pylint: disable-next=unsubscriptable-object
target_positions = target_positions[:, None] + jnp.arange(self.shape[0])
elif target_positions.ndim == 2:
Expand Down
25 changes: 24 additions & 1 deletion axlearn/common/attention_bias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import chex
import jax.numpy as jnp
import jax.util
from absl.testing import parameterized
from absl.testing import absltest, parameterized
from jax.sharding import PartitionSpec

from axlearn.common import attention_bias, test_utils
Expand Down Expand Up @@ -287,6 +287,25 @@ def test_mask_fn_attention_bias_target_positions_ndim(self):
)
self.assertNestedEqual(bias.bool_value(), expected)

def test_mask_fn_attention_bias_with_target_positions(self):
# Ensure that MaskFnAttentionBias provides the mask_fn callback with target_positions and
# source_positions tensors of the same rank.
batch, target_len, source_len = 2, 5, 4
time_step = jnp.arange(batch)

def mask_fn(target_positions, source_positions):
self.assertEqual(target_positions.shape, (batch, target_len, 1))
self.assertEqual(source_positions.shape, (1, 1, source_len))
return attention_bias.causal_mask(target_positions, source_positions)

bias = attention_bias.MaskFnAttentionBias(
mask=mask_fn, shape=(target_len, source_len), target_positions=time_step
)
ref_bias = attention_bias.MaskFnAttentionBias(
attention_bias.causal_mask, shape=(target_len, source_len), target_positions=time_step
)
chex.assert_trees_all_close(bias.value(), ref_bias.value())

def test_bool_tensor_attention_bias(self):
bias = attention_bias.BoolTensorAttentionBias.from_tensor(jnp.ones((5, 7), dtype=bool))
self.assertNestedEqual(
Expand All @@ -298,3 +317,7 @@ def test_astype(self):
self.assertEqual(bias.value().dtype, jnp.float32)
bias = bias.astype(jnp.bfloat16)
self.assertEqual(bias.value().dtype, jnp.bfloat16)


if __name__ == "__main__":
absltest.main()

0 comments on commit a7e2a95

Please sign in to comment.