Skip to content

Commit

Permalink
try to fix sync seed again
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 7, 2024
1 parent ea13758 commit 35a8a41
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.18.7"
version = "1.18.8"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
6 changes: 3 additions & 3 deletions vector_quantize_pytorch/residual_fsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def forward(
# check if seed is manually passed in

if not exists(rand_quantize_dropout_fixed_seed):
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)

rand = random.Random(rand_quantize_dropout_fixed_seed)

Expand Down Expand Up @@ -296,7 +296,7 @@ def forward(
x,
return_all_codes = False
):
shape, split_dim = x.shape, self.split_dim
shape, split_dim, device = x.shape, self.split_dim, x.device
assert shape[split_dim] == self.dim

# split the feature dimension into groups
Expand All @@ -305,7 +305,7 @@ def forward(

forward_kwargs = dict(
return_all_codes = return_all_codes,
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
)

# invoke residual vq on each group
Expand Down
10 changes: 5 additions & 5 deletions vector_quantize_pytorch/residual_lfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def round_up_multiple(num, mult):
def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1

def get_maybe_sync_seed(max_size = 10_000):
rand_int = torch.randint(0, max_size, ())
def get_maybe_sync_seed(device, max_size = 10_000):
rand_int = torch.randint(0, max_size, (), device = device)

if is_distributed():
dist.all_reduce(rand_int)
Expand Down Expand Up @@ -162,7 +162,7 @@ def forward(
# check if seed is manually passed in

if not exists(rand_quantize_dropout_fixed_seed):
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)

rand = random.Random(rand_quantize_dropout_fixed_seed)

Expand Down Expand Up @@ -262,7 +262,7 @@ def forward(
mask = None,
return_all_codes = False
):
shape, split_dim = x.shape, self.split_dim
shape, split_dim, device = x.shape, self.split_dim, x.device
assert shape[split_dim] == self.dim

# split the feature dimension into groups
Expand All @@ -272,7 +272,7 @@ def forward(
forward_kwargs = dict(
mask = mask,
return_all_codes = return_all_codes,
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
)

# invoke residual vq on each group
Expand Down
10 changes: 5 additions & 5 deletions vector_quantize_pytorch/residual_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def round_up_multiple(num, mult):
def is_distributed():
return dist.is_initialized() and dist.get_world_size() > 1

def get_maybe_sync_seed(max_size = 10_000):
rand_int = torch.randint(0, max_size, ())
def get_maybe_sync_seed(device, max_size = 10_000):
rand_int = torch.randint(0, max_size, (), device = device)

if is_distributed():
dist.all_reduce(rand_int)
Expand Down Expand Up @@ -296,7 +296,7 @@ def forward(
# check if seed is manually passed in

if not exists(rand_quantize_dropout_fixed_seed):
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)

rand = random.Random(rand_quantize_dropout_fixed_seed)

Expand Down Expand Up @@ -452,7 +452,7 @@ def forward(
freeze_codebook = False,
mask = None,
):
shape, split_dim = x.shape, self.split_dim
shape, split_dim, device = x.shape, self.split_dim, x.device
assert shape[split_dim] == self.dim

# split the feature dimension into groups
Expand All @@ -468,7 +468,7 @@ def forward(
sample_codebook_temp = sample_codebook_temp,
mask = mask,
freeze_codebook = freeze_codebook,
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed()
rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device)
)

# invoke residual vq on each group
Expand Down

0 comments on commit 35a8a41

Please sign in to comment.