Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learner: add new unittests using Model. #900

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions axlearn/common/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,15 +526,12 @@ def should_apply(tree: Nested[Any]) -> Nested[bool]:
sub_learner_updates = sub_learner_updates.mask(
# pylint: disable-next=cell-var-from-loop
lambda _: should_apply(updates.opt_params),
fields=(
"opt_params",
"delta_updates",
),
fields=("opt_params", "delta_updates"),
)
sub_learner_updated_model_params = getattr(self, name).update(sub_learner_updates)
updated_model_params = jax.tree.map(
lambda apply, new_v, old_v: new_v if apply else old_v,
should_apply(updates.param_values()),
should_apply(updated_model_params),
sub_learner_updated_model_params,
updated_model_params,
)
Expand Down Expand Up @@ -712,7 +709,7 @@ def _value_and_grad(

split_params = split_params_fn(opt_params)
model_params_grad, model_params_nograd = jax.tree.map(lambda p: p.value, split_params)
(_, forward_pass), grads = jax.value_and_grad(loss_fun, has_aux=True)(
(unused_loss, forward_pass), grads = jax.value_and_grad(loss_fun, has_aux=True)(
model_params_grad, inputs=(model_params_nograd, inputs)
)
return Updates(
Expand Down
210 changes: 209 additions & 1 deletion axlearn/common/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import axlearn.common.update_transformation_test
from axlearn.common import schedule
from axlearn.common.base_layer import FactorizationSpec, ParameterSpec
from axlearn.common.base_model import BaseModel
from axlearn.common.config import REQUIRED, Required, config_class, config_for_function
from axlearn.common.gradient_accumulation import with_minibatch_steps
from axlearn.common.layers import Linear
from axlearn.common.learner import (
CompositeLearner,
Learner,
Expand All @@ -28,7 +30,7 @@
should_update_with_optimizers,
)
from axlearn.common.metrics import MetricAccumulator, WeightedScalar
from axlearn.common.module import OutputCollection
from axlearn.common.module import OutputCollection, child_context
from axlearn.common.module import functional as F
from axlearn.common.module import new_output_collection
from axlearn.common.optimizer_base import OptParam, OptStateSpec
Expand All @@ -50,6 +52,7 @@
)
from axlearn.common.utils import (
Nested,
NestedTensor,
PartitionSpec,
Tensor,
VDict,
Expand All @@ -59,7 +62,113 @@
)


class TestModel(BaseModel):
"""A simple model for test."""

@config_class
class Config(BaseModel.Config):
dim: int = 4

def __init__(self, cfg, *, parent):
super().__init__(cfg, parent=parent)
enc_cfg = Linear.default_config().set(
input_dim=cfg.dim,
output_dim=cfg.dim,
)
self._add_child("encoder", enc_cfg)

dec_cfg = Linear.default_config().set(
input_dim=cfg.dim,
output_dim=1,
)
self._add_child("decoder", dec_cfg)

def forward(self, input_batch: NestedTensor) -> tuple[Tensor, NestedTensor]:
x = self.encoder(input_batch["x"])
y = self.decoder(x)
loss = jnp.mean(y**2)
aux = dict(discriminator_loss=jnp.mean(jnp.abs(y)))
return loss, aux


class LearnerTest(TestCase):
@parameterized.parameters(None, 0.999)
def test_forward_and_backward(self, ema_decay):
"""Demonstrates how API users should use the API while ensuring that it works correctly."""
# Init a learner.
learning_rate = config_for_function(schedule.stepwise).set(
sub=[0.1, 0.01, 0.001],
start_step=[100, 200],
)
optimizer_cfg = config_for_function(adam_optimizer).set(
learning_rate=learning_rate, b1=0.9, b2=0.99, eps=1e-5, l2_regularizer_weight=1.0
)
cfg = Learner.default_config().set(name="test", optimizer=optimizer_cfg)
cfg.ema.decay = ema_decay
learner: Learner = cfg.instantiate(parent=None)

# Init a model.
input_dim = 4
model_cfg = TestModel.default_config().set(name="test", dim=input_dim)
model = model_cfg.instantiate(parent=None)
prng_key = jax.random.PRNGKey(123)
init_key, data_key, fwd_key, learner_key, prng_key = jax.random.split(prng_key, num=5)
params = model.initialize_parameters_recursively(init_key)

# Create model and learner states.
model_param_specs = model.create_parameter_specs_recursively()
opt_params = jax.tree.map(
lambda param, spec: OptParam(
value=param,
factorization_spec=spec.factorization if spec else None,
weight_decay_scale=spec.weight_decay_scale if spec else 1.0,
),
params,
model_param_specs,
)
learner_state = learner.init(model_params=opt_params)

# Forward and backward.
def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutputs:
model_output_collection = new_output_collection()
with child_context(
"model",
module=model,
state=model_params,
prng_key=inputs["forward_key"],
output_collection=model_output_collection,
):
loss, aux = model(input_batch=inputs["input_batch"])
return ForwardOutputs(loss=loss, aux=aux, output_collection=model_output_collection)

batch = 2
input_batch = dict(x=jax.random.uniform(data_key, (batch, input_dim)))
fwd_bwd_outputs, learner_output_collection = F(
learner,
method="forward_and_backward",
state=learner_state,
is_training=True,
prng_key=learner_key,
inputs=dict(
fn=_forward,
opt_params=opt_params,
inputs=dict(
input_batch=input_batch,
forward_key=fwd_key,
),
),
)
forward_outputs: ForwardOutputs = fwd_bwd_outputs.forward_outputs
updated_model_params = fwd_bwd_outputs.backward_outputs.updated_params
learner_state = learner_output_collection.state_updates
self.assertGreater(forward_outputs.loss, 0.0)
self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0)
# The structure of updated params and Adam mu states are same.
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(learner_state["optimizer"][1].mu),
)

def test_prune_empty_state(self):
state = {
"state": {
Expand Down Expand Up @@ -816,6 +925,105 @@ def test__value_and_grad(self):


class CompositeLearnerTest(TestCase):
@parameterized.parameters(None, 0.999)
def test_forward_and_backward(self, ema_decay):
"""Demonstrates how API users should use the API while ensuring that it works correctly."""
# Init a learner.
encoder_lr = 0.1
opt1_cfg = config_for_function(sgd_optimizer).set(
learning_rate=encoder_lr, decouple_weight_decay=True, weight_decay=1.0
)
opt2_cfg = config_for_function(adam_optimizer).set(
learning_rate=0.0, b1=0.9, b2=0.99, eps=1e-5, l2_regularizer_weight=1.0
)
learner_rules = [(".*encoder.*", "encoder"), (".*decoder.*", "decoder")]

cfg = CompositeLearner.default_config().set(
name="test",
rules=learner_rules,
learners={
"encoder": Learner.default_config().set(
optimizer=opt1_cfg, enable_per_variable_summaries=True
),
"decoder": Learner.default_config().set(
optimizer=opt2_cfg, enable_per_variable_summaries=False
),
},
)
cfg.ema.decay = ema_decay
learner: CompositeLearner = cfg.instantiate(parent=None)

# Init a model.
input_dim = 4
model_cfg = TestModel.default_config().set(name="test", dim=input_dim)
model = model_cfg.instantiate(parent=None)
prng_key = jax.random.PRNGKey(123)
init_key, data_key, fwd_key, learner_key, prng_key = jax.random.split(prng_key, num=5)
params = model.initialize_parameters_recursively(init_key)

# Create model and learner states.
model_param_specs = model.create_parameter_specs_recursively()
opt_params = jax.tree.map(
lambda param, spec: OptParam(
value=param,
factorization_spec=spec.factorization if spec else None,
weight_decay_scale=spec.weight_decay_scale if spec else 1.0,
),
params,
model_param_specs,
)
learner_state = learner.init(model_params=opt_params)

# Forward and backward.
def _forward(*, model_params: NestedTensor, inputs: NestedTensor) -> ForwardOutputs:
model_output_collection = new_output_collection()
with child_context(
"model",
module=model,
state=model_params,
prng_key=inputs["forward_key"],
output_collection=model_output_collection,
):
loss, aux = model(input_batch=inputs["input_batch"])
return ForwardOutputs(loss=loss, aux=aux, output_collection=model_output_collection)

batch = 2
input_batch = dict(x=jax.random.uniform(data_key, (batch, input_dim)))
fwd_bwd_outputs, learner_output_collection = F(
learner,
method="forward_and_backward",
state=learner_state,
is_training=True,
prng_key=learner_key,
inputs=dict(
fn=_forward,
opt_params=opt_params,
inputs=dict(
input_batch=input_batch,
forward_key=fwd_key,
),
),
)
forward_outputs: ForwardOutputs = fwd_bwd_outputs.forward_outputs
updated_model_params = fwd_bwd_outputs.backward_outputs.updated_params
learner_state = learner_output_collection.state_updates
self.assertGreater(forward_outputs.loss, 0.0)
self.assertGreater(forward_outputs.aux["discriminator_loss"], 0.0)
# The structure of updated params and optimizer states are same.
opt_state_leaf_fn = lambda x: isinstance(x, (Tensor, optax.MaskedNode))
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(
learner_state["encoder"]["optimizer"][0].trace, is_leaf=opt_state_leaf_fn
),
)
self.assertNestedEqual(
jax.tree_util.tree_structure(updated_model_params),
jax.tree_util.tree_structure(
learner_state["decoder"]["optimizer"][1].mu, is_leaf=opt_state_leaf_fn
),
)

@parameterized.product(ema_decay=(None, 0.9), method=("update", "forward_and_backward"))
# pylint: disable-next=too-many-statements
def test_learner(self, ema_decay: Optional[float], method: str):
Expand Down
Loading