From fd32aacafc01fa423d2325ca74ea7d2c03c16410 Mon Sep 17 00:00:00 2001 From: Bhavya Kohli <81435750+BhavyaKohli@users.noreply.github.com> Date: Fri, 19 Apr 2024 23:30:40 +0530 Subject: [PATCH 1/5] changed default device to cpu, removed .to(device) calls in wavemix/__init__.py --- wavemix/__init__.py | 29 +++++++++++++++++------------ wavemix/classification.py | 20 ++++++++++---------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/wavemix/__init__.py b/wavemix/__init__.py index 414262b..92d1fe4 100644 --- a/wavemix/__init__.py +++ b/wavemix/__init__.py @@ -8,8 +8,6 @@ from einops.layers.torch import Rearrange -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1): """ 1D synthesis filter bank of an image tensor """ @@ -57,6 +55,7 @@ def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1): return y + def reflect(x, minx, maxx): """Reflect the values in matrix *x* about the scalar values *minx* and *maxx*. Hence a vector *x* containing a long linearly increasing series is @@ -74,6 +73,7 @@ def reflect(x, minx, maxx): out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx return np.array(out, dtype=x.dtype) + def mode_to_int(mode): if mode == 'zero': return 0 @@ -92,6 +92,7 @@ def mode_to_int(mode): else: raise ValueError("Unkown pad type: {}".format(mode)) + def int_to_mode(mode): if mode == 0: return 'zero' @@ -110,6 +111,7 @@ def int_to_mode(mode): else: raise ValueError("Unkown pad type: {}".format(mode)) + def afb1d(x, h0, h1, mode='zero', dim=-1): """ 1D analysis filter bank (along one dimension only) of an image Inputs: @@ -192,7 +194,6 @@ def afb1d(x, h0, h1, mode='zero', dim=-1): return lohi - class AFB2D(Function): """ Does a single level 2d wavelet decomposition of an input. Does separate row and column filtering by two calls to @@ -245,7 +246,7 @@ def backward(ctx, low, highs): return dx, None, None, None, None, None -def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=device): +def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device='cpu'): """ Prepares the filters to be of the right form for the afb2d function. In particular, makes the tensors the right shape. It takes mirror images of @@ -274,7 +275,7 @@ def prep_filt_afb2d(h0_col, h1_col, h0_row=None, h1_row=None, device=device): return h0_col, h1_col, h0_row, h1_row -def prep_filt_afb1d(h0, h1, device=device): +def prep_filt_afb1d(h0, h1, device='cpu'): """ Prepares the filters to be of the right form for the afb2d function. In particular, makes the tensors the right shape. It takes mirror images of @@ -289,10 +290,11 @@ def prep_filt_afb1d(h0, h1, device=device): h0 = np.array(h0[::-1]).ravel() h1 = np.array(h1[::-1]).ravel() t = torch.get_default_dtype() - h0 = torch.tensor(h0, device=device, dtype=t).reshape((1, 1, -1)) - h1 = torch.tensor(h1, device=device, dtype=t).reshape((1, 1, -1)) + h0 = torch.tensor(h0, device='cpu', dtype=t).reshape((1, 1, -1)) + h1 = torch.tensor(h1, device='cpu', dtype=t).reshape((1, 1, -1)) return h0, h1 + class DWTForward(nn.Module): """ Performs a 2d DWT Forward decomposition of an image Args: @@ -358,12 +360,14 @@ def forward(self, x): return ll, yh + from numpy.lib.function_base import hamming -xf1 = DWTForward(J=1, mode='zero', wave='db1').to(device) -xf2 = DWTForward(J=2, mode='zero', wave='db1').to(device) -xf3 = DWTForward(J=3, mode='zero', wave='db1').to(device) -xf4 = DWTForward(J=4, mode='zero', wave='db1').to(device) +xf1 = DWTForward(J=1, mode='zero', wave='db1') +xf2 = DWTForward(J=2, mode='zero', wave='db1') +xf3 = DWTForward(J=3, mode='zero', wave='db1') +xf4 = DWTForward(J=4, mode='zero', wave='db1') + class Level1Waveblock(nn.Module): def __init__( @@ -405,6 +409,7 @@ def forward(self, x): return x + class Level2Waveblock(nn.Module): def __init__( self, @@ -618,4 +623,4 @@ def forward(self, x): x1 = torch.cat((x1,x2), dim = 1) x = self.feedforward1(x1) - return x + return x \ No newline at end of file diff --git a/wavemix/classification.py b/wavemix/classification.py index 4353f0e..cb864e7 100644 --- a/wavemix/classification.py +++ b/wavemix/classification.py @@ -2,6 +2,7 @@ import torch.nn as nn from einops.layers.torch import Rearrange + class WaveMix(nn.Module): def __init__( self, @@ -22,14 +23,14 @@ def __init__( self.layers = nn.ModuleList([]) for _ in range(depth): - if level == 4: - self.layers.append(Level4Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) - elif level == 3: - self.layers.append(Level3Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) - elif level == 2: - self.layers.append(Level2Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) - else: - self.layers.append(Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) + if level == 4: + self.layers.append(Level4Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) + elif level == 3: + self.layers.append(Level3Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) + elif level == 2: + self.layers.append(Level2Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) + else: + self.layers.append(Level1Waveblock(mult = mult, ff_channel = ff_channel, final_dim = final_dim, dropout = dropout)) self.pool = nn.Sequential( nn.AdaptiveAvgPool2d(1), @@ -49,8 +50,7 @@ def __init__( nn.Conv2d(int(final_dim/2), final_dim, patch_size, patch_size), nn.GELU(), nn.BatchNorm2d(final_dim) - ) - + ) def forward(self, img): x = self.conv(img) From f6a3c616e67dc186c2e6a8f3c783d3d88248be89 Mon Sep 17 00:00:00 2001 From: Bhavya Kohli <81435750+BhavyaKohli@users.noreply.github.com> Date: Fri, 19 Apr 2024 23:51:21 +0530 Subject: [PATCH 2/5] device call fix --- wavemix/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wavemix/__init__.py b/wavemix/__init__.py index 92d1fe4..8e2a202 100644 --- a/wavemix/__init__.py +++ b/wavemix/__init__.py @@ -290,8 +290,8 @@ def prep_filt_afb1d(h0, h1, device='cpu'): h0 = np.array(h0[::-1]).ravel() h1 = np.array(h1[::-1]).ravel() t = torch.get_default_dtype() - h0 = torch.tensor(h0, device='cpu', dtype=t).reshape((1, 1, -1)) - h1 = torch.tensor(h1, device='cpu', dtype=t).reshape((1, 1, -1)) + h0 = torch.tensor(h0, device=device, dtype=t).reshape((1, 1, -1)) + h1 = torch.tensor(h1, device=device, dtype=t).reshape((1, 1, -1)) return h0, h1 From f902ea5da2b3aa4b7b0286dc105aa1ae2015f5c0 Mon Sep 17 00:00:00 2001 From: Bhavya Kohli <81435750+BhavyaKohli@users.noreply.github.com> Date: Sat, 20 Apr 2024 00:19:43 +0530 Subject: [PATCH 3/5] added DWTForward filters inside their respective level waveblocks, to allow pytorch to put them on a specified device when calling LevelXWaveblock.to(device) --- wavemix/__init__.py | 51 ++++++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/wavemix/__init__.py b/wavemix/__init__.py index 8e2a202..58c460c 100644 --- a/wavemix/__init__.py +++ b/wavemix/__init__.py @@ -162,7 +162,7 @@ def afb1d(x, h0, h1, mode='zero', dim=-1): N += 1 x = roll(x, -L2, dim=d) pad = (L-1, 0) if d == 2 else (0, L-1) - lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C) + lohi = F.conv2d(x, h.to(x.device), padding=pad, stride=s, groups=C) N2 = N//2 if d == 2: lohi[:,:,:L2] = lohi[:,:,:L2] + lohi[:,:,N2:N2+L2] @@ -183,11 +183,11 @@ def afb1d(x, h0, h1, mode='zero', dim=-1): x = F.pad(x, pad) pad = (p//2, 0) if d == 2 else (0, p//2) # Calculate the high and lowpass - lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C) + lohi = F.conv2d(x, h.to(x.device), padding=pad, stride=s, groups=C) elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic': pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0) x = mypad(x, pad=pad, mode=mode) - lohi = F.conv2d(x, h, stride=s, groups=C) + lohi = F.conv2d(x, h.to(x.device), stride=s, groups=C) else: raise ValueError("Unkown pad type: {}".format(mode)) @@ -362,11 +362,14 @@ def forward(self, x): from numpy.lib.function_base import hamming - -xf1 = DWTForward(J=1, mode='zero', wave='db1') -xf2 = DWTForward(J=2, mode='zero', wave='db1') -xf3 = DWTForward(J=3, mode='zero', wave='db1') -xf4 = DWTForward(J=4, mode='zero', wave='db1') + + +def get_dwt_filters(level, mode='zero', wave='db1'): + xf1 = DWTForward(J=1, mode=mode, wave=wave) + xf2 = DWTForward(J=2, mode=mode, wave=wave) + xf3 = DWTForward(J=3, mode=mode, wave=wave) + xf4 = DWTForward(J=4, mode=mode, wave=wave) + return [xf1, xf2, xf3, xf4][:level] class Level1Waveblock(nn.Module): @@ -392,14 +395,15 @@ def __init__( ) self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1) - + self.xf1 = get_dwt_filters(level=1) + def forward(self, x): b, c, h, w = x.shape x = self.reduction(x) - Y1, Yh = xf1(x) + Y1, Yh = self.xf1(x) x = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2))) @@ -440,15 +444,16 @@ def __init__( ) self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1) - + self.xf1, self.xf2 = get_dwt_filters(level=2) + def forward(self, x): b, c, h, w = x.shape x = self.reduction(x) - Y1, Yh = xf1(x) - Y2, Yh = xf2(x) + Y1, Yh = self.xf1(x) + Y2, Yh = self.xf2(x) x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2))) @@ -504,16 +509,17 @@ def __init__( ) self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1) - + self.xf1, self.xf2, self.xf3 = get_dwt_filters(level=3) + def forward(self, x): b, c, h, w = x.shape x = self.reduction(x) - Y1, Yh = xf1(x) - Y2, Yh = xf2(x) - Y3, Yh = xf3(x) + Y1, Yh = self.xf1(x) + Y2, Yh = self.xf2(x) + Y3, Yh = self.xf3(x) x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2))) @@ -587,17 +593,18 @@ def __init__( ) self.reduction = nn.Conv2d(final_dim, int(final_dim/4), 1) - + self.xf1, self.xf2, self.xf3, self.xf4 = get_dwt_filters(level=4) + def forward(self, x): b, c, h, w = x.shape x = self.reduction(x) - Y1, Yh = xf1(x) - Y2, Yh = xf2(x) - Y3, Yh = xf3(x) - Y4, Yh = xf4(x) + Y1, Yh = self.xf1(x) + Y2, Yh = self.xf2(x) + Y3, Yh = self.xf3(x) + Y4, Yh = self.xf4(x) x1 = torch.reshape(Yh[0], (b, int(c*3/4), int(h/2), int(w/2))) x2 = torch.reshape(Yh[1], (b, int(c*3/4), int(h/4), int(w/4))) From f0aae5daec7f4163197422da47499c51d1253e98 Mon Sep 17 00:00:00 2001 From: Bhavya Kohli <81435750+BhavyaKohli@users.noreply.github.com> Date: Sat, 20 Apr 2024 05:10:02 +0530 Subject: [PATCH 4/5] fixed dwt_filters length issue for level 1 --- wavemix/__init__.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/wavemix/__init__.py b/wavemix/__init__.py index 58c460c..110b7f8 100644 --- a/wavemix/__init__.py +++ b/wavemix/__init__.py @@ -365,11 +365,14 @@ def forward(self, x): def get_dwt_filters(level, mode='zero', wave='db1'): - xf1 = DWTForward(J=1, mode=mode, wave=wave) - xf2 = DWTForward(J=2, mode=mode, wave=wave) - xf3 = DWTForward(J=3, mode=mode, wave=wave) - xf4 = DWTForward(J=4, mode=mode, wave=wave) - return [xf1, xf2, xf3, xf4][:level] + xf = [] + for j in range(1,level+1,1): + xf.append(DWTForward(J=j, mode=mode, wave=wave) + + if level == 1: + xf = xf[0] + + return xf class Level1Waveblock(nn.Module): @@ -630,4 +633,4 @@ def forward(self, x): x1 = torch.cat((x1,x2), dim = 1) x = self.feedforward1(x1) - return x \ No newline at end of file + return x From 14dd70872bd46eef113902908312dc01a59e2db2 Mon Sep 17 00:00:00 2001 From: Bhavya Kohli <81435750+BhavyaKohli@users.noreply.github.com> Date: Sat, 20 Apr 2024 05:14:58 +0530 Subject: [PATCH 5/5] fixed dwt_filters length issue for level 1 --- wavemix/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wavemix/__init__.py b/wavemix/__init__.py index 110b7f8..ce86895 100644 --- a/wavemix/__init__.py +++ b/wavemix/__init__.py @@ -367,7 +367,7 @@ def forward(self, x): def get_dwt_filters(level, mode='zero', wave='db1'): xf = [] for j in range(1,level+1,1): - xf.append(DWTForward(J=j, mode=mode, wave=wave) + xf.append(DWTForward(J=j, mode=mode, wave=wave)) if level == 1: xf = xf[0]