|
| 1 | +# Takes the baseline version and uses vmap, adds in a learning rate scheduler |
| 2 | +import jax |
| 3 | +import jax.numpy as jnp |
| 4 | +import numpy as np |
| 5 | +from jax import random |
| 6 | +import optax |
| 7 | +import jax_dataloader as jdl |
| 8 | +from jax_dataloader.loaders import DataLoaderJAX |
| 9 | +from flax import nnx |
| 10 | +from typing import List, Optional, Callable |
| 11 | +import wandb |
| 12 | +import jsonargparse |
| 13 | + |
| 14 | + |
| 15 | +# My MLP |
| 16 | +class MyMLP(nnx.Module): |
| 17 | + def __init__(self, din: int, dout: int, width: int, *, rngs: nnx.Rngs): |
| 18 | + self.linear1 = nnx.Linear(din, width, rngs=rngs) |
| 19 | + self.linear2 = nnx.Linear(width, width, rngs=rngs) |
| 20 | + self.linear3 = nnx.Linear(width, dout, rngs=rngs) |
| 21 | + |
| 22 | + def __call__(self, x: jax.Array): |
| 23 | + x = self.linear1(x) |
| 24 | + x = nnx.relu(x) |
| 25 | + x = self.linear2(x) |
| 26 | + x = nnx.relu(x) |
| 27 | + x = self.linear3(x) |
| 28 | + return x |
| 29 | + |
| 30 | + |
| 31 | +def fit_model( |
| 32 | + N: int = 500, |
| 33 | + M: int = 2, |
| 34 | + sigma: float = 0.0001, |
| 35 | + width: int = 128, |
| 36 | + lr: float = 0.001, |
| 37 | + num_epochs: int = 2000, |
| 38 | + batch_size: int = 512, |
| 39 | + seed: int = 42, |
| 40 | + wandb_project: str = "econ622_examples", |
| 41 | + wandb_mode: str = "offline", # "online", "disabled |
| 42 | +): |
| 43 | + if not wandb_mode == "disabled": |
| 44 | + wandb.init(project="survey", mode=wandb_mode) |
| 45 | + rngs = nnx.Rngs(seed) |
| 46 | + |
| 47 | + theta = random.normal(rngs(), (M,)) |
| 48 | + X = random.normal(rngs(), (N, M)) |
| 49 | + Y = X @ theta + sigma * random.normal(rngs(), (N,)) # Adding noise |
| 50 | + |
| 51 | + def residual(model, x, y): |
| 52 | + y_hat = model(x) |
| 53 | + return (y_hat - y) ** 2 |
| 54 | + |
| 55 | + def residuals_loss(model, X, Y): |
| 56 | + return jnp.mean(jax.vmap(residual, in_axes=(None, 0, 0))(model, X, Y)) |
| 57 | + |
| 58 | + model = MyMLP(M, 1, width, rngs=rngs) |
| 59 | + |
| 60 | + n_params = sum( |
| 61 | + np.prod(x.shape) for x in jax.tree.leaves(nnx.state(model, nnx.Param)) |
| 62 | + ) |
| 63 | + print(f"Number of parameters: {n_params}") |
| 64 | + |
| 65 | + optimizer = nnx.Optimizer(model, optax.sgd(lr)) |
| 66 | + |
| 67 | + @nnx.jit |
| 68 | + def train_step(model, optimizer, X, Y): |
| 69 | + def loss_fn(model): |
| 70 | + return residuals_loss(model, X, Y) |
| 71 | + |
| 72 | + loss, grads = nnx.value_and_grad(loss_fn)(model) |
| 73 | + optimizer.update(grads) |
| 74 | + return loss |
| 75 | + |
| 76 | + dataset = jdl.ArrayDataset(X, Y) |
| 77 | + train_loader = DataLoaderJAX(dataset, batch_size=batch_size, shuffle=True) |
| 78 | + for epoch in range(num_epochs): |
| 79 | + for X_batch, Y_batch in train_loader: |
| 80 | + loss = train_step(model, optimizer, X_batch, Y_batch) |
| 81 | + |
| 82 | + if not (wandb_mode == "disabled"): |
| 83 | + wandb.log({"epoch": epoch, "train_loss": loss, "lr": lr}) |
| 84 | + if epoch % 100 == 0: |
| 85 | + print(f"Epoch {epoch}, loss {loss}") |
| 86 | + |
| 87 | + N_test = 200 |
| 88 | + X_test = random.normal(rngs(), (N_test, M)) |
| 89 | + Y_test = X_test @ theta + sigma * random.normal(rngs(), (N_test,)) # Adding noise |
| 90 | + |
| 91 | + loss_data = residuals_loss(model, X, Y) |
| 92 | + loss_test = residuals_loss(model, X_test, Y_test) |
| 93 | + print(f"loss(model, X, Y) = {loss_data}, loss(model, X_test, Y_test) = {loss_test}") |
| 94 | + if not (wandb_mode == "disabled"): |
| 95 | + wandb.log( |
| 96 | + {"train_loss": loss_data, "test_loss": loss_test, "num_params": n_params} |
| 97 | + ) |
| 98 | + |
| 99 | + if not wandb_mode == "disabled": |
| 100 | + wandb.finish() |
| 101 | + |
| 102 | + |
| 103 | +if __name__ == "__main__": |
| 104 | + jsonargparse.CLI(fit_model) |
| 105 | + # Swap with this line to run debugger with different parameters |
| 106 | + # jsonargparse.CLI(fit_model, args=["--num_epochs", "200", "--wandb_mode", "online"]) |
0 commit comments