diff --git a/axlearn/common/compiler_options.py b/axlearn/common/compiler_options.py index d4555b11..c1c056cd 100644 --- a/axlearn/common/compiler_options.py +++ b/axlearn/common/compiler_options.py @@ -174,6 +174,9 @@ def infer_tpu_version(tpu_type: str) -> str: """ tpu_type = infer_tpu_type(tpu_type) tpu_version = tpu_type.rsplit("-", 1)[0] # split from the last occurrence of '-' + # Resolve aliases like v5e to v5litepod, since in some cases (e.g. aot compilation) v5e is + # expected. + tpu_version = _TPU_VERSION_ALIASES.get(tpu_version, tpu_version) if tpu_version not in _TPU_VERSIONS: raise ValueError(f"Unknown TPU version {tpu_version}. Expected one of {_TPU_VERSIONS}") return tpu_version @@ -238,4 +241,5 @@ def infer_xsc_compiler_options( return options +_TPU_VERSION_ALIASES = {"v5e": "v5litepod"} _TPU_VERSIONS = ("v3", "v4", "v5litepod", "v5p", "v6e") diff --git a/axlearn/common/compiler_options_test.py b/axlearn/common/compiler_options_test.py index 30d1b2d3..ff146c2e 100644 --- a/axlearn/common/compiler_options_test.py +++ b/axlearn/common/compiler_options_test.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp import pytest +from absl.testing import parameterized from axlearn.common import compiler_options, test_utils from axlearn.common.utils import Tensor @@ -48,3 +49,9 @@ def test_xsc_compiler_options(self): ) for name, option in options.items(): self.assertEqual(option, expected_options[name]) + + @parameterized.parameters( + dict(tpu_type="v5e-16", expected="v5litepod"), + ) + def test_tpu_version_alias(self, tpu_type: str, expected: str): + self.assertEqual(expected, compiler_options.infer_tpu_version(tpu_type)) diff --git a/axlearn/common/input_base.py b/axlearn/common/input_base.py index 31b7d3e9..efec40ea 100644 --- a/axlearn/common/input_base.py +++ b/axlearn/common/input_base.py @@ -148,3 +148,10 @@ def constrain_batch_axis(batch): global_physical_batch, batch_axis_names=batch_axis_names ) return constrain_batch_axis(global_logical_batch) + + def element_spec(self) -> Nested[jax.ShapeDtypeStruct]: + """Returns the per-feed logical batch spec. + + This is used e.g. for AOT compilation and is not strictly required for training. + """ + raise NotImplementedError(type(self)) diff --git a/axlearn/common/input_grain.py b/axlearn/common/input_grain.py index e06d8ae3..10c78456 100644 --- a/axlearn/common/input_grain.py +++ b/axlearn/common/input_grain.py @@ -159,6 +159,7 @@ def maybe_repeat(ds: Dataset): ds = ds.repeat() return ds + # TODO(markblee): Support mixing grain.IterDataset. return grain.MapDataset.mix( datasets=[maybe_repeat(source) for source in sources], weights=weights, @@ -629,3 +630,22 @@ def dataset(self) -> grain.IterDataset: f"Please make sure to call {shard_dataset.__name__} if using input dispatch." ) return maybe_to_iter_dataset(ds) + + def element_spec(self) -> utils.Nested[jax.ShapeDtypeStruct]: + """Infers the element spec. + + Grain requires fetching an example from the dataset to extract the spec. To avoid reading + actual data, replace your source dataset with one from `input_fake.fake_grain_source`. + """ + ds = self.dataset() + if isinstance(ds, grain.MapDataset): + example = ds[0] + else: + example = next(ds.__iter__()) # pylint: disable=unnecessary-dunder-call + + def shape_dtype(x): + if not hasattr(x, "shape") or not hasattr(x, "dtype"): + raise ValueError(f"element_spec() requires Tensor-like leaves, got: {x}.") + return jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype) + + return jax.tree.map(shape_dtype, example) diff --git a/axlearn/common/input_grain_test.py b/axlearn/common/input_grain_test.py index 4cfed0e7..b052d2ca 100644 --- a/axlearn/common/input_grain_test.py +++ b/axlearn/common/input_grain_test.py @@ -567,3 +567,24 @@ def test_dispatch_tpu(self): # Should contain the right ids. self.assertEqual([0, 1, 2, 3], replicate_to_local_data(batch)["input_ids"].tolist()) break + + def test_element_spec(self): + ds = range_dataset(start=0, stop=10, seed=123).map(lambda x: {"input_ids": x}) + grain_input: Input = self._input_config(ds).instantiate(parent=None) + # element_spec() requires Tensor-like leaves. + with self.assertRaisesRegex(ValueError, "Tensor"): + grain_input.element_spec() + + ds = range_dataset(start=0, stop=10, seed=123).map(lambda x: {"input_ids": np.array(x)}) + cfg = self._input_config( + ds.repeat(num_epochs=None), + per_process=lambda ds: ds.batch(2), + process_count=4, + process_index=0, + ) + grain_input: Input = cfg.instantiate(parent=None) + self.assertEqual( + # Element spec should reflect the per-process shape. + {"input_ids": jax.ShapeDtypeStruct(shape=(2,), dtype=np.int64)}, + grain_input.element_spec(), + ) diff --git a/axlearn/common/input_tf_data.py b/axlearn/common/input_tf_data.py index 1d44ad1b..6cb1e752 100644 --- a/axlearn/common/input_tf_data.py +++ b/axlearn/common/input_tf_data.py @@ -46,6 +46,7 @@ from axlearn.common.module import Module from axlearn.common.utils import ( PHYSICAL_TO_LOGICAL_DISPATCH_KEY, + Nested, Tensor, get_data_dir, get_recursively, @@ -1209,6 +1210,16 @@ def processor(self) -> DatasetToDatasetFn: def dataset(self) -> tf.data.Dataset: return self._batcher(self._processor(self._source())) + def element_spec(self) -> Nested[jax.ShapeDtypeStruct]: + """Returns the tfds element spec.""" + + return jax.tree.map( + lambda tf_spec: jax.ShapeDtypeStruct( + shape=tf_spec.shape, dtype=tf_spec.dtype.as_numpy_dtype + ), + self.dataset().element_spec, + ) + def disable_shuffle_recursively(cfg: Input.Config): """Disables all shuffling on the input config. diff --git a/axlearn/common/input_tf_data_test.py b/axlearn/common/input_tf_data_test.py index 63b859b4..f43c5d5e 100644 --- a/axlearn/common/input_tf_data_test.py +++ b/axlearn/common/input_tf_data_test.py @@ -1583,5 +1583,28 @@ def test_disable_shuffle_recursively(self): self.assertEqual(cfg.source.source.train_shuffle_files, False) +class ElementSpecTest(parameterized.TestCase): + """Tests Input.element_spec().""" + + def test_element_spec(self): + cfg = Input.default_config().set( + source=config_for_function(with_processor).set( + source=config_for_function(fake_text_source), + processor=config_for_function(identity), + ), + processor=config_for_function(identity), + batcher=config_for_function(batch).set( + global_batch_size=2, + pad_example_fn=default_pad_example_fn, + ), + is_training=True, + name="test", + ) + self.assertEqual( + {"text": jax.ShapeDtypeStruct(shape=(2,), dtype=object)}, + cfg.instantiate(parent=None).element_spec(), + ) + + if __name__ == "__main__": absltest.main() diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index a6056076..1bc9b601 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -1114,13 +1114,7 @@ def compile_train_step( if input_batch is None: # Infer input batch shapes from input element spec. # N.B. in a multi-process setting these will be host-local (per process). - # TODO(markblee): This path currently assumes input_tf_data; fix for generic inputs. - input_batch = jax.tree.map( - lambda tf_spec: jax.ShapeDtypeStruct( - shape=tf_spec.shape, dtype=tf_spec.dtype.as_numpy_dtype - ), - self.input.dataset().element_spec, # pytype: disable=attribute-error - ) + input_batch = self.input.element_spec() # Rely on the instance handle to ensure that we hit the compilation cache if possible. jit_train_step = self._jit_train_step or self._pjit_train_step() # Note(Jan 2022): diff --git a/axlearn/common/trainer_test.py b/axlearn/common/trainer_test.py index 39177429..3851c097 100644 --- a/axlearn/common/trainer_test.py +++ b/axlearn/common/trainer_test.py @@ -2,10 +2,12 @@ """Tests SpmdTrainer.""" -# pylint: disable=no-self-use import copy import dataclasses import math + +# pylint: disable=no-self-use +import os import os.path import shutil import tempfile @@ -70,6 +72,8 @@ NUM_CLASSES = 16 +os.environ["TPU_SKIP_MDS_QUERY"] = "1" + class DummyInput(Module): """A dummy input.""" @@ -172,6 +176,14 @@ def __iter__(self): # guaranteed to be savable). yield from self.dataset() + def element_spec(self): + return jax.tree.map( + lambda tf_spec: jax.ShapeDtypeStruct( + shape=tf_spec.shape, dtype=tf_spec.dtype.as_numpy_dtype + ), + self.dataset().element_spec, + ) + class DummyModel(BaseModel): """A dummy model."""