Skip to content

Commit

Permalink
Revert updates to fuji model config
Browse files Browse the repository at this point in the history
  • Loading branch information
jesus-orozco committed Jan 14, 2025
1 parent c7bc5df commit a0bf9df
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions axlearn/experiments/text/gpt/fuji.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,13 @@ def get_trainer_kwargs(
ChainConfigModifier.default_config().set(
config_modifiers=[
MeshShapeModifier.default_config().set(
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=128)
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=256)
),
RematSpecModifier.default_config().set(
remat_policies={
"model.decoder.transformer.layer": RematSpec(
prevent_cse=True,
policy=jax_remat_policies.dots_saveable,
policy=offload_dots_saveable_policy,
),
}
),
Expand Down

0 comments on commit a0bf9df

Please sign in to comment.