Skip to content

Commit

Permalink
address #149
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 26, 2024
1 parent 4c514db commit 74a27c8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 5 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.15.3"
version = "1.15.4"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand All @@ -25,7 +25,7 @@ classifiers=[
dependencies = [
"torch>=2.0",
"einops>=0.8.0",
"einx>=0.2.2",
"einx>=0.3.0",
]

[project.urls]
Expand Down
25 changes: 25 additions & 0 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@ def test_vq_eval():
quantized, indices, commit_loss = vq(x)
assert torch.allclose(quantized, vq.get_output_from_indices(indices))

def test_vq_mask():
from vector_quantize_pytorch import VectorQuantize

vq = VectorQuantize(
dim = 256,
codebook_size = 512, # codebook size
decay = 1., # the exponential moving average decay, lower means the dictionary will change faster
commitment_weight = 1. # the weight on the commitment loss
)

x = torch.randn(1, 1024, 256)
lens = torch.full((1,), 512)

vq.train()

quantized, indices, commit_loss = vq(x[:, :512])
mask_quantized, mask_indices, mask_commit_loss = vq(x, lens = lens)

assert torch.allclose(commit_loss, mask_commit_loss)
assert torch.allclose(quantized, mask_quantized[:, :512])
assert torch.allclose(indices, mask_indices[:, :512])

assert torch.allclose(mask_quantized[:, 512:], x[:, 512:])
assert (mask_indices[:, 512:] == -1).all()

def test_residual_vq():
from vector_quantize_pytorch import ResidualVQ

Expand Down
30 changes: 27 additions & 3 deletions vector_quantize_pytorch/vector_quantize_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@

import torch
from torch.nn import Module
from torch import nn, einsum
from torch import nn, einsum, Tensor
import torch.nn.functional as F
import torch.distributed as distributed
from torch.optim import Optimizer
from torch.cuda.amp import autocast

import einx
from einops import rearrange, repeat, reduce, pack, unpack

from typing import Callable
Expand Down Expand Up @@ -63,6 +64,10 @@ def pack_one(t, pattern):
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]

def lens_to_mask(lens, max_length):
seq = torch.arange(max_length, device = lens.device)
return seq < lens[:, None]

def uniform_init(*shape):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
Expand Down Expand Up @@ -897,12 +902,22 @@ def forward(
x,
indices = None,
mask = None,
lens = None,
sample_codebook_temp = None,
freeze_codebook = False,
return_loss_breakdown = False,
):
orig_input = x

# handle masking, either passed in as `mask` or `lens`

assert not (exists(mask) and exists(lens))

if exists(lens):
mask = lens_to_mask(lens, x.shape[1])

# handle one token given

only_one = x.ndim == 2

if only_one:
Expand All @@ -917,6 +932,7 @@ def forward(
# rearrange inputs

if self.accept_image_fmap:
assert not exists(mask)
height, width = x.shape[-2:]
x = rearrange(x, 'b c h w -> b (h w) c')

Expand Down Expand Up @@ -1117,12 +1133,20 @@ def calculate_ce_loss(codes):
# if masking, only return quantized for where mask has True

if exists(mask):
quantize = torch.where(
rearrange(mask, '... -> ... 1'),
quantize = einx.where(
'b n, b n d, b n d -> b n d',
mask,
quantize,
orig_input
)

embed_ind = einx.where(
'b n, b n ..., -> b n ...',
mask,
embed_ind,
-1
)

if not return_loss_breakdown:
return quantize, embed_ind, loss

Expand Down

0 comments on commit 74a27c8

Please sign in to comment.