This repository provides a TensorFlow implementation of FSQ based on:
Image source: https://arxiv.org/abs/2309.15505
$ git clone https://github.com/Nikolai10/FSQ
import sys
sys.path.append('/content/FSQ') # adjust path to your needs
from finite_scalar_quantization import FSQ
import numpy as np
fsq = FSQ(levels=[3, 5, 4])
z = np.asarray([0.25, 0.6, -7])
zhat = fsq(z) # == fsq.quantize(z)
print(f"Quantized {z} -> {zhat}") # Quantized [ 0.25 0.6 -7. ] -> [ 0. 0.5 -1. ]
# We can map to an index in the codebook.
idx = fsq.codes_to_indices(zhat)
print(f"Code {zhat} is the {idx}-th index.") # Code [ 0. 0.5 -1. ] is the 10-th index.
# Back to code
code_out = fsq.indices_to_codes(idx)
print(f"Index {idx} mapped back to {zhat}.") # Index 10 mapped back to [ 0. 0.5 -1. ].
See Colab with FSQ code for more details.
A notebook on how to train a FSQ-VAE is additionally provided here: .
This notebook largely follows the Keras tutorial Vector-Quantized Variational Autoencoders;
the main change is that we replace the VectorQuantizer
with our FSQ
class.
VQ | FSQ | |
---|---|---|
Quantization | argmin_c || z-c || | round(f(z)) |
Gradients | Straight Through Estimation (STE) | STE |
Auxiliary Losses | Commitment, codebook, entropy loss, ... | N/A |
Tricks | EMA on codebook, codebook splitting, projections, ... | N/A |
Parameters | Codebook | N/A |
- official JAX implementation: https://github.com/google-research/google-research/tree/master/fsq,
- external PyTorch port: https://github.com/lucidrains/vector-quantize-pytorch.