Skip to content

Commit

Permalink
Merge changes
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Jan 14, 2025
2 parents bff51dc + 3405a6e commit b5baff9
Show file tree
Hide file tree
Showing 62 changed files with 2,530 additions and 1,210 deletions.
22 changes: 8 additions & 14 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,17 @@ ARG BASE_IMAGE=python:3.10-slim

FROM ${BASE_IMAGE} AS base

RUN apt-get update
RUN apt-get install -y apt-transport-https ca-certificates gnupg curl gcc g++
# Install curl and gpupg first so that we can use them to install google-cloud-cli.
# Any RUN apt-get install step needs to have apt-get update otherwise stale package
# list may occur when previous apt-get update step is cached. See here for more info:
# https://docs.docker.com/build/building/best-practices/#apt-get
RUN apt-get update && apt-get install -y curl gnupg

# Install git.
RUN apt-get install -y git

# Install gcloud. https://cloud.google.com/sdk/docs/install
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
apt-get update -y && apt-get install google-cloud-cli -y

# Install screen and other utils for launch script.
RUN apt-get install -y jq screen ca-certificates
apt-get update -y && \
apt-get install -y apt-transport-https ca-certificates gcc g++ \
git screen ca-certificates google-perftools google-cloud-cli

# Setup.
RUN mkdir -p /root
Expand Down Expand Up @@ -88,8 +86,6 @@ FROM base AS tpu

ARG EXTRAS=

RUN apt-get install -y google-perftools

ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Ensure we install the TPU version, even if building locally.
# Jax will fallback to CPU when run on a machine without TPU.
Expand All @@ -103,8 +99,6 @@ COPY . .

FROM base AS gpu

RUN apt-get install -y google-perftools

# TODO(markblee): Support extras.
ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN pip install .[core,gpu]
Expand Down
94 changes: 77 additions & 17 deletions axlearn/common/attention.py

Large diffs are not rendered by default.

35 changes: 29 additions & 6 deletions axlearn/common/attention_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,28 @@ class BaseAttentionBias:
# If None, do not cast the dtype.
dtype: Optional[jnp.dtype] = struct.field(kw_only=True, default=None, pytree_node=False)

@final
def eval_shape(self) -> tuple[int, int, int, int]:
"""Return the shape of the bias tensor.
Note: this doesn't materialize the value. jax.eval_shape calls value(), but it only does so
using tracers.
Returns
shape: [batch or 1, num_heads or 1, target_len, source_len].
Raises:
ValueError: If the bias has no value.
"""
if not self.has_value():
raise ValueError("AttentionBias has no value.")
return jax.eval_shape(self.value).shape

@final
def has_value(self) -> bool:
"""Return whether to the bias has a value."""
return jax.eval_shape(self.value) is not None

@final
def value(self) -> Optional[Tensor]:
"""Return a tensor with the biases or None if there are no biases.
Expand Down Expand Up @@ -116,9 +138,6 @@ def _broadcast_value(cls, value: OpT) -> OpT:
return value[:, None, :, :]
raise ValueError(f"Invalid attention_logit_biases shape: {value.shape}.")

def eval_shape(self):
return jax.eval_shape(self.value).shape

def partition_spec(
self, mha_dim_to_partition_spec: dict[str, PartitionSpec]
) -> Union["BaseAttentionBias", PartitionSpec]:
Expand Down Expand Up @@ -233,7 +252,7 @@ def _nonzero(self) -> Sequence[BaseAttentionBias]:
Returned biases are not guaranteed to be nonzero, but are guaranteed to not return None.
"""
filt = lambda b: b.value() is not None
filt = lambda b: b.has_value()
return list(filter(filt, self.biases))

def bias_and_residual(self, cls: Type[B]) -> "BiasAndResidual[B]":
Expand All @@ -260,7 +279,7 @@ def bias_and_residual(self, cls: Type[B]) -> "BiasAndResidual[B]":
send_residual_to = remaining_biases
else:
send_residual_to = residuals
if bias_and_residual.residual.value() is not None:
if bias_and_residual.residual.has_value():
send_residual_to.append(bias_and_residual.residual)
return BiasAndResidual(
bias=cls.from_sequence(cls_biases), residual=CompositeAttentionBias(residuals)
Expand Down Expand Up @@ -687,7 +706,11 @@ def sliding_window_causal_mask(sliding_window_size: int) -> MaskFn:
def mask(query_position: Tensor, key_position: Tensor):
return query_position - key_position <= sliding_window_size

return and_masks(causal_mask, mask)
fun = and_masks(causal_mask, mask)
# Flash attention needs to recognize sliding window size in _to_splash_mask().
# pylint: disable-next=protected-access
fun._sliding_window_size = sliding_window_size
return fun


def make_causal_biases(seq_len: int) -> Tensor:
Expand Down
67 changes: 63 additions & 4 deletions axlearn/common/attention_bias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@


class AttentionBiasTest(test_utils.TestCase):
@parameterized.parameters(
[attention_bias.ZeroAttentionBias(), False],
[attention_bias.CausalAttentionBias(shape=(5, 5)), True],
[attention_bias.MaskFnAttentionBias(attention_bias.causal_mask, shape=(5, 5)), True],
[attention_bias.TensorAttentionBias.from_tensor(jnp.ones((5, 5))), True],
)
def test_has_bias(self, bias, expected):
self.assertEqual(bias.has_value(), expected)

def test_causal_attention_bias(self):
bias = attention_bias.CausalAttentionBias(shape=(5, 5))
chex.assert_trees_all_close(bias.value(), attention_bias.make_causal_biases(5)[None, None])
Expand All @@ -45,19 +54,19 @@ def test_base_attention_bias_value(self):
# pylint: disable=function-redefined

class TestAttentionBias(attention_bias.BaseAttentionBias):
def _value(self) -> Optional[Tensor]:
def _value(self) -> Tensor:
return jnp.ones((5, 7))

self.assertEqual(TestAttentionBias().value().shape, (1, 1, 5, 7))

class TestAttentionBias(attention_bias.BaseAttentionBias):
def _value(self) -> Optional[Tensor]:
def _value(self) -> Tensor:
return jnp.ones((3, 5, 7))

self.assertEqual(TestAttentionBias().value().shape, (3, 1, 5, 7))

class TestAttentionBias(attention_bias.BaseAttentionBias):
def _value(self) -> Optional[Tensor]:
def _value(self) -> Tensor:
return jnp.ones((2, 3, 5, 7))

self.assertEqual(TestAttentionBias().value().shape, (2, 3, 5, 7))
Expand All @@ -77,6 +86,56 @@ def test_base_attention_bias_and_residual(self):
bias.bias_and_residual(int), attention_bias.BiasAndResidual(bias=None, residual=bias)
)

@parameterized.parameters(
[
attention_bias.CompositeAttentionBias(
[attention_bias.ZeroAttentionBias(), attention_bias.ZeroAttentionBias()]
),
False,
],
[
attention_bias.CompositeAttentionBias(
[
attention_bias.CausalAttentionBias(shape=(5, 5)),
attention_bias.CausalAttentionBias(shape=(5, 5)),
]
),
True,
],
[
attention_bias.CompositeAttentionBias(
[
attention_bias.CausalAttentionBias(shape=(5, 5)),
attention_bias.ZeroAttentionBias(),
]
),
True,
],
[
attention_bias.CompositeAttentionBias(
[
attention_bias.ZeroAttentionBias(),
attention_bias.CausalAttentionBias(shape=(5, 5)),
]
),
True,
],
)
def test_composite_attention_has_bias(self, bias, expected):
self.assertEqual(bias.has_value(), expected)

def test_bias_and_residual_has_bias(self):
bias = attention_bias.CompositeAttentionBias(
[
attention_bias.CausalAttentionBias(shape=(5, 5)),
attention_bias.MaskFnAttentionBias(attention_bias.causal_mask, shape=(5, 5)),
]
)
bias_and_residual = bias.bias_and_residual(attention_bias.CausalAttentionBias)
self.assertTrue(bias_and_residual.has_value())
bias_and_residual = bias.bias_and_residual(attention_bias.MaskFnAttentionBias)
self.assertTrue(bias_and_residual.has_value())

def test_composite_attention_bias_zero(self):
# Test handling of zero biases.
bias = attention_bias.CompositeAttentionBias(
Expand Down Expand Up @@ -191,7 +250,7 @@ def test_split_subsets(
attention_bias.SegmentIdAttentionBias,
attention_bias.MaskFnAttentionBias,
)
new_bias_list = [b if b.value() is not None else None for b in new_bias_list]
new_bias_list = [b if b.has_value() else None for b in new_bias_list]
expected = [causal, segment_ids, mask, None]
for b1, b2 in jax.util.safe_zip(new_bias_list, expected):
self.assertIs(b1, b2)
Expand Down
Loading

0 comments on commit b5baff9

Please sign in to comment.