Skip to content

Commit

Permalink
wrap up simple diffusion example
Browse files Browse the repository at this point in the history
  • Loading branch information
Sylvain Le Groux committed Feb 14, 2025
1 parent 64dc812 commit 88b57cf
Show file tree
Hide file tree
Showing 5 changed files with 509 additions and 411 deletions.
8 changes: 4 additions & 4 deletions config/data/image/mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ batch_size: 64
num_workers: 0

pin_memory: True
persistent_workers: False
persistent_workers: True
transforms:
_target_: torchvision.transforms.Compose
_target_: torchvision.transforms.v2.Compose
transforms:
- _target_: torchvision.transforms.ToTensor
- _target_: torchvision.transforms.Normalize
- _target_: torchvision.transforms.v2.ToImage
- _target_: torchvision.transforms.v2.Normalize
mean: [0.1307,]
std: [0.3081,]

397 changes: 132 additions & 265 deletions nbs/models.diffusion.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion nimrod/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,9 @@
'nimrod.models.diffusion.DiffusorX._step': ( 'models.diffusion.html#diffusorx._step',
'nimrod/models/diffusion.py'),
'nimrod.models.diffusion.DiffusorX.forward': ( 'models.diffusion.html#diffusorx.forward',
'nimrod/models/diffusion.py')},
'nimrod/models/diffusion.py'),
'nimrod.models.diffusion.DiffusorX.generate_images': ( 'models.diffusion.html#diffusorx.generate_images',
'nimrod/models/diffusion.py')},
'nimrod.models.lm': { 'nimrod.models.lm.NNBigram': ('models.lm.html#nnbigram', 'nimrod/models/lm.py'),
'nimrod.models.lm.NNBigram.__init__': ('models.lm.html#nnbigram.__init__', 'nimrod/models/lm.py'),
'nimrod.models.lm.NNBigram.forward': ('models.lm.html#nnbigram.forward', 'nimrod/models/lm.py'),
Expand Down
16 changes: 12 additions & 4 deletions nimrod/models/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision.transforms import transforms

import torchvision.transforms.v2 as transforms
from lightning import Trainer
import os
import logging
import warnings
Expand All @@ -34,12 +34,11 @@
from diffusers.utils import make_image_grid
from diffusers.optimization import get_cosine_schedule_with_warmup

# %% ../../nbs/models.diffusion.ipynb 4
set_seed(42)
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib.image")

# %% ../../nbs/models.diffusion.ipynb 32
# %% ../../nbs/models.diffusion.ipynb 30
class DiffusorX(Regressor):
def __init__(
self,
Expand Down Expand Up @@ -79,3 +78,12 @@ def _step(self, batch, batch_idx):
loss = self.criterion(noise_pred, noise)
return loss, noise_pred, noise # loss, y_hat, y

def generate_images(self, img_shape):
logger.info("diffuse a batch")
B, C, H, W = img_shape
sample = torch.randn(img_shape).to(self.device)
for i, t in enumerate(self.noise_scheduler.timesteps):
with torch.no_grad():
residual = self.forward(sample, t)
sample = self.noise_scheduler.step(residual, t, sample).prev_sample
return sample
495 changes: 358 additions & 137 deletions tutorials/diffusion.ipynb

Large diffs are not rendered by default.

0 comments on commit 88b57cf

Please sign in to comment.