diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 8fc53878..6cd49814 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -434,13 +434,13 @@ def get_trainer_kwargs( ChainConfigModifier.default_config().set( config_modifiers=[ MeshShapeModifier.default_config().set( - mesh_shape=mesh_shape_from_axes(data=-1, fsdp=128) + mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256) ), RematSpecModifier.default_config().set( remat_policies={ "model.decoder.transformer.layer": RematSpec( prevent_cse=True, - policy=jax_remat_policies.dots_saveable, + policy=offload_dots_saveable_policy, ), } ),