Skip to content

Commit

Permalink
flesh out the maskbit trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 12, 2024
1 parent 74ec85e commit eaca2ee
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 6 deletions.
3 changes: 3 additions & 0 deletions maskbit_pytorch/maskbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,9 @@ def __init__(

self._c = vae.channels

def parameters(self):
return self.demasking_transformer.parameters()

@property
def device(self):
return next(self.parameters()).device
Expand Down
211 changes: 206 additions & 5 deletions maskbit_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image

from maskbit_pytorch.maskbit import BQVAE
from maskbit_pytorch.maskbit import BQVAE, MaskBit

from einops import rearrange

Expand Down Expand Up @@ -357,9 +357,9 @@ def train_step(self):

self.discr_optim.step()

# log
# log

self.print(f"{steps}: vae loss: {logs['loss']:.3f} - discr loss: {logs['discr_loss']:.3f}")
self.print(f"{steps}: vae loss: {logs['loss']:.3f} - discr loss: {logs['discr_loss']:.3f}")

# update exponential moving averaged generator

Expand Down Expand Up @@ -424,6 +424,207 @@ def forward(self):
# maskbit trainer

class MaskBitTrainer(Module):
def __init__(self):
def __init__(
self,
maskbit: MaskBit,
folder,
num_train_steps,
batch_size,
image_size,
lr = 3e-4,
grad_accum_every = 1,
max_grad_norm = None,
save_results_every = 100,
save_model_every = 1000,
results_folder = './results',
valid_frac = 0.05,
random_split_seed = 42,
accelerate_kwargs: dict = dict()
):
super().__init__()
raise NotImplementedError

# instantiate accelerator

kwargs_handlers = accelerate_kwargs.get('kwargs_handlers', [])

ddp_kwargs = find_and_pop(
kwargs_handlers,
lambda x: isinstance(x, DistributedDataParallelKwargs),
partial(DistributedDataParallelKwargs, find_unused_parameters = True)
)

ddp_kwargs.find_unused_parameters = True
kwargs_handlers.append(ddp_kwargs)
accelerate_kwargs.update(kwargs_handlers = kwargs_handlers)

self.accelerator = Accelerator(**accelerate_kwargs)

# training params

self.register_buffer('steps', tensor(0))

self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.grad_accum_every = grad_accum_every

# model

self.maskbit = maskbit

# optimizers

self.optim = Adam(maskbit.parameters(), lr = lr)

self.max_grad_norm = max_grad_norm

# create dataset

self.ds = ImageDataset(folder, image_size)

# split for validation

if valid_frac > 0:
train_size = int((1 - valid_frac) * len(self.ds))
valid_size = len(self.ds) - train_size
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
else:
self.valid_ds = self.ds
self.print(f'training with shared training and valid dataset of {len(self.ds)} samples')

# dataloader

self.dl = DataLoader(
self.ds,
batch_size = batch_size,
shuffle = True
)

self.valid_dl = DataLoader(
self.valid_ds,
batch_size = batch_size,
shuffle = True
)

# prepare with accelerator

(
self.maskbit,
self.optim,
self.dl,
self.valid_dl
) = self.accelerator.prepare(
self.maskbit,
self.optim,
self.dl,
self.valid_dl
)

self.dl_iter = cycle(self.dl)
self.valid_dl_iter = cycle(self.valid_dl)

self.save_model_every = save_model_every
self.save_results_every = save_results_every

self.results_folder = Path(results_folder)

if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
rmtree(str(self.results_folder))

self.results_folder.mkdir(parents = True, exist_ok = True)

def save(self, path):
if not self.accelerator.is_local_main_process:
return

pkg = dict(
model = self.accelerator.get_state_dict(self.maskbit),
optim = self.optim.state_dict(),
)

torch.save(pkg, path)

def load(self, path):
path = Path(path)
assert path.exists()
pkg = torch.load(path)

maskbit = self.accelerator.unwrap_model(self.maskbit)
maskbit.load_state_dict(pkg['model'])

self.optim.load_state_dict(pkg['optim'])

def print(self, msg):
self.accelerator.print(msg)

@property
def device(self):
return self.accelerator.device

@property
def is_distributed(self):
return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)

@property
def is_main(self):
return self.accelerator.is_main_process

@property
def is_local_main(self):
return self.accelerator.is_local_main_process

def train_step(self):
acc = self.accelerator
device = self.device

steps = int(self.steps.item())

self.maskbit.train()

# logs

logs = dict()

# update vae (generator)

for _ in range(self.grad_accum_every):
img = next(self.dl_iter)
img = img.to(device)

with acc.autocast():
loss = self.maskbit(img)

acc.backward(loss / self.grad_accum_every)

accum_log(logs, {'loss': loss.item() / self.grad_accum_every})

if exists(self.max_grad_norm):
acc.clip_grad_norm_(self.maskbit.parameters(), self.max_grad_norm)

self.optim.step()
self.optim.zero_grad()

# log

self.print(f"{steps}: maskbit loss: {logs['loss']:.3f}")

# save model every so often

acc.wait_for_everyone()

if self.is_main and not (steps % self.save_model_every):
state_dict = acc.unwrap_model(self.maskbit).state_dict()
model_path = str(self.results_folder / f'maskbit.{steps}.pt')
acc.save(state_dict, model_path)

self.print(f'{steps}: saving model to {str(self.results_folder)}')

self.steps += 1
return logs

def forward(self):

while self.steps < self.num_train_steps:
logs = self.train_step()

self.print('training complete')
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "maskbit-pytorch"
version = "0.0.1"
version = "0.0.2"
description = "MaskBit"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit eaca2ee

Please sign in to comment.