Skip to content

Commit 9c79f0d

Browse files
committed
Updates to embeddings/llp
1 parent 0773055 commit 9c79f0d

File tree

8 files changed

+1519
-164
lines changed

8 files changed

+1519
-164
lines changed

docs/lectures/lectures/deep_learning.html

+755-2
Large diffs are not rendered by default.

docs/lectures/lectures/embeddings_nlp_llm.html

+641-162
Large diffs are not rendered by default.
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Takes the baseline version and uses vmap, adds in a learning rate scheduler
2+
import jax
3+
import jax.numpy as jnp
4+
import numpy as np
5+
from jax import random
6+
import optax
7+
import jax_dataloader as jdl
8+
from jax_dataloader.loaders import DataLoaderJAX
9+
from flax import nnx
10+
from typing import List, Optional, Callable
11+
import wandb
12+
import jsonargparse
13+
14+
15+
# My MLP
16+
class MyMLP(nnx.Module):
17+
def __init__(self, din: int, dout: int, width: int, *, rngs: nnx.Rngs):
18+
self.linear1 = nnx.Linear(din, width, rngs=rngs)
19+
self.linear2 = nnx.Linear(width, width, rngs=rngs)
20+
self.linear3 = nnx.Linear(width, dout, rngs=rngs)
21+
22+
def __call__(self, x: jax.Array):
23+
x = self.linear1(x)
24+
x = nnx.relu(x)
25+
x = self.linear2(x)
26+
x = nnx.relu(x)
27+
x = self.linear3(x)
28+
return x
29+
30+
31+
def fit_model(
32+
N: int = 500,
33+
M: int = 2,
34+
sigma: float = 0.0001,
35+
width: int = 128,
36+
lr: float = 0.001,
37+
num_epochs: int = 2000,
38+
batch_size: int = 512,
39+
seed: int = 42,
40+
wandb_project: str = "econ622_examples",
41+
wandb_mode: str = "offline", # "online", "disabled
42+
):
43+
if not wandb_mode == "disabled":
44+
wandb.init(project="survey", mode=wandb_mode)
45+
rngs = nnx.Rngs(seed)
46+
47+
theta = random.normal(rngs(), (M,))
48+
X = random.normal(rngs(), (N, M))
49+
Y = X @ theta + sigma * random.normal(rngs(), (N,)) # Adding noise
50+
51+
def residual(model, x, y):
52+
y_hat = model(x)
53+
return (y_hat - y) ** 2
54+
55+
def residuals_loss(model, X, Y):
56+
return jnp.mean(jax.vmap(residual, in_axes=(None, 0, 0))(model, X, Y))
57+
58+
model = MyMLP(M, 1, width, rngs=rngs)
59+
60+
n_params = sum(
61+
np.prod(x.shape) for x in jax.tree.leaves(nnx.state(model, nnx.Param))
62+
)
63+
print(f"Number of parameters: {n_params}")
64+
65+
optimizer = nnx.Optimizer(model, optax.sgd(lr))
66+
67+
@nnx.jit
68+
def train_step(model, optimizer, X, Y):
69+
def loss_fn(model):
70+
return residuals_loss(model, X, Y)
71+
72+
loss, grads = nnx.value_and_grad(loss_fn)(model)
73+
optimizer.update(grads)
74+
return loss
75+
76+
dataset = jdl.ArrayDataset(X, Y)
77+
train_loader = DataLoaderJAX(dataset, batch_size=batch_size, shuffle=True)
78+
for epoch in range(num_epochs):
79+
for X_batch, Y_batch in train_loader:
80+
loss = train_step(model, optimizer, X_batch, Y_batch)
81+
82+
if not (wandb_mode == "disabled"):
83+
wandb.log({"epoch": epoch, "train_loss": loss, "lr": lr})
84+
if epoch % 100 == 0:
85+
print(f"Epoch {epoch}, loss {loss}")
86+
87+
N_test = 200
88+
X_test = random.normal(rngs(), (N_test, M))
89+
Y_test = X_test @ theta + sigma * random.normal(rngs(), (N_test,)) # Adding noise
90+
91+
loss_data = residuals_loss(model, X, Y)
92+
loss_test = residuals_loss(model, X_test, Y_test)
93+
print(f"loss(model, X, Y) = {loss_data}, loss(model, X_test, Y_test) = {loss_test}")
94+
if not (wandb_mode == "disabled"):
95+
wandb.log(
96+
{"train_loss": loss_data, "test_loss": loss_test, "num_params": n_params}
97+
)
98+
99+
if not wandb_mode == "disabled":
100+
wandb.finish()
101+
102+
103+
if __name__ == "__main__":
104+
jsonargparse.CLI(fit_model)
105+
# Swap with this line to run debugger with different parameters
106+
# jsonargparse.CLI(fit_model, args=["--num_epochs", "200", "--wandb_mode", "online"])
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
program: lectures/examples/mlp_regression_jax_nnx_logging.py
2+
name: Sweep Example
3+
description: Example Sweep
4+
method: bayes
5+
metric:
6+
name: test_loss
7+
goal: minimize
8+
parameters:
9+
wandb_mode:
10+
value: online # otherwise won't log
11+
num_epochs:
12+
value: 300
13+
lr:
14+
min: 0.0001
15+
max: 0.01
16+
width:
17+
values: [64, 128, 256]
234 KB
Loading
40.4 KB
Loading
226 KB
Loading

0 commit comments

Comments
 (0)