Skip to content

Commit

Permalink
Cleaned repo
Browse files Browse the repository at this point in the history
  • Loading branch information
lpaillet-laas committed Mar 14, 2024
1 parent 35fc173 commit 7631803
Show file tree
Hide file tree
Showing 25 changed files with 786 additions and 769 deletions.
15 changes: 12 additions & 3 deletions MST/simulation/train_code/architecture/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,19 @@ def model_generator(method, pretrained_model_path=None):
else:
print(f'Method {method} is not defined !!!!')
if pretrained_model_path is not None:
print(f'load model from {pretrained_model_path}')
# print(f'load model from {pretrained_model_path}')
# checkpoint = torch.load(pretrained_model_path)
# model.load_state_dict({k.replace('reconstruction_model.', ''): v for k, v in checkpoint.items()},
# strict=False)

checkpoint = torch.load(pretrained_model_path)
model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint.items()},
strict=False)

adjusted_state_dict = {key.replace('reconstruction_module.reconstruction_model.', '').replace('reconstruction_model.', ''): value
for key, value in checkpoint['state_dict'].items()}
# Filter out unexpected keys
model_keys = set(model.state_dict().keys())
filtered_state_dict = {k: v for k, v in adjusted_state_dict.items() if k in model_keys}
model.load_state_dict(filtered_state_dict)
if method == 'hdnet':
return model,fdl_loss
return model
Binary file not shown.
Binary file not shown.
Binary file not shown.
12 changes: 1 addition & 11 deletions Resnet_only.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import pytorch_lightning as pl
import torch
import torch.nn as nn
from simca.CassiSystem_lightning import CassiSystemOptim
from MST.simulation.train_code.architecture import *
from simca import load_yaml_config
import matplotlib.pyplot as plt
import torchvision
import numpy as np
from simca.functions_acquisition import *
from piqa import SSIM
from torch.utils.tensorboard import SummaryWriter
import io
import torchvision.transforms as transforms
from PIL import Image
from optimization_modules_with_resnet import UnetModel
from optimization_modules_full import UnetModel

class ResnetOnly(pl.LightningModule):

Expand Down Expand Up @@ -48,16 +45,9 @@ def forward(self, x, pattern=None):
self.acquisition = self.acquisition.flip(2)
self.acquisition = self.acquisition.unsqueeze(1).float()

# print("acquisition shape: ", self.acquisition.shape)
# plt.imshow(self.acquisition[0,0,:,:].cpu().numpy())
# plt.show()

self.pattern = self.mask_generation(self.acquisition)
self.pattern = BinarizeFunction.apply(self.pattern)

# print("pattern shape: ", self.pattern.shape)
# plt.imshow(self.pattern[0,0,:,:].detach().cpu().numpy())
# plt.show()

return self.pattern

Expand Down
113 changes: 9 additions & 104 deletions data_handler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import os
from pytorch_lightning.utilities.types import EVAL_DATALOADERS
import torch
import scipy.io as sio
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
import random
from pytorch_lightning import LightningDataModule
Expand Down Expand Up @@ -46,38 +42,31 @@ def load_hyperspectral_data(self, idx):

def augment(self, img, crop_size = 128):
h, w, _ = img.shape

# Randomly crop
x_index = np.random.randint(0, h - crop_size)
y_index = np.random.randint(0, w - crop_size)
processed_data = np.zeros((crop_size, crop_size, 28), dtype=np.float32)
processed_data = img[x_index:x_index + crop_size, y_index:y_index + crop_size, :]
processed_data = torch.from_numpy(np.transpose(processed_data, (2, 0, 1))).float()

# Randomly flip and rotate
processed_data = arguement_1(processed_data)

""" # The other half data use splicing.
processed_data = np.zeros((4, crop_size//2, crop_size//2, 28), dtype=np.float32)
for i in range(batch_size - batch_size // 2):
sample_list = np.random.randint(0, len(train_data), 4)
for j in range(4):
x_index = np.random.randint(0, h-crop_size//2)
y_index = np.random.randint(0, w-crop_size//2)
processed_data[j] = train_data[sample_list[j]][x_index:x_index+crop_size//2,y_index:y_index+crop_size//2,:]
gt_batch_2 = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))).cuda() # [4,28,128,128]
gt_batch.append(arguement_2(gt_batch_2))
gt_batch = torch.stack(gt_batch, dim=0) """
return processed_data


class CubesDataModule(LightningDataModule):
def __init__(self, data_dir, batch_size, num_workers=1):
def __init__(self, data_dir, batch_size, num_workers=1, augment=True):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
self.num_workers = num_workers
self.dataset = CubesDataset(self.data_dir,augment=True)
self.dataset = CubesDataset(self.data_dir,augment=augment)

def setup(self, stage=None):
dataset_size = len(self.dataset)
train_size = int(0.7 * dataset_size)
train_size = int(0.79 * dataset_size)
val_size = int(0.2 * dataset_size)
test_size = dataset_size - train_size - val_size

Expand All @@ -102,12 +91,11 @@ def test_dataloader(self):
shuffle=False)

def predict_dataloader(self):
return DataLoader(self.train_ds,
return DataLoader(self.dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False)


def arguement_1(x):
"""
:param x: c,h,w
Expand All @@ -125,88 +113,5 @@ def arguement_1(x):
# Random horizontal Flip
for j in range(hFlip):
x = torch.flip(x, dims=(1,))
return x


def shuffle_crop(train_data, batch_size, crop_size=256, argument=True):
if argument:
gt_batch = []
# The first half data use the original data.
index = np.random.choice(range(len(train_data)), batch_size//2)
processed_data = np.zeros((batch_size//2, crop_size, crop_size, 28), dtype=np.float32)
for i in range(batch_size//2):
img = train_data[index[i]]
h, w, _ = img.shape
x_index = np.random.randint(0, h - crop_size)
y_index = np.random.randint(0, w - crop_size)
processed_data[i, :, :, :] = img[x_index:x_index + crop_size, y_index:y_index + crop_size, :]
processed_data = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))).cuda().float()
for i in range(processed_data.shape[0]):
gt_batch.append(arguement_1(processed_data[i]))

# The other half data use splicing.
processed_data = np.zeros((4, crop_size//2, crop_size//2, 28), dtype=np.float32)
for i in range(batch_size - batch_size // 2):
sample_list = np.random.randint(0, len(train_data), 4)
for j in range(4):
x_index = np.random.randint(0, h-crop_size//2)
y_index = np.random.randint(0, w-crop_size//2)
processed_data[j] = train_data[sample_list[j]][x_index:x_index+crop_size//2,y_index:y_index+crop_size//2,:]
gt_batch_2 = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))).cuda() # [4,28,128,128]
gt_batch.append(arguement_2(gt_batch_2))
gt_batch = torch.stack(gt_batch, dim=0)
return gt_batch
else:
index = np.random.choice(range(len(train_data)), batch_size)
processed_data = np.zeros((batch_size, crop_size, crop_size, 28), dtype=np.float32)
for i in range(batch_size):
h, w, _ = train_data[index[i]].shape
x_index = np.random.randint(0, h - crop_size)
y_index = np.random.randint(0, w - crop_size)
processed_data[i, :, :, :] = train_data[index[i]][x_index:x_index + crop_size, y_index:y_index + crop_size, :]
gt_batch = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2)))
return gt_batch

def arguement_2(generate_gt):
c, h, w = generate_gt.shape[1],generate_gt.shape[2],generate_gt.shape[3]
divid_point_h = h//2
divid_point_w = w//2
output_img = torch.zeros(c,h,w).cuda()
output_img[:, :divid_point_h, :divid_point_w] = generate_gt[0]
output_img[:, :divid_point_h, divid_point_w:] = generate_gt[1]
output_img[:, divid_point_h:, :divid_point_w] = generate_gt[2]
output_img[:, divid_point_h:, divid_point_w:] = generate_gt[3]
return output_img

# class AcquisitionDataset(Dataset):
# def __init__(self, input, hs_cubes, transform=None, target_transform=None):
# """_summary_

# Args:
# input (_type_): List of size 2 with each element being a list:
# - First list: List of n torch.tensor acquisitions (2D)
# - Second list: List of n int labels
# hs_cubes (_type_): List of size m, hs_cubes[m] being the m-th hs cube
# transform (_type_, optional): _description_. Defaults to None.
# target_transform (_type_, optional): _description_. Defaults to None.
# """
# self.data = input # list of size 2, first elem is a list of n torch.tensor acquisitions (input), second elem is a list of size n with the index of corresponding hs cubes (output)
# self.labels = self.data[1]

# self.cubes = hs_cubes # list of cubes, number of cubes must be >= max(self.labels)

# self.transform = transform
# self.target_transform = target_transform

# def __len__(self):
# return len(self.data[1])

# def __getitem__(self, index):
# acq = self.data[0][index] # torch tensor of size x*y
# cube = self.cubes[self.labels[index]] # torch tensor of size x*y*w

# return acq, cube

if __name__ == "__main__":
data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28/"
datamodule = CubesDataModule(data_dir, batch_size=5, num_workers=2)
return x
34 changes: 34 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: simca
channels:
- defaults
- nvidia
- pytorch
- pyg
- conda-forge
dependencies:
- einops=0.7.0
- fvcore=0.1.5
- h5py=3.10.0
- imageio=2.34.0
- lightning=2.2.1
- cudatoolkit=11.8.0
- opencv=4.9.0
- opticalglass=1.0.7
- pyqtgraph=0.13.3
- seaborn=0.13.2
- segmentation-models-pytorch=0.3.3
- snoop=0.4.3
- spectral=0.23.1
- tensorboard=2.16.2
- pytorch-cuda=12.1
- pytorch=2.1.2=*cuda*
- torchvision=0.16.2
- pip=23.3.1=py39h06a4308_0
- python=3.9.18=h955ad1f_0
- pytorch-cluster=1.6.3
- pyg=2.5.0
- pytorch-scatter=2.1.2
- pytorch-sparse=0.6.18
- pytorch-spline-conv=1.2.2
- pip:
- piqa==1.3.2
Loading

0 comments on commit 7631803

Please sign in to comment.