Skip to content

Fix error in AXLearn tests, and remove unnecessary ones #1443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 48 commits into
base: main
Choose a base branch
from

Conversation

Steboss
Copy link
Contributor

@Steboss Steboss commented May 9, 2025

This PR does the following:

  • remove array_serialization_test.py that's causing the test to hang with the following error:
[pod/axlearn-14925960362-xb57t/axlearn]   File "/opt/jax/jax/experimental/array_serialization/serialization.py", line 193, in __del__
[pod/axlearn-14925960362-xb57t/axlearn]     logger.warning('Please add `.wait_until_finished()` in the main thread '
[pod/axlearn-14925960362-xb57t/axlearn] Message: 'Please add `.wait_until_finished()` in the main thread before your program finishes because there is a possibility of losing errors raised if the this class is deleted before writing is completed.'
[pod/axlearn-14925960362-xb57t/axlearn] Arguments: ()
[pod/axlearn-14925960362-xb57t/axlearn] sssssssssssssss.ssssssssssss.ssssssssssssssssssssssFsFs.s...sssssssFs.Fs [ 97%]

causing the EKS job to run out of time - so we can't get the tests

  • remove tests that are redundant (namely, there are similar test already running) such as:
"/opt/axlearn/axlearn/common/deberta_test.py"
"/opt/axlearn/axlearn/common/distilbert_test.py"
"/opt/axlearn/axlearn/common/trainer_test.py"
"/opt/axlearn/axlearn/common/decoder_test.py"
"/opt/axlearn/axlearn/common/adapter_torch_test.py"
"/opt/axlearn/axlearn/common/attention_test.py"
"/opt/axlearn/axlearn/common/convolution_test.py"
  • remove tests for models that we're not currently using:
"/opt/axlearn/axlearn/common/mixture_of_experts_test.py"
"/opt/axlearn/axlearn/common/t5_test.py"
"/opt/axlearn/axlearn/common/vision_transformer_test.py"
"/opt/axlearn/axlearn/common/input_reading_comprehension_test.py"
"/opt/axlearn/axlearn/common/input_t5_test.py"
  • remove tests like summary_writer_test.py that is mostly using python library we're not employing here (e.g. wandb)
  • add the installation of pytest-xdist and pytest-reportlog to avoid the following error:
ERROR: usage: pytest [options] [file_or_dir] [file_or_dir] [...]
pytest: error: unrecognized arguments: --report-log=/tmp/tmp.iL7rVQtXBq --dist=load --tx --tx 

Overall, this should allow us to reduce the testing time from 50 minutes to 30 minutes, covering the most important tests as well, that are dealing with general AXLearn infrastructure.

Copy link
Collaborator

@olupton olupton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see an error that seems like a test/infra bug, rather than a failing test: https://github.com/NVIDIA/JAX-Toolbox/actions/runs/14931992154/job/41952767038#step:6:41440

Some of the tests look like they are failing due to missing input data.

Also, the CI job is marked as successful despite the tests failing.

@Steboss
Copy link
Contributor Author

Steboss commented May 13, 2025

In this PR I added a workflow dispatch, so we can test single parts of the CI.
In particular, this may be a better trick to be used during CI, so that we can have MODE=SOMETHING to trigger only specific parts of the workflow, rather than testing everything

@Steboss Steboss requested a review from olupton May 13, 2025 16:41
@Steboss
Copy link
Contributor Author

Steboss commented May 14, 2025

@olupton
I can see that the axlearn tests now are working fine. We have some tests that are still failing - I can have a look at those.
Working on the way we're monitoring and returning the k8s job status, as it's still giving green flag on the axlearn eks job.

@Steboss
Copy link
Contributor Author

Steboss commented May 14, 2025

@olupton
It looks like we need the XLA_FLAGS="--xla_force_host_platform_device_count=8" for the for_8_devices tests, otherwise the XLA tests will fail as:

[pod/axlearn-15024599931-2vfmc/axlearn] ___________________ HostArrayTest.test_fixed_process_shape67 ___________________
[pod/axlearn-15024599931-2vfmc/axlearn] [gw93] linux -- Python 3.12.3 /usr/bin/python3
[pod/axlearn-15024599931-2vfmc/axlearn] 
[pod/axlearn-15024599931-2vfmc/axlearn] self = <axlearn.common.host_array_test.HostArrayTest testMethod=test_fixed_process_shape67>
[pod/axlearn-15024599931-2vfmc/axlearn] platform = 'cpu', mesh_shape = (-1, 2), process_shape = [1]
[pod/axlearn-15024599931-2vfmc/axlearn] partition = PartitionSpec('data', 'model')
[pod/axlearn-15024599931-2vfmc/axlearn] 
[pod/axlearn-15024599931-2vfmc/axlearn]     @parameterized.product(
[pod/axlearn-15024599931-2vfmc/axlearn]         platform=("cpu", "tpu"),
[pod/axlearn-15024599931-2vfmc/axlearn]         mesh_shape=[
[pod/axlearn-15024599931-2vfmc/axlearn]             (-1, 1),  # Fully partitioned along one dim.
[pod/axlearn-15024599931-2vfmc/axlearn]             (2, -1),  # Partitioned along multiple dims.
[pod/axlearn-15024599931-2vfmc/axlearn]             (-1, 2),  # Test the other way.
[pod/axlearn-15024599931-2vfmc/axlearn]             (1, -1),
[pod/axlearn-15024599931-2vfmc/axlearn]         ],
[pod/axlearn-15024599931-2vfmc/axlearn]         process_shape=[
[pod/axlearn-15024599931-2vfmc/axlearn]             # Each process produces single dim.
[pod/axlearn-15024599931-2vfmc/axlearn]             [1],  # Not divisible by number of devices (replicated).
[pod/axlearn-15024599931-2vfmc/axlearn]             [8],  # Divisible by number of devices.
[pod/axlearn-15024599931-2vfmc/axlearn]             [16],  # Multiple elements per device.
[pod/axlearn-15024599931-2vfmc/axlearn]             # Each process produces multiple dims.
[pod/axlearn-15024599931-2vfmc/axlearn]             [1, 1],  # Not divisible by number of devices (replicated).
[pod/axlearn-15024599931-2vfmc/axlearn]             [2, 1],  # Can be partitioned over dim=0, replicated on dim=1.
[pod/axlearn-15024599931-2vfmc/axlearn]             [16, 1],  # Multiple elements per device.
[pod/axlearn-15024599931-2vfmc/axlearn]             [2, 4],  # Can be fully partitioned.
[pod/axlearn-15024599931-2vfmc/axlearn]             [8, 8],  # Can be fully partitioned.
[pod/axlearn-15024599931-2vfmc/axlearn]         ],
[pod/axlearn-15024599931-2vfmc/axlearn]         partition=(
[pod/axlearn-15024599931-2vfmc/axlearn]             DataPartitionType.FULL,
[pod/axlearn-15024599931-2vfmc/axlearn]             DataPartitionType.REPLICATED,
[pod/axlearn-15024599931-2vfmc/axlearn]             PartitionSpec("data"),
[pod/axlearn-15024599931-2vfmc/axlearn]             PartitionSpec("data", "model"),
[pod/axlearn-15024599931-2vfmc/axlearn]         ),
[pod/axlearn-15024599931-2vfmc/axlearn]     )
[pod/axlearn-15024599931-2vfmc/axlearn]     # NOTE: while annotated with `for_8_devices`, this runs on other configurations.
[pod/axlearn-15024599931-2vfmc/axlearn]     @pytest.mark.for_8_devices
[pod/axlearn-15024599931-2vfmc/axlearn]     def test_fixed_process_shape(
[pod/axlearn-15024599931-2vfmc/axlearn]         self,
[pod/axlearn-15024599931-2vfmc/axlearn]         platform: str,
[pod/axlearn-15024599931-2vfmc/axlearn]         mesh_shape: tuple[int, int],
[pod/axlearn-15024599931-2vfmc/axlearn]         process_shape: Sequence[int],
[pod/axlearn-15024599931-2vfmc/axlearn]         partition: Union[DataPartitionType, PartitionSpec],
[pod/axlearn-15024599931-2vfmc/axlearn]     ):
[pod/axlearn-15024599931-2vfmc/axlearn]         """Tests roundtrip host-to-global and global-to-host with fixed process shape."""
[pod/axlearn-15024599931-2vfmc/axlearn]     
[pod/axlearn-15024599931-2vfmc/axlearn] >       mesh_shape = infer_mesh_shape(mesh_shape)
[pod/axlearn-15024599931-2vfmc/axlearn] 
[pod/axlearn-15024599931-2vfmc/axlearn] axlearn/common/host_array_test.py:124: 
[pod/axlearn-15024599931-2vfmc/axlearn] _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
[pod/axlearn-15024599931-2vfmc/axlearn] 
[pod/axlearn-15024599931-2vfmc/axlearn] mesh_shape = (-1, 2)
[pod/axlearn-15024599931-2vfmc/axlearn] 
[pod/axlearn-15024599931-2vfmc/axlearn]     def infer_mesh_shape(mesh_shape: MeshShape, *, num_devices: Optional[int] = None) -> MeshShape:
[pod/axlearn-15024599931-2vfmc/axlearn]         """Infer the value for -1 from len(jax.devices()) and other dims if there is -1 in mesh shape.
[pod/axlearn-15024599931-2vfmc/axlearn]     
[pod/axlearn-15024599931-2vfmc/axlearn]         Args:
[pod/axlearn-15024599931-2vfmc/axlearn]             mesh_shape: The original MeshShape, which might have -1 in one axis.
[pod/axlearn-15024599931-2vfmc/axlearn]             num_devices: The devices that will be used to construct the mesh.
[pod/axlearn-15024599931-2vfmc/axlearn]                 If None, defaults to len(jax.devices()).
[pod/axlearn-15024599931-2vfmc/axlearn]     
[pod/axlearn-15024599931-2vfmc/axlearn]         Returns
[pod/axlearn-15024599931-2vfmc/axlearn]             A new MeshShape with inferred value for -1.
[pod/axlearn-15024599931-2vfmc/axlearn]         """
[pod/axlearn-15024599931-2vfmc/axlearn]         if -1 not in mesh_shape:
[pod/axlearn-15024599931-2vfmc/axlearn]             return mesh_shape
[pod/axlearn-15024599931-2vfmc/axlearn]     
[pod/axlearn-15024599931-2vfmc/axlearn]         if mesh_shape.count(-1) > 1:
[pod/axlearn-15024599931-2vfmc/axlearn]             raise ValueError(f"Only one axis can be -1 in {mesh_shape=}.")
[pod/axlearn-15024599931-2vfmc/axlearn]     
[pod/axlearn-15024599931-2vfmc/axlearn]         # Handle the case with one -1.
[pod/axlearn-15024599931-2vfmc/axlearn]         prod = math.prod(mesh_shape, start=-1)
[pod/axlearn-15024599931-2vfmc/axlearn]         if num_devices is None:
[pod/axlearn-15024599931-2vfmc/axlearn]             num_devices = len(jax.devices())
[pod/axlearn-15024599931-2vfmc/axlearn]         if num_devices % prod != 0:
[pod/axlearn-15024599931-2vfmc/axlearn] >           raise ValueError(
[pod/axlearn-15024599931-2vfmc/axlearn]                 f"Unable to infer -1 in mesh shape {mesh_shape} as num_devices {num_devices} "
[pod/axlearn-15024599931-2vfmc/axlearn]                 f"is not a multiple of the product {prod} of mesh axes."
[pod/axlearn-15024599931-2vfmc/axlearn]             )
[pod/axlearn-15024599931-2vfmc/axlearn] E           ValueError: Unable to infer -1 in mesh shape (-1, 2) as num_devices 1 is not a multiple of the product 2 of mesh axes.
[pod/axlearn-15024599931-2vfmc/axlearn] 
[pod/axlearn-15024599931-2vfmc/axlearn] axlearn/common/utils.py:1834: ValueError

I ran a test, that results in:

  • with flag: 1 failed, 161 passed
  • without flag: 203 failed, 100 passed

@Steboss
Copy link
Contributor Author

Steboss commented May 14, 2025

This may be a new JAX version error

AttributeError: module 'jax.experimental.array_serialization.serialization' has no attribute '_spec_has_metadata'

@Steboss Steboss requested a review from olupton May 16, 2025 16:29
@Steboss Steboss requested a review from olupton May 19, 2025 16:39
olupton
olupton previously approved these changes May 20, 2025
Copy link
Collaborator

@olupton olupton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can merge this to speed up the pipeline, but I left some more comments on error handling and robustness.

# Run tests
pytest-xdist.sh 1 6 ${LOG_DIR}/axlearn-unittests.jsonl test-axlearn.sh --directory "." --output ${LOG_DIR} --test-files "/opt/axlearn/axlearn/common/*_test.py" | tee -a ${LOG_DIR}/pytest_stdout.log

# test on JAX, make sure 8 devices are visible
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably this means we are leaving parallelism on the table by launching 1-GPU tests with 8 visible.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed :)

echo "Total number of failed tests ${failed}"
echo "Total number of skipped tests ${skipped}"
# add those to summary.txt and we're using it for extracting values
echo "PASSED: ${passed} FAILED: ${failed} SKIPPED: ${skipped}" >> ${LOG_DIRECTORY}/summary.txt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Steboss Steboss requested a review from olupton May 21, 2025 12:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants