-
Notifications
You must be signed in to change notification settings - Fork 59
Fix error in AXLearn tests, and remove unnecessary ones #1443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see an error that seems like a test/infra bug, rather than a failing test: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/14931992154/job/41952767038#step:6:41440
Some of the tests look like they are failing due to missing input data.
Also, the CI job is marked as successful despite the tests failing.
…olbox into sbosisio/fix_axlearn_tests
In this PR I added a workflow dispatch, so we can test single parts of the CI. |
@olupton |
@olupton [pod/axlearn-15024599931-2vfmc/axlearn] ___________________ HostArrayTest.test_fixed_process_shape67 ___________________
[pod/axlearn-15024599931-2vfmc/axlearn] [gw93] linux -- Python 3.12.3 /usr/bin/python3
[pod/axlearn-15024599931-2vfmc/axlearn]
[pod/axlearn-15024599931-2vfmc/axlearn] self = <axlearn.common.host_array_test.HostArrayTest testMethod=test_fixed_process_shape67>
[pod/axlearn-15024599931-2vfmc/axlearn] platform = 'cpu', mesh_shape = (-1, 2), process_shape = [1]
[pod/axlearn-15024599931-2vfmc/axlearn] partition = PartitionSpec('data', 'model')
[pod/axlearn-15024599931-2vfmc/axlearn]
[pod/axlearn-15024599931-2vfmc/axlearn] @parameterized.product(
[pod/axlearn-15024599931-2vfmc/axlearn] platform=("cpu", "tpu"),
[pod/axlearn-15024599931-2vfmc/axlearn] mesh_shape=[
[pod/axlearn-15024599931-2vfmc/axlearn] (-1, 1), # Fully partitioned along one dim.
[pod/axlearn-15024599931-2vfmc/axlearn] (2, -1), # Partitioned along multiple dims.
[pod/axlearn-15024599931-2vfmc/axlearn] (-1, 2), # Test the other way.
[pod/axlearn-15024599931-2vfmc/axlearn] (1, -1),
[pod/axlearn-15024599931-2vfmc/axlearn] ],
[pod/axlearn-15024599931-2vfmc/axlearn] process_shape=[
[pod/axlearn-15024599931-2vfmc/axlearn] # Each process produces single dim.
[pod/axlearn-15024599931-2vfmc/axlearn] [1], # Not divisible by number of devices (replicated).
[pod/axlearn-15024599931-2vfmc/axlearn] [8], # Divisible by number of devices.
[pod/axlearn-15024599931-2vfmc/axlearn] [16], # Multiple elements per device.
[pod/axlearn-15024599931-2vfmc/axlearn] # Each process produces multiple dims.
[pod/axlearn-15024599931-2vfmc/axlearn] [1, 1], # Not divisible by number of devices (replicated).
[pod/axlearn-15024599931-2vfmc/axlearn] [2, 1], # Can be partitioned over dim=0, replicated on dim=1.
[pod/axlearn-15024599931-2vfmc/axlearn] [16, 1], # Multiple elements per device.
[pod/axlearn-15024599931-2vfmc/axlearn] [2, 4], # Can be fully partitioned.
[pod/axlearn-15024599931-2vfmc/axlearn] [8, 8], # Can be fully partitioned.
[pod/axlearn-15024599931-2vfmc/axlearn] ],
[pod/axlearn-15024599931-2vfmc/axlearn] partition=(
[pod/axlearn-15024599931-2vfmc/axlearn] DataPartitionType.FULL,
[pod/axlearn-15024599931-2vfmc/axlearn] DataPartitionType.REPLICATED,
[pod/axlearn-15024599931-2vfmc/axlearn] PartitionSpec("data"),
[pod/axlearn-15024599931-2vfmc/axlearn] PartitionSpec("data", "model"),
[pod/axlearn-15024599931-2vfmc/axlearn] ),
[pod/axlearn-15024599931-2vfmc/axlearn] )
[pod/axlearn-15024599931-2vfmc/axlearn] # NOTE: while annotated with `for_8_devices`, this runs on other configurations.
[pod/axlearn-15024599931-2vfmc/axlearn] @pytest.mark.for_8_devices
[pod/axlearn-15024599931-2vfmc/axlearn] def test_fixed_process_shape(
[pod/axlearn-15024599931-2vfmc/axlearn] self,
[pod/axlearn-15024599931-2vfmc/axlearn] platform: str,
[pod/axlearn-15024599931-2vfmc/axlearn] mesh_shape: tuple[int, int],
[pod/axlearn-15024599931-2vfmc/axlearn] process_shape: Sequence[int],
[pod/axlearn-15024599931-2vfmc/axlearn] partition: Union[DataPartitionType, PartitionSpec],
[pod/axlearn-15024599931-2vfmc/axlearn] ):
[pod/axlearn-15024599931-2vfmc/axlearn] """Tests roundtrip host-to-global and global-to-host with fixed process shape."""
[pod/axlearn-15024599931-2vfmc/axlearn]
[pod/axlearn-15024599931-2vfmc/axlearn] > mesh_shape = infer_mesh_shape(mesh_shape)
[pod/axlearn-15024599931-2vfmc/axlearn]
[pod/axlearn-15024599931-2vfmc/axlearn] axlearn/common/host_array_test.py:124:
[pod/axlearn-15024599931-2vfmc/axlearn] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
[pod/axlearn-15024599931-2vfmc/axlearn]
[pod/axlearn-15024599931-2vfmc/axlearn] mesh_shape = (-1, 2)
[pod/axlearn-15024599931-2vfmc/axlearn]
[pod/axlearn-15024599931-2vfmc/axlearn] def infer_mesh_shape(mesh_shape: MeshShape, *, num_devices: Optional[int] = None) -> MeshShape:
[pod/axlearn-15024599931-2vfmc/axlearn] """Infer the value for -1 from len(jax.devices()) and other dims if there is -1 in mesh shape.
[pod/axlearn-15024599931-2vfmc/axlearn]
[pod/axlearn-15024599931-2vfmc/axlearn] Args:
[pod/axlearn-15024599931-2vfmc/axlearn] mesh_shape: The original MeshShape, which might have -1 in one axis.
[pod/axlearn-15024599931-2vfmc/axlearn] num_devices: The devices that will be used to construct the mesh.
[pod/axlearn-15024599931-2vfmc/axlearn] If None, defaults to len(jax.devices()).
[pod/axlearn-15024599931-2vfmc/axlearn]
[pod/axlearn-15024599931-2vfmc/axlearn] Returns
[pod/axlearn-15024599931-2vfmc/axlearn] A new MeshShape with inferred value for -1.
[pod/axlearn-15024599931-2vfmc/axlearn] """
[pod/axlearn-15024599931-2vfmc/axlearn] if -1 not in mesh_shape:
[pod/axlearn-15024599931-2vfmc/axlearn] return mesh_shape
[pod/axlearn-15024599931-2vfmc/axlearn]
[pod/axlearn-15024599931-2vfmc/axlearn] if mesh_shape.count(-1) > 1:
[pod/axlearn-15024599931-2vfmc/axlearn] raise ValueError(f"Only one axis can be -1 in {mesh_shape=}.")
[pod/axlearn-15024599931-2vfmc/axlearn]
[pod/axlearn-15024599931-2vfmc/axlearn] # Handle the case with one -1.
[pod/axlearn-15024599931-2vfmc/axlearn] prod = math.prod(mesh_shape, start=-1)
[pod/axlearn-15024599931-2vfmc/axlearn] if num_devices is None:
[pod/axlearn-15024599931-2vfmc/axlearn] num_devices = len(jax.devices())
[pod/axlearn-15024599931-2vfmc/axlearn] if num_devices % prod != 0:
[pod/axlearn-15024599931-2vfmc/axlearn] > raise ValueError(
[pod/axlearn-15024599931-2vfmc/axlearn] f"Unable to infer -1 in mesh shape {mesh_shape} as num_devices {num_devices} "
[pod/axlearn-15024599931-2vfmc/axlearn] f"is not a multiple of the product {prod} of mesh axes."
[pod/axlearn-15024599931-2vfmc/axlearn] )
[pod/axlearn-15024599931-2vfmc/axlearn] E ValueError: Unable to infer -1 in mesh shape (-1, 2) as num_devices 1 is not a multiple of the product 2 of mesh axes.
[pod/axlearn-15024599931-2vfmc/axlearn]
[pod/axlearn-15024599931-2vfmc/axlearn] axlearn/common/utils.py:1834: ValueError I ran a test, that results in:
|
This may be a new JAX version error
|
Co-authored-by: Olli Lupton <olupton@nvidia.com>
…olbox into sbosisio/fix_axlearn_tests
…olbox into sbosisio/fix_axlearn_tests
…olbox into sbosisio/fix_axlearn_tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can merge this to speed up the pipeline, but I left some more comments on error handling and robustness.
# Run tests | ||
pytest-xdist.sh 1 6 ${LOG_DIR}/axlearn-unittests.jsonl test-axlearn.sh --directory "." --output ${LOG_DIR} --test-files "/opt/axlearn/axlearn/common/*_test.py" | tee -a ${LOG_DIR}/pytest_stdout.log | ||
|
||
# test on JAX, make sure 8 devices are visible |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Presumably this means we are leaving parallelism on the table by launching 1-GPU tests with 8 visible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed :)
echo "Total number of failed tests ${failed}" | ||
echo "Total number of skipped tests ${skipped}" | ||
# add those to summary.txt and we're using it for extracting values | ||
echo "PASSED: ${passed} FAILED: ${failed} SKIPPED: ${skipped}" >> ${LOG_DIRECTORY}/summary.txt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not capturing this error in the run: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/15144748689/job/42579078908?pr=1443#step:6:680
This PR does the following:
array_serialization_test.py
that's causing the test to hang with the following error:causing the EKS job to run out of time - so we can't get the tests
summary_writer_test.py
that is mostly using python library we're not employing here (e.g.wandb
)pytest-xdist
andpytest-reportlog
to avoid the following error:Overall, this should allow us to reduce the testing time from 50 minutes to 30 minutes, covering the most important tests as well, that are dealing with general AXLearn infrastructure.