Skip to content

Commit

Permalink
Add VAE class
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Dec 8, 2024
1 parent 16f6b29 commit 8aad5d2
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions hypercoast/chla.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,61 @@
from rasterio.transform import from_origin
from rasterio.warp import reproject, Resampling
from scipy.interpolate import griddata


class VAE(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()

# encoder
self.encoder_layer = nn.Sequential(
nn.Linear(input_dim, 64),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.2),
nn.Linear(64, 64),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.2),
)

self.fc1 = nn.Linear(64, 32)
self.fc2 = nn.Linear(64, 32)

# decoder
self.decoder = nn.Sequential(
nn.Linear(32, 64),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.2),
nn.Linear(64, 64),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.2),
nn.Linear(64, output_dim),
nn.Softplus(),
)

def encode(self, x):
x = self.encoder_layer(x)
mu = self.fc1(x)
log_var = self.fc2(x)
return mu, log_var

def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
z = mu + eps * std
return z

def decode(self, z):
return self.decoder(z)

def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
x_reconstructed = self.decode(z)
return x_reconstructed, mu, log_var


def loss_function(recon_x, x, mu, log_var):
L1 = F.l1_loss(recon_x, x, reduction="mean")
BCE = F.mse_loss(recon_x, x, reduction="mean")
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return L1

0 comments on commit 8aad5d2

Please sign in to comment.