Skip to content

Commit

Permalink
Fix aot compilation with grain inputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
markblee committed Jan 15, 2025
1 parent 3405a6e commit e2ac4da
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 8 deletions.
4 changes: 4 additions & 0 deletions axlearn/common/compiler_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def infer_tpu_version(tpu_type: str) -> str:
"""
tpu_type = infer_tpu_type(tpu_type)
tpu_version = tpu_type.rsplit("-", 1)[0] # split from the last occurrence of '-'
# Resolve aliases like v5e to v5litepod, since in some cases (e.g. aot compilation) v5e is
# expected.
tpu_version = _TPU_VERSION_ALIASES.get(tpu_version, tpu_version)
if tpu_version not in _TPU_VERSIONS:
raise ValueError(f"Unknown TPU version {tpu_version}. Expected one of {_TPU_VERSIONS}")
return tpu_version
Expand Down Expand Up @@ -238,4 +241,5 @@ def infer_xsc_compiler_options(
return options


_TPU_VERSION_ALIASES = {"v5e": "v5litepod"}
_TPU_VERSIONS = ("v3", "v4", "v5litepod", "v5p", "v6e")
7 changes: 7 additions & 0 deletions axlearn/common/compiler_options_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
import pytest
from absl.testing import parameterized

from axlearn.common import compiler_options, test_utils
from axlearn.common.utils import Tensor
Expand Down Expand Up @@ -48,3 +49,9 @@ def test_xsc_compiler_options(self):
)
for name, option in options.items():
self.assertEqual(option, expected_options[name])

@parameterized.parameters(
dict(tpu_type="v5e-16", expected="v5litepod"),
)
def test_tpu_version_alias(self, tpu_type: str, expected: str):
self.assertEqual(expected, compiler_options.infer_tpu_version(tpu_type))
7 changes: 7 additions & 0 deletions axlearn/common/input_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,10 @@ def constrain_batch_axis(batch):
global_physical_batch, batch_axis_names=batch_axis_names
)
return constrain_batch_axis(global_logical_batch)

def element_spec(self) -> Nested[jax.ShapeDtypeStruct]:
"""Returns the per-feed logical batch spec.
This is used e.g. for AOT compilation and is not strictly required for training.
"""
raise NotImplementedError(type(self))
20 changes: 20 additions & 0 deletions axlearn/common/input_grain.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def maybe_repeat(ds: Dataset):
ds = ds.repeat()
return ds

# TODO(markblee): Support mixing grain.IterDataset.
return grain.MapDataset.mix(
datasets=[maybe_repeat(source) for source in sources],
weights=weights,
Expand Down Expand Up @@ -629,3 +630,22 @@ def dataset(self) -> grain.IterDataset:
f"Please make sure to call {shard_dataset.__name__} if using input dispatch."
)
return maybe_to_iter_dataset(ds)

def element_spec(self) -> utils.Nested[jax.ShapeDtypeStruct]:
"""Infers the element spec.
Grain requires fetching an example from the dataset to extract the spec. To avoid reading
actual data, replace your source dataset with one from `input_fake.fake_grain_source`.
"""
ds = self.dataset()
if isinstance(ds, grain.MapDataset):
example = ds[0]
else:
example = next(ds.__iter__()) # pylint: disable=unnecessary-dunder-call

def shape_dtype(x):
if not hasattr(x, "shape") or not hasattr(x, "dtype"):
raise ValueError(f"element_spec() requires Tensor-like leaves, got: {x}.")
return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype)

return jax.tree.map(shape_dtype, example)
21 changes: 21 additions & 0 deletions axlearn/common/input_grain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,3 +567,24 @@ def test_dispatch_tpu(self):
# Should contain the right ids.
self.assertEqual([0, 1, 2, 3], replicate_to_local_data(batch)["input_ids"].tolist())
break

def test_element_spec(self):
ds = range_dataset(start=0, stop=10, seed=123).map(lambda x: {"input_ids": x})
grain_input: Input = self._input_config(ds).instantiate(parent=None)
# element_spec() requires Tensor-like leaves.
with self.assertRaisesRegex(ValueError, "Tensor"):
grain_input.element_spec()

ds = range_dataset(start=0, stop=10, seed=123).map(lambda x: {"input_ids": np.array(x)})
cfg = self._input_config(
ds.repeat(num_epochs=None),
per_process=lambda ds: ds.batch(2),
process_count=4,
process_index=0,
)
grain_input: Input = cfg.instantiate(parent=None)
self.assertEqual(
# Element spec should reflect the per-process shape.
{"input_ids": jax.ShapeDtypeStruct(shape=(2,), dtype=np.int64)},
grain_input.element_spec(),
)
11 changes: 11 additions & 0 deletions axlearn/common/input_tf_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from axlearn.common.module import Module
from axlearn.common.utils import (
PHYSICAL_TO_LOGICAL_DISPATCH_KEY,
Nested,
Tensor,
get_data_dir,
get_recursively,
Expand Down Expand Up @@ -1209,6 +1210,16 @@ def processor(self) -> DatasetToDatasetFn:
def dataset(self) -> tf.data.Dataset:
return self._batcher(self._processor(self._source()))

def element_spec(self) -> Nested[jax.ShapeDtypeStruct]:
"""Returns the tfds element spec."""

return jax.tree.map(
lambda tf_spec: jax.ShapeDtypeStruct(
shape=tf_spec.shape, dtype=tf_spec.dtype.as_numpy_dtype
),
self.dataset().element_spec,
)


def disable_shuffle_recursively(cfg: Input.Config):
"""Disables all shuffling on the input config.
Expand Down
23 changes: 23 additions & 0 deletions axlearn/common/input_tf_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,5 +1583,28 @@ def test_disable_shuffle_recursively(self):
self.assertEqual(cfg.source.source.train_shuffle_files, False)


class ElementSpecTest(parameterized.TestCase):
"""Tests Input.element_spec()."""

def test_element_spec(self):
cfg = Input.default_config().set(
source=config_for_function(with_processor).set(
source=config_for_function(fake_text_source),
processor=config_for_function(identity),
),
processor=config_for_function(identity),
batcher=config_for_function(batch).set(
global_batch_size=2,
pad_example_fn=default_pad_example_fn,
),
is_training=True,
name="test",
)
self.assertEqual(
{"text": jax.ShapeDtypeStruct(shape=(2,), dtype=object)},
cfg.instantiate(parent=None).element_spec(),
)


if __name__ == "__main__":
absltest.main()
8 changes: 1 addition & 7 deletions axlearn/common/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,13 +1114,7 @@ def compile_train_step(
if input_batch is None:
# Infer input batch shapes from input element spec.
# N.B. in a multi-process setting these will be host-local (per process).
# TODO(markblee): This path currently assumes input_tf_data; fix for generic inputs.
input_batch = jax.tree.map(
lambda tf_spec: jax.ShapeDtypeStruct(
shape=tf_spec.shape, dtype=tf_spec.dtype.as_numpy_dtype
),
self.input.dataset().element_spec, # pytype: disable=attribute-error
)
input_batch = self.input.element_spec()
# Rely on the instance handle to ensure that we hit the compilation cache if possible.
jit_train_step = self._jit_train_step or self._pjit_train_step()
# Note(Jan 2022):
Expand Down
14 changes: 13 additions & 1 deletion axlearn/common/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

"""Tests SpmdTrainer."""

# pylint: disable=no-self-use
import copy
import dataclasses
import math

# pylint: disable=no-self-use
import os
import os.path
import shutil
import tempfile
Expand Down Expand Up @@ -70,6 +72,8 @@

NUM_CLASSES = 16

os.environ["TPU_SKIP_MDS_QUERY"] = "1"


class DummyInput(Module):
"""A dummy input."""
Expand Down Expand Up @@ -172,6 +176,14 @@ def __iter__(self):
# guaranteed to be savable).
yield from self.dataset()

def element_spec(self):
return jax.tree.map(
lambda tf_spec: jax.ShapeDtypeStruct(
shape=tf_spec.shape, dtype=tf_spec.dtype.as_numpy_dtype
),
self.dataset().element_spec,
)


class DummyModel(BaseModel):
"""A dummy model."""
Expand Down

0 comments on commit e2ac4da

Please sign in to comment.