diff --git a/MST/simulation/train_code/architecture/__init__.py b/MST/simulation/train_code/architecture/__init__.py index 3fa9eb2..b55be2a 100755 --- a/MST/simulation/train_code/architecture/__init__.py +++ b/MST/simulation/train_code/architecture/__init__.py @@ -58,10 +58,19 @@ def model_generator(method, pretrained_model_path=None): else: print(f'Method {method} is not defined !!!!') if pretrained_model_path is not None: - print(f'load model from {pretrained_model_path}') + # print(f'load model from {pretrained_model_path}') + # checkpoint = torch.load(pretrained_model_path) + # model.load_state_dict({k.replace('reconstruction_model.', ''): v for k, v in checkpoint.items()}, + # strict=False) + checkpoint = torch.load(pretrained_model_path) - model.load_state_dict({k.replace('module.', ''): v for k, v in checkpoint.items()}, - strict=False) + + adjusted_state_dict = {key.replace('reconstruction_module.reconstruction_model.', '').replace('reconstruction_model.', ''): value + for key, value in checkpoint['state_dict'].items()} + # Filter out unexpected keys + model_keys = set(model.state_dict().keys()) + filtered_state_dict = {k: v for k, v in adjusted_state_dict.items() if k in model_keys} + model.load_state_dict(filtered_state_dict) if method == 'hdnet': return model,fdl_loss return model \ No newline at end of file diff --git a/MST/simulation/train_code/architecture/__pycache__/BIRNAT.cpython-39.pyc b/MST/simulation/train_code/architecture/__pycache__/BIRNAT.cpython-39.pyc index 39ac929..a61ae14 100644 Binary files a/MST/simulation/train_code/architecture/__pycache__/BIRNAT.cpython-39.pyc and b/MST/simulation/train_code/architecture/__pycache__/BIRNAT.cpython-39.pyc differ diff --git a/MST/simulation/train_code/architecture/__pycache__/DAUHST.cpython-39.pyc b/MST/simulation/train_code/architecture/__pycache__/DAUHST.cpython-39.pyc index ecee790..e1637b7 100644 Binary files a/MST/simulation/train_code/architecture/__pycache__/DAUHST.cpython-39.pyc and b/MST/simulation/train_code/architecture/__pycache__/DAUHST.cpython-39.pyc differ diff --git a/MST/simulation/train_code/architecture/__pycache__/__init__.cpython-39.pyc b/MST/simulation/train_code/architecture/__pycache__/__init__.cpython-39.pyc index e821b7c..7e9be86 100644 Binary files a/MST/simulation/train_code/architecture/__pycache__/__init__.cpython-39.pyc and b/MST/simulation/train_code/architecture/__pycache__/__init__.cpython-39.pyc differ diff --git a/Resnet_only.py b/Resnet_only.py index 1365597..413d368 100644 --- a/Resnet_only.py +++ b/Resnet_only.py @@ -1,19 +1,16 @@ import pytorch_lightning as pl import torch import torch.nn as nn -from simca.CassiSystem_lightning import CassiSystemOptim from MST.simulation.train_code.architecture import * -from simca import load_yaml_config import matplotlib.pyplot as plt import torchvision import numpy as np from simca.functions_acquisition import * -from piqa import SSIM from torch.utils.tensorboard import SummaryWriter import io import torchvision.transforms as transforms from PIL import Image -from optimization_modules_with_resnet import UnetModel +from optimization_modules_full import UnetModel class ResnetOnly(pl.LightningModule): @@ -48,16 +45,9 @@ def forward(self, x, pattern=None): self.acquisition = self.acquisition.flip(2) self.acquisition = self.acquisition.unsqueeze(1).float() - # print("acquisition shape: ", self.acquisition.shape) - # plt.imshow(self.acquisition[0,0,:,:].cpu().numpy()) - # plt.show() - self.pattern = self.mask_generation(self.acquisition) self.pattern = BinarizeFunction.apply(self.pattern) - # print("pattern shape: ", self.pattern.shape) - # plt.imshow(self.pattern[0,0,:,:].detach().cpu().numpy()) - # plt.show() return self.pattern diff --git a/data_handler.py b/data_handler.py index e33b896..660f9d4 100755 --- a/data_handler.py +++ b/data_handler.py @@ -1,12 +1,8 @@ import os -from pytorch_lightning.utilities.types import EVAL_DATALOADERS import torch import scipy.io as sio from torch.utils.data import Dataset, DataLoader from torch.utils.data import random_split -from torchvision import datasets -from torchvision.transforms import ToTensor -import matplotlib.pyplot as plt import numpy as np import random from pytorch_lightning import LightningDataModule @@ -46,38 +42,31 @@ def load_hyperspectral_data(self, idx): def augment(self, img, crop_size = 128): h, w, _ = img.shape + + # Randomly crop x_index = np.random.randint(0, h - crop_size) y_index = np.random.randint(0, w - crop_size) processed_data = np.zeros((crop_size, crop_size, 28), dtype=np.float32) processed_data = img[x_index:x_index + crop_size, y_index:y_index + crop_size, :] processed_data = torch.from_numpy(np.transpose(processed_data, (2, 0, 1))).float() + + # Randomly flip and rotate processed_data = arguement_1(processed_data) - """ # The other half data use splicing. - processed_data = np.zeros((4, crop_size//2, crop_size//2, 28), dtype=np.float32) - for i in range(batch_size - batch_size // 2): - sample_list = np.random.randint(0, len(train_data), 4) - for j in range(4): - x_index = np.random.randint(0, h-crop_size//2) - y_index = np.random.randint(0, w-crop_size//2) - processed_data[j] = train_data[sample_list[j]][x_index:x_index+crop_size//2,y_index:y_index+crop_size//2,:] - gt_batch_2 = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))).cuda() # [4,28,128,128] - gt_batch.append(arguement_2(gt_batch_2)) - gt_batch = torch.stack(gt_batch, dim=0) """ return processed_data class CubesDataModule(LightningDataModule): - def __init__(self, data_dir, batch_size, num_workers=1): + def __init__(self, data_dir, batch_size, num_workers=1, augment=True): super().__init__() self.data_dir = data_dir self.batch_size = batch_size self.num_workers = num_workers - self.dataset = CubesDataset(self.data_dir,augment=True) + self.dataset = CubesDataset(self.data_dir,augment=augment) def setup(self, stage=None): dataset_size = len(self.dataset) - train_size = int(0.7 * dataset_size) + train_size = int(0.79 * dataset_size) val_size = int(0.2 * dataset_size) test_size = dataset_size - train_size - val_size @@ -102,12 +91,11 @@ def test_dataloader(self): shuffle=False) def predict_dataloader(self): - return DataLoader(self.train_ds, + return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) - def arguement_1(x): """ :param x: c,h,w @@ -125,88 +113,5 @@ def arguement_1(x): # Random horizontal Flip for j in range(hFlip): x = torch.flip(x, dims=(1,)) - return x - - -def shuffle_crop(train_data, batch_size, crop_size=256, argument=True): - if argument: - gt_batch = [] - # The first half data use the original data. - index = np.random.choice(range(len(train_data)), batch_size//2) - processed_data = np.zeros((batch_size//2, crop_size, crop_size, 28), dtype=np.float32) - for i in range(batch_size//2): - img = train_data[index[i]] - h, w, _ = img.shape - x_index = np.random.randint(0, h - crop_size) - y_index = np.random.randint(0, w - crop_size) - processed_data[i, :, :, :] = img[x_index:x_index + crop_size, y_index:y_index + crop_size, :] - processed_data = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))).cuda().float() - for i in range(processed_data.shape[0]): - gt_batch.append(arguement_1(processed_data[i])) - - # The other half data use splicing. - processed_data = np.zeros((4, crop_size//2, crop_size//2, 28), dtype=np.float32) - for i in range(batch_size - batch_size // 2): - sample_list = np.random.randint(0, len(train_data), 4) - for j in range(4): - x_index = np.random.randint(0, h-crop_size//2) - y_index = np.random.randint(0, w-crop_size//2) - processed_data[j] = train_data[sample_list[j]][x_index:x_index+crop_size//2,y_index:y_index+crop_size//2,:] - gt_batch_2 = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))).cuda() # [4,28,128,128] - gt_batch.append(arguement_2(gt_batch_2)) - gt_batch = torch.stack(gt_batch, dim=0) - return gt_batch - else: - index = np.random.choice(range(len(train_data)), batch_size) - processed_data = np.zeros((batch_size, crop_size, crop_size, 28), dtype=np.float32) - for i in range(batch_size): - h, w, _ = train_data[index[i]].shape - x_index = np.random.randint(0, h - crop_size) - y_index = np.random.randint(0, w - crop_size) - processed_data[i, :, :, :] = train_data[index[i]][x_index:x_index + crop_size, y_index:y_index + crop_size, :] - gt_batch = torch.from_numpy(np.transpose(processed_data, (0, 3, 1, 2))) - return gt_batch - -def arguement_2(generate_gt): - c, h, w = generate_gt.shape[1],generate_gt.shape[2],generate_gt.shape[3] - divid_point_h = h//2 - divid_point_w = w//2 - output_img = torch.zeros(c,h,w).cuda() - output_img[:, :divid_point_h, :divid_point_w] = generate_gt[0] - output_img[:, :divid_point_h, divid_point_w:] = generate_gt[1] - output_img[:, divid_point_h:, :divid_point_w] = generate_gt[2] - output_img[:, divid_point_h:, divid_point_w:] = generate_gt[3] - return output_img - -# class AcquisitionDataset(Dataset): -# def __init__(self, input, hs_cubes, transform=None, target_transform=None): -# """_summary_ - -# Args: -# input (_type_): List of size 2 with each element being a list: -# - First list: List of n torch.tensor acquisitions (2D) -# - Second list: List of n int labels -# hs_cubes (_type_): List of size m, hs_cubes[m] being the m-th hs cube -# transform (_type_, optional): _description_. Defaults to None. -# target_transform (_type_, optional): _description_. Defaults to None. -# """ -# self.data = input # list of size 2, first elem is a list of n torch.tensor acquisitions (input), second elem is a list of size n with the index of corresponding hs cubes (output) -# self.labels = self.data[1] - -# self.cubes = hs_cubes # list of cubes, number of cubes must be >= max(self.labels) - -# self.transform = transform -# self.target_transform = target_transform - -# def __len__(self): -# return len(self.data[1]) - -# def __getitem__(self, index): -# acq = self.data[0][index] # torch tensor of size x*y -# cube = self.cubes[self.labels[index]] # torch tensor of size x*y*w - -# return acq, cube -if __name__ == "__main__": - data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28/" - datamodule = CubesDataModule(data_dir, batch_size=5, num_workers=2) \ No newline at end of file + return x \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100755 index 0000000..4f42ad4 --- /dev/null +++ b/environment.yml @@ -0,0 +1,34 @@ +name: simca +channels: + - defaults + - nvidia + - pytorch + - pyg + - conda-forge +dependencies: + - einops=0.7.0 + - fvcore=0.1.5 + - h5py=3.10.0 + - imageio=2.34.0 + - lightning=2.2.1 + - cudatoolkit=11.8.0 + - opencv=4.9.0 + - opticalglass=1.0.7 + - pyqtgraph=0.13.3 + - seaborn=0.13.2 + - segmentation-models-pytorch=0.3.3 + - snoop=0.4.3 + - spectral=0.23.1 + - tensorboard=2.16.2 + - pytorch-cuda=12.1 + - pytorch=2.1.2=*cuda* + - torchvision=0.16.2 + - pip=23.3.1=py39h06a4308_0 + - python=3.9.18=h955ad1f_0 + - pytorch-cluster=1.6.3 + - pyg=2.5.0 + - pytorch-scatter=2.1.2 + - pytorch-sparse=0.6.18 + - pytorch-spline-conv=1.2.2 + - pip: + - piqa==1.3.2 \ No newline at end of file diff --git a/optimization_modules.py b/optimization_modules.py index 225d3cc..a2a59d1 100755 --- a/optimization_modules.py +++ b/optimization_modules.py @@ -3,7 +3,7 @@ import torch.nn as nn from simca.CassiSystem_lightning import CassiSystemOptim from MST.simulation.train_code.architecture import * -from simca import load_yaml_config +from simca import load_yaml_config, save_config_file import matplotlib.pyplot as plt import torchvision import numpy as np @@ -12,19 +12,24 @@ from torch.utils.tensorboard import SummaryWriter import io import torchvision.transforms as transforms +from torchmetrics.image import PeakSignalNoiseRatio +import os from PIL import Image class JointReconstructionModule_V1(pl.LightningModule): - def __init__(self, model_name,log_dir="tb_logs", reconstruction_checkpoint=None): + def __init__(self, model_name,log_dir="tb_logs", reconstruction_checkpoint=None, fix_random_pattern=False): super().__init__() self.model_name = model_name self.reconstruction_model = model_generator(self.model_name, reconstruction_checkpoint) + self.fix_random_pattern = fix_random_pattern self.loss_fn = nn.MSELoss() - #self.ssim_loss = SSIM(window_size=11, n_channels=28) - self.ssim_loss = SSIM(window_size=11, n_channels=3) + self.ssim_loss = SSIM(window_size=11, n_channels=28) + + # for param in self.reconstruction_model.parameters(): + # param.requires_grad = False self.writer = SummaryWriter(log_dir) @@ -39,6 +44,15 @@ def on_validation_start(self,stage=None): def on_predict_start(self,stage=None): print("---PREDICT START---") + + if not os.path.exists('./results'): + os.makedirs('./results') + + if self.ssim_loss.kernel.shape[0]!=28: + self.ssim_loss = SSIM(window_size=11, n_channels=28).to(self.device) + + self.psnr = PeakSignalNoiseRatio().to(self.device) + self.config = "simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml" config_system = load_yaml_config(self.config) @@ -46,6 +60,22 @@ def on_predict_start(self,stage=None): self.cassi_system = CassiSystemOptim(system_config=config_system) self.cassi_system.propagate_coded_aperture_grid() + try: + if self.fix_random_pattern: + file_name = "predict_results_recons_fixed_random.yml" + else: + file_name = "predict_results_recons.yml" + self.predict_results = load_yaml_config(file_name) + except: + self.predict_results = {} + + def on_predict_end(self): + if self.fix_random_pattern: + file_name = "predict_results_recons_fixed_random" + else: + file_name = "predict_results_recons" + save_config_file('./results/' + file_name,self.predict_results,".") + def _normalize_data_by_itself(self, data): # Calculate the mean and std for each batch individually # Keep dimensions for broadcasting @@ -64,38 +94,22 @@ def forward(self, x, pattern=None): hyperspectral_cube = hyperspectral_cube.permute(0, 2, 3, 1).to(self.device) batch_size, H, W, C = hyperspectral_cube.shape - # fig, ax = plt.subplots(1, 1) - # plt.title(f"entry cube") - # ax.imshow(hyperspectral_cube[0, :, :, 0].cpu().detach().numpy()) - # plt.show() - # print(f"batch size:{batch_size}") - # generate pattern - + # Generate pattern if pattern is None: - self.pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size) + self.pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size, fix_random_pattern=self.fix_random_pattern) self.pattern = self.pattern.to(self.device) else: self.pattern = pattern.to(self.device) self.cassi_system.pattern = pattern.to(self.device) - # plt.imshow(pattern[0, :, :].cpu().detach().numpy()) - # plt.show() - - # print(f"pattern_size: {pattern.shape}") - - # generate first acquisition with simca - + # Generate first acquisition with simca filtering_cube = self.cassi_system.generate_filtering_cube().to(self.device) self.acquired_image1 = self.cassi_system.image_acquisition(hyperspectral_cube, self.pattern, wavelengths).to(self.device) - # self.acquired_image1 = self._normalize_data_by_itself(self.acquired_image1) - # acquired_cubes = self.acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C - filtering_cubes = subsample(filtering_cube, torch.linspace(450, 650, filtering_cube.shape[-1]), torch.linspace(450, 650, 28)).permute((0, 3, 1, 2)).float().to(self.device) if self.model_name == "birnat": - # acquisition = self.acquired_image1.unsqueeze(1) acquisition = self.acquired_image1.float() filtering_cubes = filtering_cubes.float() elif "dauhst" in self.model_name: @@ -107,8 +121,6 @@ def forward(self, x, pattern=None): elif self.model_name == "mst_plus_plus": acquisition = self.acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) - #print(f"acquisition shape: {acquisition.shape}") - #print(f"filtering_cubes shape: {filtering_cubes.shape}") reconstructed_cube = self.reconstruction_model(acquisition, filtering_cubes) @@ -175,6 +187,14 @@ def validation_step(self, batch, batch_idx): prog_bar=True, ) + self.log_dict( + { "val_ssim_loss": ssim_loss, + }, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + return {"loss": loss} def test_step(self, batch, batch_idx): @@ -192,14 +212,30 @@ def test_step(self, batch, batch_idx): def predict_step(self, batch, batch_idx): print("Predict step") loss, ssim_loss, reconstructed_cube, ref_cube = self._common_step(batch, batch_idx) - print("Predict loss: ", loss.item()) - print("Predict ssim loss: ", ssim_loss) - #self.log('predict_step', loss,on_step=True, on_epoch=True, prog_bar=True, logger=True) + psnr_loss = self.psnr(reconstructed_cube.permute(0, 3, 1, 2), ref_cube.permute(0, 3, 1, 2)) + + print("Predict PSNR loss: ", psnr_loss.item()) + print("Predict RMSE loss: ", loss.item()) + print("Predict SSIM loss: ", ssim_loss.item()) + + ignored_ids = [2, 3, 9, 10, 11, 12, 13, 17] # Those contain a low amount of signal and are thus less interesting + if (batch_idx+1) not in ignored_ids: + try: + self.predict_results[f"RMSE_scene{batch_idx+1}"].append(loss.item()) + self.predict_results[f"SSIM_scene{batch_idx+1}"].append(ssim_loss.item()) + self.predict_results[f"PSNR_scene{batch_idx+1}"].append(psnr_loss.item()) + except: + self.predict_results[f"RMSE_scene{batch_idx+1}"] = [loss.item()] + self.predict_results[f"SSIM_scene{batch_idx+1}"] = [ssim_loss.item()] + self.predict_results[f"PSNR_scene{batch_idx+1}"] = [psnr_loss.item()] + + if batch_idx == 19-1: + torch.save(reconstructed_cube, f'./results/recons_cube_random.pt') + return loss def _common_step(self, batch, batch_idx): - reconstructed_cube = self.forward(batch) hyperspectral_cube, wavelengths = batch @@ -207,16 +243,8 @@ def _common_step(self, batch, batch_idx): reconstructed_cube = reconstructed_cube.permute(0, 2, 3, 1).to(self.device) ref_cube = match_dataset_to_instrument(hyperspectral_cube, reconstructed_cube[0, :, :,0]) - # fig, ax = plt.subplots(1, 2) - # plt.title(f"true cube vs reconstructed cube") - # ax[0].imshow(hyperspectral_cube[0, :, :, 0].cpu().detach().numpy()) - # ax[1].imshow(reconstructed_cube[0, :, :, 0].cpu().detach().numpy()) - # plt.show() - - loss = torch.sqrt(self.loss_fn(reconstructed_cube, ref_cube)) ssim_loss = self.ssim_loss(torch.clamp(reconstructed_cube.permute(0, 3, 1, 2), 0, 1), ref_cube.permute(0, 3, 1, 2)) - #ssim_loss = 0 return loss, ssim_loss, reconstructed_cube, ref_cube @@ -245,7 +273,7 @@ def _convert_output_to_images(self, acquired_images): def plot_spectral_filter(self,ref_hyperspectral_cube,recontructed_hyperspectral_cube): - batch_size, y,x, lmabda_ = ref_hyperspectral_cube.shape + batch_size, y,x, lambda_ = ref_hyperspectral_cube.shape # Create a figure with subplots arranged horizontally fig, axs = plt.subplots(1, batch_size, figsize=(batch_size * 5, 4)) # Adjust figure size as needed @@ -291,25 +319,10 @@ def plot_spectral_filter(self,ref_hyperspectral_cube,recontructed_hyperspectral_ def subsample(input, origin_sampling, target_sampling): - [bs, row, col, nC] = input.shape + # Subsample input from origin_sampling to target_sampling indices = torch.zeros(len(target_sampling), dtype=torch.int) for i in range(len(target_sampling)): sample = target_sampling[i] idx = torch.abs(origin_sampling-sample).argmin() indices[i] = idx return input[:,:,:,indices] - -def expand_mask_3d(mask_batch): - if len(mask_batch.shape)==3: - mask3d = mask_batch.unsqueeze(-1).repeat((1, 1, 1, 28)) - else: - mask3d = mask_batch.repeat((1, 1, 1, 28)) - mask3d = torch.permute(mask3d, (0, 3, 1, 2)) - return mask3d - -class EmptyModule(nn.Module): - def __init__(self): - super().__init__() - self.useless_linear = nn.Linear(1, 1) - def forward(self, x): - return x diff --git a/optimization_modules_with_resnet.py b/optimization_modules_full.py similarity index 63% rename from optimization_modules_with_resnet.py rename to optimization_modules_full.py index 510a7f8..fffd36a 100755 --- a/optimization_modules_with_resnet.py +++ b/optimization_modules_full.py @@ -3,7 +3,7 @@ import torch.nn as nn from simca.CassiSystem_lightning import CassiSystemOptim from MST.simulation.train_code.architecture import * -from simca import load_yaml_config +from simca import load_yaml_config, save_config_file import matplotlib.pyplot as plt import torchvision import numpy as np @@ -11,8 +11,10 @@ from piqa import SSIM from torch.utils.tensorboard import SummaryWriter import io +import os import torchvision.transforms as transforms from PIL import Image +from torchmetrics.image import PeakSignalNoiseRatio import segmentation_models_pytorch as smp import torch.nn.functional as F @@ -28,23 +30,58 @@ def forward(self,x): class JointReconstructionModule_V2(pl.LightningModule): - def __init__(self, model_name,log_dir="tb_logs",reconstruction_checkpoint=None): + def __init__(self, model_name, log_dir="tb_logs", mask_model="learned_mask", + reconstruction_checkpoint=None, resnet_checkpoint = None, + full_checkpoint=None, + train_reconstruction=False): super().__init__() self.model_name = model_name - self.reconstruction_model = model_generator(self.model_name, pretrained_model_path=None) - if reconstruction_checkpoint is not None: - #self.reconstruction_model = model_generator(self.model_name, pretrained_model_path=reconstruction_checkpoint) - self.reconstruction_model.load_state_dict(torch.load(reconstruction_checkpoint), strict=False) - self.mask_generation = UnetModel(classes=1,encoder_weights=None,in_channels=1) + self.mask_model = mask_model + if full_checkpoint is None: + self.reconstruction_model = model_generator(self.model_name, pretrained_model_path=reconstruction_checkpoint) + else: # In that case, full_checkpoint also contains the weights of the model + self.reconstruction_model = model_generator(self.model_name, pretrained_model_path=full_checkpoint) + + if self.mask_model=="resnet": + self.mask_generation = UnetModel(classes=1,encoder_weights=None,in_channels=1) + elif (self.mask_model=="learned_mask") or (self.mask_model=="learned_mask_float"): + self.learned_pattern = nn.Parameter(torch.rand(131, 131, requires_grad=True)) + + if (self.mask_model=="resnet") and (resnet_checkpoint is not None) and (full_checkpoint is None): + # Load only a pretrained resnet + checkpoint = torch.load(resnet_checkpoint, map_location=self.device) + if (self.mask_model=="resnet"): + # Adjust the keys + adjusted_state_dict = {key.replace('mask_generation.', ''): value + for key, value in checkpoint['state_dict'].items()} + # Filter out unexpected keys + model_keys = set(self.mask_generation.state_dict().keys()) + filtered_state_dict = {k: v for k, v in adjusted_state_dict.items() if k in model_keys} + self.mask_generation.load_state_dict(filtered_state_dict) + + if full_checkpoint is not None: + # Load the weights from the checkpoint into self.model + checkpoint = torch.load(full_checkpoint, map_location=self.device) + if (self.mask_model=="resnet"): + # Adjust the keys + adjusted_state_dict = {key.replace('mask_generation.', ''): value + for key, value in checkpoint['state_dict'].items()} + # Filter out unexpected keys + model_keys = set(self.mask_generation.state_dict().keys()) + filtered_state_dict = {k: v for k, v in adjusted_state_dict.items() if k in model_keys} + self.mask_generation.load_state_dict(filtered_state_dict) + elif "learned_mask" in self.mask_model: + self.learned_pattern = nn.Parameter(torch.tensor(checkpoint['state_dict']['learned_pattern'], device=self.device, requires_grad=True)) self.loss_fn = nn.MSELoss() self.ssim_loss = SSIM(window_size=11, n_channels=28) self.writer = SummaryWriter(log_dir) - # for param in self.reconstruction_model.parameters(): - # param.requires_grad = False + if not train_reconstruction: + for param in self.reconstruction_model.parameters(): + param.requires_grad = False def on_validation_start(self,stage=None): print("---VALIDATION START---") @@ -53,6 +90,28 @@ def on_validation_start(self,stage=None): self.config_patterns = load_yaml_config("simca/configs/pattern.yml") self.cassi_system = CassiSystemOptim(system_config=config_system) self.cassi_system.propagate_coded_aperture_grid() + + def on_predict_start(self,stage=None): + print("---PREDICT START---") + + if not os.path.exists('./results'): + os.makedirs('./results') + + self.config = "simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml" + config_system = load_yaml_config(self.config) + self.config_patterns = load_yaml_config("simca/configs/pattern.yml") + self.cassi_system = CassiSystemOptim(system_config=config_system) + self.cassi_system.propagate_coded_aperture_grid() + + self.psnr = PeakSignalNoiseRatio().to(self.device) + + try: + self.predict_results = load_yaml_config(f"./results/predict_results_full_{self.mask_model}.yml") + except: + self.predict_results = {} + + def on_predict_end(self): + save_config_file(f"./results/predict_results_full_{self.mask_model}",self.predict_results,".") def _normalize_data_by_itself(self, data): # Calculate the mean and std for each batch individually @@ -72,32 +131,34 @@ def forward(self, x): hyperspectral_cube = hyperspectral_cube.permute(0, 2, 3, 1).to(self.device) batch_size, H, W, C = hyperspectral_cube.shape + if self.mask_model=="resnet": + self.pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size) + self.pattern = self.pattern.to(self.device) + filtering_cube = self.cassi_system.generate_filtering_cube().to(self.device) + self.acquired_image1 = self.cassi_system.image_acquisition(hyperspectral_cube, self.pattern, wavelengths).to(self.device) - self.pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size) - self.pattern = self.pattern.to(self.device) - filtering_cube = self.cassi_system.generate_filtering_cube().to(self.device) - self.acquired_image1 = self.cassi_system.image_acquisition(hyperspectral_cube, self.pattern, wavelengths).to(self.device) + self.acquired_image1 = pad_tensor(self.acquired_image1, (128, 128)) - self.acquired_image1 = pad_tensor(self.acquired_image1, (128, 128)) + # flip along second and thrid axis + self.acquired_image1 = self.acquired_image1.flip(1) + self.acquired_image1 = self.acquired_image1.flip(2) + self.acquired_image1 = self.acquired_image1.unsqueeze(1).float() + #self.acquired_image1 = self._normalize_data_by_itself(self.acquired_image1) - # flip along second and thrid axis - self.acquired_image1 = self.acquired_image1.flip(1) - self.acquired_image1 = self.acquired_image1.flip(2) - self.acquired_image1 = self.acquired_image1.unsqueeze(1).float() - #self.acquired_image1 = self._normalize_data_by_itself(self.acquired_image1) - - - self.pattern = self.mask_generation(self.acquired_image1).squeeze(1) - self.pattern = BinarizeFunction.apply(self.pattern) - self.pattern = pad_tensor(self.pattern, (131, 131)) - self.cassi_system.pattern = self.pattern.to(self.device) + self.pattern = self.mask_generation(self.acquired_image1).squeeze(1) + self.pattern = BinarizeFunction.apply(self.pattern) + self.pattern = pad_tensor(self.pattern, (131, 131)) + elif (self.mask_model=="learned_mask") or (self.mask_model=="learned_mask_float"): + self.pattern = torch.clamp(self.learned_pattern.unsqueeze(0).repeat(batch_size, 1,1), 0, 1).float().to(self.device) + if self.mask_model=="learned_mask": + self.pattern = BinarizeFunction.apply(self.pattern) + + self.cassi_system.pattern = self.pattern.to(self.device) filtering_cube = self.cassi_system.generate_filtering_cube().to(self.device) self.acquired_image2 = self.cassi_system.image_acquisition(hyperspectral_cube, self.pattern, wavelengths).to(self.device) - # self.acquired_image1 = self._normalize_data_by_itself(self.acquired_image1) - # acquired_cubes = self.acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C filtering_cubes = subsample(filtering_cube, torch.linspace(450, 650, filtering_cube.shape[-1]), torch.linspace(450, 650, 28)).permute((0, 3, 1, 2)).float().to(self.device) @@ -112,12 +173,11 @@ def forward(self, x): filtering_cubes = (filtering_cubes.float(), filtering_cubes_s.float()) elif self.model_name == "mst_plus_plus": - acquisition = self.acquired_image2.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) + acquisition = self.acquired_image2.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x C x H x W reconstructed_cube = self.reconstruction_model(acquisition, filtering_cubes) - return reconstructed_cube @@ -139,12 +199,9 @@ def training_step(self, batch, batch_idx): self._log_images('train/reconstructed', reconstructed_image, self.global_step) self._log_images('train/patterns', patterns, self.global_step) - plt.imshow(self.pattern[0,:,:].cpu().detach().numpy()) - plt.colorbar() - plt.savefig('./pattern.png') - spectral_filter_plot = self.plot_spectral_filter(ref_cube,reconstructed_cube) - self.log_gradients(self.global_step) + if self.mask_model=="resnet": + self.log_gradients(self.global_step) self.writer.add_image('Spectral Filter', spectral_filter_plot, self.global_step) self.log_dict( @@ -219,14 +276,31 @@ def test_step(self, batch, batch_idx): def predict_step(self, batch, batch_idx): print("Predict step") loss, ssim_loss, reconstructed_cube, ref_cube= self._common_step(batch, batch_idx) - print("Predict loss: ", loss.item()) - print("Predict ssim loss: ", ssim_loss) - #self.log('predict_step', loss,on_step=True, on_epoch=True, prog_bar=True, logger=True) + psnr_loss = self.psnr(reconstructed_cube.permute(0, 3, 1, 2), ref_cube.permute(0, 3, 1, 2)) + + print("Predict PSNR loss: ", psnr_loss.item()) + print("Predict RMSE loss: ", loss.item()) + print("Predict SSIM loss: ", ssim_loss.item()) + + ignored_ids = [2, 3, 9, 10, 11, 12, 13, 17] # Those contain a low amount of signal and are thus less interesting + if (batch_idx+1) not in ignored_ids: + try: + self.predict_results[f"RMSE_scene{batch_idx+1}"].append(loss.item()) + self.predict_results[f"SSIM_scene{batch_idx+1}"].append(ssim_loss.item()) + self.predict_results[f"PSNR_scene{batch_idx+1}"].append(psnr_loss.item()) + except: + self.predict_results[f"RMSE_scene{batch_idx+1}"] = [loss.item()] + self.predict_results[f"SSIM_scene{batch_idx+1}"] = [ssim_loss.item()] + self.predict_results[f"PSNR_scene{batch_idx+1}"] = [psnr_loss.item()] + + # extract reconstructed cube and ref cube + if batch_idx == 19-1: + torch.save(reconstructed_cube, f'./results/recons_cube_{self.mask_model}.pt') + torch.save(ref_cube, f'./results/gt_cube.pt') + return loss def _common_step(self, batch, batch_idx): - - reconstructed_cube = self.forward(batch) hyperspectral_cube, wavelengths = batch @@ -234,30 +308,20 @@ def _common_step(self, batch, batch_idx): reconstructed_cube = reconstructed_cube.permute(0, 2, 3, 1).to(self.device) ref_cube = match_dataset_to_instrument(hyperspectral_cube, reconstructed_cube[0, :, :,0]) - # fig, ax = plt.subplots(1, 2) - # plt.title(f"true cube vs reconstructed cube") - # ax[0].imshow(hyperspectral_cube[0, :, :, 0].cpu().detach().numpy()) - # ax[1].imshow(reconstructed_cube[0, :, :, 0].cpu().detach().numpy()) - # plt.show() total_sum_pattern = torch.sum(self.pattern, dim=(1, 2)) total_half_pattern_equal_1 = torch.sum(torch.ones_like(self.pattern), dim=(1, 2)) / 2 - print(f"total_sum_pattern {total_sum_pattern}") - print(f"total_half_pattern_equal_1 {total_half_pattern_equal_1}") - - loss1 = torch.sqrt(self.loss_fn(reconstructed_cube, ref_cube)) - loss2 = torch.sum(torch.abs((total_sum_pattern - total_half_pattern_equal_1)/(self.pattern.shape[1]*self.pattern.shape[2]))**2) - loss = loss1 + loss2 + loss2 = torch.sum(torch.abs((total_sum_pattern - total_half_pattern_equal_1)/(self.pattern.shape[1]*self.pattern.shape[2]))**2) # Force it to have about 50% opened mirrors + loss = loss1 ssim_loss = self.ssim_loss(torch.clamp(reconstructed_cube.permute(0, 3, 1, 2), 0, 1), ref_cube.permute(0, 3, 1, 2)) - print(f"loss1 {loss1}") - print(f"loss2 {loss2}") + return loss, ssim_loss, reconstructed_cube, ref_cube def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=4e-4) + optimizer = torch.optim.Adam(self.parameters(), lr=1.5e-4) return { "optimizer":optimizer, "lr_scheduler":{ "scheduler":torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6), @@ -333,7 +397,6 @@ def plot_spectral_filter(self,ref_hyperspectral_cube,recontructed_hyperspectral_ def subsample(input, origin_sampling, target_sampling): - [bs, row, col, nC] = input.shape indices = torch.zeros(len(target_sampling), dtype=torch.int) for i in range(len(target_sampling)): sample = target_sampling[i] @@ -341,21 +404,6 @@ def subsample(input, origin_sampling, target_sampling): indices[i] = idx return input[:,:,:,indices] -def expand_mask_3d(mask_batch): - if len(mask_batch.shape)==3: - mask3d = mask_batch.unsqueeze(-1).repeat((1, 1, 1, 28)) - else: - mask3d = mask_batch.repeat((1, 1, 1, 28)) - mask3d = torch.permute(mask3d, (0, 3, 1, 2)) - return mask3d - -class EmptyModule(nn.Module): - def __init__(self): - super().__init__() - self.useless_linear = nn.Linear(1, 1) - def forward(self, x): - return x - def pad_tensor(input_tensor, target_shape): [bs, row, col] = input_tensor.shape @@ -382,18 +430,6 @@ def crop_tensor(input_tensor, target_shape): crop_col = (col-target_col)//2 return input_tensor[:,crop_row:crop_row+target_row,crop_col:crop_col+target_col] -# img = torch.randn(3, 112, 112) # Example tensor -# target_height = 128 -# target_width = 128 -# -# plt.imshow(img[0,...].cpu().detach().numpy()) -# plt.show() -# -# zda = pad_tensor(img, (target_height, target_width)) -# -# plt.imshow(zda[0,...].cpu().detach().numpy()) -# plt.show() - class BinarizeFunction(torch.autograd.Function): @staticmethod diff --git a/show_spectrum.py b/show_spectrum.py new file mode 100755 index 0000000..d57fad6 --- /dev/null +++ b/show_spectrum.py @@ -0,0 +1,61 @@ +import torch +import matplotlib.pyplot as plt +import matplotlib.patches as patches +import numpy as np + +file_binary = "./metrics-results/recons_cube_learned_mask.pt" +file_float = "./metrics-results/recons_cube_learned_mask_float.pt" +file_random = "./metrics-results/recons_cube_random.pt" +gt = "./metrics-results/gt_cube.pt" + +binary_cube = torch.load(file_binary, map_location=torch.device('cpu')).cpu()[0,...] +float_cube = torch.load(file_float, map_location=torch.device('cpu')).cpu()[0,...] +random_cube = torch.load(file_random, map_location=torch.device('cpu')).cpu()[0,...] +gt_cube = torch.load(gt, map_location=torch.device('cpu')).cpu()[0,...] + +panchro = torch.sum(gt_cube, dim=2) + +pixel_i_1 = 104 +pixel_j_1 = 105 + +pixel_i_2 = 43 +pixel_j_2 = 80 + +pixel_i_3 = 45 +pixel_j_3 = 2 + +# Slice at 554nm +slice = binary_cube[:,:,14] + +# Cube plot +fig, ax = plt.subplots() +ax.imshow(slice.numpy(), cmap='gray') +rect1 = patches.Rectangle((pixel_j_1, pixel_i_1), 2, 2, linewidth=1, edgecolor='r', facecolor='none') +rect2 = patches.Rectangle((pixel_j_2, pixel_i_2), 2, 2, linewidth=1, edgecolor='r', facecolor='none') +rect3 = patches.Rectangle((pixel_j_3, pixel_i_3), 2, 2, linewidth=1, edgecolor='r', facecolor='none') +ax.add_patch(rect1) +ax.add_patch(rect2) +ax.add_patch(rect3) +plt.show() +fig.savefig('./metrics-results/reconstructed_cube.svg', format='svg') + +# Spectra plot +fig, ax = plt.subplots() +ax.plot(np.linspace(450,650,28), gt_cube[pixel_i_1, pixel_j_1].cpu().numpy(), 'k--', label='truth') +ax.plot(np.linspace(450,650,28), binary_cube[pixel_i_1, pixel_j_1].cpu().numpy(), 'r', label='binary') +ax.plot(np.linspace(450,650,28), float_cube[pixel_i_1, pixel_j_1].cpu().numpy(), 'b', label = 'float') +ax.plot(np.linspace(450,650,28), random_cube[pixel_i_1, pixel_j_1].cpu().numpy(), 'g', label = 'random') + +ax.plot(np.linspace(450,650,28), gt_cube[pixel_i_2, pixel_j_2].cpu().numpy(), 'k--') +ax.plot(np.linspace(450,650,28), binary_cube[pixel_i_2, pixel_j_2].cpu().numpy(), 'r') +ax.plot(np.linspace(450,650,28), float_cube[pixel_i_2, pixel_j_2].cpu().numpy(), 'b') +ax.plot(np.linspace(450,650,28), random_cube[pixel_i_2, pixel_j_2].cpu().numpy(), 'g') + +ax.plot(np.linspace(450,650,28), gt_cube[pixel_i_3, pixel_j_3].cpu().numpy(), 'k--') +ax.plot(np.linspace(450,650,28), binary_cube[pixel_i_3, pixel_j_3].cpu().numpy(), 'r') +ax.plot(np.linspace(450,650,28), float_cube[pixel_i_3, pixel_j_3].cpu().numpy(), 'b') +ax.plot(np.linspace(450,650,28), random_cube[pixel_i_3, pixel_j_3].cpu().numpy(), 'g') +ax.legend() +plt.show() + +fig.savefig('./metrics-results/spectra.svg', format='svg') \ No newline at end of file diff --git a/simca/CassiSystem.py b/simca/CassiSystem.py index f88b4ce..541d35e 100644 --- a/simca/CassiSystem.py +++ b/simca/CassiSystem.py @@ -144,7 +144,7 @@ def interpolate_dataset_along_wavelengths(self, new_wavelengths_sampling, chunk_ else: raise ValueError("The new wavelengths sampling must be inside the dataset wavelengths range") - def generate_2D_pattern(self, config_pattern, nb_of_patterns=1): + def generate_2D_pattern(self, config_pattern, nb_of_patterns=1, fix_random_pattern=False): """ Generate multiple coded aperture 2D patterns based on the "pattern" configuration file and stack them to match the desired number of patterns. @@ -165,7 +165,7 @@ def generate_2D_pattern(self, config_pattern, nb_of_patterns=1): if pattern_type == "random": pattern = generate_random_pattern((self.system_config["coded aperture"]["number of pixels along Y"], self.system_config["coded aperture"]["number of pixels along X"]), - config_pattern['pattern']['ROM']) + config_pattern['pattern']['ROM'], fix_random_pattern) elif pattern_type == "slit": pattern = generate_slit_pattern((self.system_config["coded aperture"]["number of pixels along Y"], self.system_config["coded aperture"]["number of pixels along X"]), diff --git a/simca/CassiSystemTorch.py b/simca/CassiSystemTorch.py deleted file mode 100644 index b85330e..0000000 --- a/simca/CassiSystemTorch.py +++ /dev/null @@ -1,288 +0,0 @@ -from simca.OpticalModel import OpticalModelTorch -from simca.functions_acquisition import * -from simca.functions_patterns_generation import * -from simca.functions_scenes import * -from simca.functions_general_purpose import * -from CassiSystem import CassiSystem -from functions_acquisition_torch import * - -class CassiSystemTorch(CassiSystem): - """Class that contains the cassi system main attributes and methods""" - - def __init__(self, system_config=None, system_config_path=None): - - """ - - Args: - system_config_path (str): path to the configs file - system_config (dict): system configuration - - """ - super().__init__(system_config=system_config, system_config_path=system_config_path) - self.set_up_system(system_config_path=system_config_path, system_config=system_config) - - def set_up_system(self, system_config_path=None, system_config=None): - """ - Loading system config & initializing the grids coordinates for the coded aperture and the detector - - Args: - system_config_path (str): path to the configs file - system_config (dict): system configuration - - """ - - if system_config_path is not None: - self.system_config = load_yaml_config(system_config_path) - elif system_config is not None: - self.system_config = system_config - - self.optical_model = OpticalModelTorch(self.system_config) - - self.X_coded_aper_coordinates, self.Y_coded_aper_coordinates = self.create_coordinates_grid( - self.system_config["coded aperture"]["number of pixels along X"], - self.system_config["coded aperture"]["number of pixels along Y"], - self.system_config["coded aperture"]["pixel size along X"], - self.system_config["coded aperture"]["pixel size along Y"]) - - self.X_detector_coordinates_grid, self.Y_detector_coordinates_grid = self.create_coordinates_grid( - self.system_config["detector"]["number of pixels along X"], - self.system_config["detector"]["number of pixels along Y"], - self.system_config["detector"]["pixel size along X"], - self.system_config["detector"]["pixel size along Y"]) - - def update_config(self, system_config_path=None, system_config=None): - - """ - Update the system configuration file and re-initialize the grids for the coded aperture and the detector - - Args: - system_config_path (str): path to the configs file - system_config (dict): system configuration - Returns: - dict: updated system configuration - - """ - - self.set_up_system(system_config_path=system_config_path, system_config=system_config) - - return self.system_config - - def generate_filtering_cube(self): - """ - Generate filtering cube : each slice of the cube is a propagated pattern interpolated on the detector grid - - Returns: - numpy.ndarray: filtering cube generated according to the optical system & the pattern configuration (R x C x W) - - """ - - self.filtering_cube = interpolate_data_on_grid_positions_torch(data=self.pattern, - X_init=self.X_coordinates_propagated_coded_aperture, - Y_init=self.Y_coordinates_propagated_coded_aperture, - X_target=self.X_detector_coordinates_grid, - Y_target=self.Y_detector_coordinates_grid) - - - return self.filtering_cube - - def generate_multiple_filtering_cubes(self, number_of_patterns): - """ - Generate multiple filtering cubes, each cube corresponds to a pattern, and for each pattern, each slice is a propagated coded apertureinterpolated on the detector grid - - Args: - number_of_patterns (int): number of patterns to generate - Returns: - list: filtering cubes generated according to the current optical system and the pattern configuration - - """ - self.list_of_filtering_cubes = [] - - for idx in range(number_of_patterns): - - self.filtering_cube = interpolate_data_on_grid_positions_torch(data=self.list_of_patterns[idx], - X_init=self.X_coordinates_propagated_coded_aperture, - Y_init=self.Y_coordinates_propagated_coded_aperture, - X_target=self.X_detector_coordinates_grid, - Y_target=self.Y_detector_coordinates_grid) - - self.list_of_filtering_cubes.append(self.filtering_cube) - - return self.list_of_filtering_cubes - - def image_acquisition(self, use_psf=False, chunck_size=50): - """ - Run the acquisition/measurement process depending on the cassi system type - - Args: - chunck_size (int): default block size for the interpolation - - Returns: - numpy.ndarray: compressed measurement (R x C) - """ - - dataset = self.interpolate_dataset_along_wavelengths_torch(self.optical_model.system_wavelengths, chunck_size) - - if dataset is None: - return None - dataset_labels = self.dataset_labels - - if self.system_config["system architecture"]["system type"] == "DD-CASSI": - - try: - self.filtering_cube - except: - return print("Please generate filtering cube first") - - scene = torch.from_numpy(match_dataset_to_instrument(dataset, self.filtering_cube)) - - measurement_in_3D = generate_dd_measurement_torch(scene, self.filtering_cube, chunck_size) - - self.last_filtered_interpolated_scene = measurement_in_3D - self.interpolated_scene = scene - - if dataset_labels is not None: - scene_labels = torch.from_numpy(match_dataset_labels_to_instrument(dataset_labels, self.filtering_cube)) - self.scene_labels = scene_labels - - - elif self.system_config["system architecture"]["system type"] == "SD-CASSI": - - X_coded_aper_coordinates_crop = crop_center(self.X_coded_aper_coordinates,dataset.shape[1], dataset.shape[0]) - Y_coded_aper_coordinates_crop = crop_center(self.Y_coded_aper_coordinates,dataset.shape[1], dataset.shape[0]) - - scene = torch.from_numpy(match_dataset_to_instrument(dataset, X_coded_aper_coordinates_crop)) - - pattern_crop = crop_center(self.pattern, scene.shape[1], scene.shape[0]) - - filtered_scene = scene * pattern_crop[..., None].repeat((1, 1, scene.shape[2])) - - self.propagate_coded_aperture_grid(X_input_grid=X_coded_aper_coordinates_crop, Y_input_grid=Y_coded_aper_coordinates_crop, use_torch = True) - - sd_measurement = interpolate_data_on_grid_positions_torch(filtered_scene, - self.X_coordinates_propagated_coded_aperture, - self.Y_coordinates_propagated_coded_aperture, - self.X_detector_coordinates_grid, - self.Y_detector_coordinates_grid) - - self.last_filtered_interpolated_scene = sd_measurement - self.interpolated_scene = scene - - if dataset_labels is not None: - scene_labels = torch.from_numpy(match_dataset_labels_to_instrument(dataset_labels, self.last_filtered_interpolated_scene)) - self.scene_labels = scene_labels - - self.panchro = torch.sum(self.interpolated_scene, dim=2) - - if use_psf: - self.apply_psf_torch() - else: - print("No PSF was applied") - - # Calculate the other two arrays - self.measurement = torch.sum(self.last_filtered_interpolated_scene, dim=2) - - return self.measurement - - def multiple_image_acquisitions(self, use_psf=False, nb_of_filtering_cubes=1,chunck_size=50): - """ - Run the acquisition process depending on the cassi system type - - Args: - chunck_size (int): default block size for the dataset - - Returns: - list: list of compressed measurements (list of numpy.ndarray of size R x C) - """ - - dataset = self.interpolate_dataset_along_wavelengths_torch(self.optical_model.system_wavelengths, chunck_size) - if dataset is None: - return None - dataset_labels = self.dataset_labels - - self.list_of_filtered_scenes = [] - - if self.system_config["system architecture"]["system type"] == "DD-CASSI": - try: - self.list_of_filtering_cubes - except: - return print("Please generate list of filtering cubes first") - - scene = torch.from_numpy(match_dataset_to_instrument(dataset, self.list_of_filtering_cubes[0])) - - if dataset_labels is not None: - scene_labels = torch.from_numpy(match_dataset_labels_to_instrument(dataset_labels, self.filtering_cube)) - self.scene_labels = scene_labels - - self.interpolated_scene = scene - - for i in range(nb_of_filtering_cubes): - - filtered_scene = generate_dd_measurement_torch(scene, self.list_of_filtering_cubes[i], chunck_size) - self.list_of_filtered_scenes.append(filtered_scene) - - elif self.system_config["system architecture"]["system type"] == "SD-CASSI": - - X_coded_aper_coordinates_crop = crop_center(self.X_coded_aper_coordinates,dataset.shape[1], dataset.shape[0]) - Y_coded_aper_coordinates_crop = crop_center(self.Y_coded_aper_coordinates,dataset.shape[1], dataset.shape[0]) - - scene = torch.from_numpy(match_dataset_to_instrument(dataset, X_coded_aper_coordinates_crop)) - - if dataset_labels is not None: - scene_labels = torch.from_numpy(match_dataset_labels_to_instrument(dataset_labels, self.filtering_cube)) - self.scene_labels = scene_labels - - self.interpolated_scene = scene - - for i in range(nb_of_filtering_cubes): - - mask_crop = crop_center(self.list_of_patterns[i], scene.shape[1], scene.shape[0]) - - filtered_scene = scene * mask_crop[..., None].repeat((1, 1, scene.shape[2])) - - self.propagate_coded_aperture_grid(X_input_grid=X_coded_aper_coordinates_crop, Y_input_grid=Y_coded_aper_coordinates_crop, use_torch = True) - - sd_measurement_cube = interpolate_data_on_grid_positions_torch(filtered_scene, - self.X_coordinates_propagated_coded_aperture, - self.Y_coordinates_propagated_coded_aperture, - self.X_detector_coordinates_grid, - self.Y_detector_coordinates_grid) - self.list_of_filtered_scenes.append(sd_measurement_cube) - - self.panchro = torch.sum(self.interpolated_scene, dim=2) - - if use_psf: - self.apply_psf_torch() - else: - print("No PSF was applied") - - # Calculate the other two arrays - self.list_of_measurements = [] - for i in range(nb_of_filtering_cubes): - self.list_of_measurements.append(torch.sum(self.list_of_filtered_scenes[i], dim=2)) - - return self.list_of_measurements - - def apply_psf(self): - """ - Apply the PSF to the last measurement - - Returns: - numpy.ndarray: last measurement cube convolved with by PSF (shape= R x C x W). Each slice of the 3D filtered scene is convolved with the PSF - """ - if (self.optical_model.psf is not None) and (self.last_filtered_interpolated_scene is not None): - # Expand the dimensions of the 2D matrix to match the 3D matrix - psf_3D = self.optical_model.psf[..., None] - - # Perform the convolution using convolve - result = torch.nn.functional.conv3d(self.last_filtered_interpolated_scene[None, None, ...], torch.flip(psf_3D, (0,1,2))[None, None, ...], padding = tuple((np.array(psf_3D.shape)-1)//2)).squeeze(0,1) - result_panchro = torch.nn.functional.conv2d(self.panchro[None, None, ...], torch.flip(self.optical_model.psf, (0,1))[None, None, ...], padding = tuple((np.array(self.optical_model.psf.shape)-1)//2)).squeeze(0,1) - - else: - print("No PSF or last measurement to apply PSF") - result = self.last_filtered_interpolated_scene - result_panchro = self.panchro - - self.last_filtered_interpolated_scene = result - self.panchro = result_panchro - - return self.last_filtered_interpolated_scene \ No newline at end of file diff --git a/simca/CassiSystem_lightning.py b/simca/CassiSystem_lightning.py index ff1e1ce..7cf6b2e 100755 --- a/simca/CassiSystem_lightning.py +++ b/simca/CassiSystem_lightning.py @@ -234,9 +234,6 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals self.X_coded_aper_coordinates = X_coded_aper_coordinates_crop self.Y_coded_aper_coordinates = Y_coded_aper_coordinates_crop - # print("dataset shape: ", dataset.shape) - # print("X coded shape: ", X_coded_aper_coordinates_crop.shape) - scene = match_dataset_to_instrument(dataset, X_coded_aper_coordinates_crop) @@ -246,32 +243,9 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals pattern_crop = pattern_crop.unsqueeze(-1).repeat(1, 1, 1, scene.size(-1)) - #print(scene.get_device()) - #print(pattern_crop.get_device()) - - plt.imshow(scene[0,:,:,0].cpu().numpy()) - plt.title("scene") - plt.show() - # filtered_scene = scene * pattern_crop[..., None].repeat((1, 1, scene.shape[2])) - # print(f"scene: {scene.shape}") - # print(f"pattern_crop: {pattern_crop.shape}") filtered_scene = scene * pattern_crop - - # print(f"filtered_scene: {filtered_scene.shape}") - - - plt.imshow(pattern_crop[0,:,:,0].cpu().numpy()) - plt.title("Pattern crop1") - plt.show() - - - plt.imshow(filtered_scene[0,:,:,0].cpu().numpy()) - plt.title("Filtered scene") - plt.show() - - #print("filtered_scene",filtered_scene.shape) self.propagate_coded_aperture_grid() @@ -372,25 +346,8 @@ def generate_custom_slit_pattern_width(self, start_pattern = "line", start_posit # Set the position of the slit (j,i) bottom_pad = (nb_rows - i-1)*height_slits + self.system_config["coded aperture"]["number of pixels along Y"] % nb_rows # Padding necessary below slit (j,i) top_pad = i*height_slits - """ array_x_pos = torch.tensor(start_position[i]) - - # Create a grid to represent positions - grid_positions = torch.arange(self.empty_grid.shape[1], dtype=torch.float32) - # Expand dimensions for broadcasting - expanded_x_positions = (array_x_pos.unsqueeze(-1)) * (self.empty_grid.shape[1]-1) - expanded_grid_positions = grid_positions.unsqueeze(0) - - # Apply Gaussian-like function - sigma = (self.array_x_positions[j,i]+1)/2 - gaussian = torch.exp(-(((expanded_grid_positions - expanded_x_positions)) ** 2) / (2 * sigma ** 2)) - - padded = torch.nn.functional.pad(gaussian, (0,0,top_pad,bottom_pad)) # padding: left - right - top - bottom - - # Normalize to make sure the maximum value is 1 - self.pattern = self.pattern + padded/padded.max() """ c = start_position[i].clone().detach() # center of the slit - #d = ((torch.tanh(1.1*self.array_x_positions[j,i])+1)/2)/2 # width of the slit at pos d = self.array_x_positions[j,i]/2 # width of the slit at pos m = (c-d)*(self.system_config["coded aperture"]["number of pixels along X"]-1) # left bound M = (c+d)*(self.system_config["coded aperture"]["number of pixels along X"]-1) # right bound @@ -408,7 +365,6 @@ def generate_custom_slit_pattern_width(self, start_pattern = "line", start_posit rect = clamp_M - clamp_m +1 rect = torch.where(rect!=2, rect, 0) rect = torch.where(rect <= 1, rect, rect-1) - #rect = torch.clamp(-(rect-m)*(rect-M)+1,0,1).to(self.device) gaussian_range = torch.arange(self.system_config["coded aperture"]["number of pixels along X"], dtype=torch.float32) center_pos = 0.5*(len(gaussian_range)-1) @@ -443,26 +399,6 @@ def generate_custom_slit_pattern(self): return self.pattern - - # def generate_custom_slit_pattern(self): - # """ - # Generate a custom slit pattern - - # Args: - # array_x_positions (numpy.ndarray): array of the x positions of the slits between -1 and 1 - - # Returns: - # numpy.ndarray: generated slit pattern - # """ - # pattern = torch.clone(self.empty_grid) - # self.array_x_positions += 1 - # self.array_x_positions *= self.empty_grid.shape[1] // 2 - # self.array_x_positions = self.array_x_positions.type(torch.int32) - # for i in range(self.array_x_positions.shape[0]): - # pattern[0, self.array_x_positions[i]] = 1 - - # return self.pattern - def interpolate_dataset_along_wavelengths_torch(self, hyperspectral_cube, wavelengths, new_wavelengths_sampling, chunk_size): """ @@ -481,8 +417,6 @@ def interpolate_dataset_along_wavelengths_torch(self, hyperspectral_cube, wavele except: self.dataset = hyperspectral_cube self.dataset_wavelengths = wavelengths - #print(self.dataset.shape) - #print(self.dataset_wavelengths.shape) self.dataset = hyperspectral_cube self.dataset_wavelengths = wavelengths diff --git a/simca/configs/cassi_system_optim_optics_full_triplet_prev.yml b/simca/configs/cassi_system_optim_optics_full_triplet_prev.yml new file mode 100755 index 0000000..d51f4ab --- /dev/null +++ b/simca/configs/cassi_system_optim_optics_full_triplet_prev.yml @@ -0,0 +1,48 @@ +##### Configuration file for the chosen optical system + +infos: + system name: HYACAMEO + +system architecture: + system type: DD-CASSI + propagation type: simca + focal lens: 50000 + dispersive element: + # dispersive element caracteristics + type: tripleprism # name of the dispersive element + glass1: P-SK60 # glass type of the dispersive element (only used if type == 'prism') + glass2: SF4 # glass type of the dispersive element (only used if type == 'prism') + glass3: P-SK60 # glass type of the dispersive element (only used if type == 'prism') + A1: 21.5 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees + A2: 43.0 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees + A3: 21.5 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees + nd1: 1.6074 + nd2: 1.7552 + nd3: 1.6074 + vd1: 56.65 + vd2: 27.58 + vd3: 56.65 + continuous glass materials 1: False + continuous glass materials 2: False + continuous glass materials 3: False + m: 1 # grating order to consider (only used if type == 'grating') -- no units + G: 30 # grating density (only used if type == 'grating') -- in lines/mm + alpha_c: 2.6 + delta alpha c: 0 + delta beta c: 0 + wavelength center: 600 # central wavelength -- in nm +detector: + number of pixels along X: 145 # number of pixels along X axis -- no units + number of pixels along Y: 301 # number of pixels along Y axis -- no units + pixel size along X: 15 # pixel size along X -- in um + pixel size along Y: 50 # pixel size along Y -- in um +coded aperture: + number of pixels along X: 145 # 151 # number of pixels along X axis -- no units + number of pixels along Y: 301 # 151 # number of pixels along Y axis -- no units + pixel size along X: 10 # pixel size along X -- in um + pixel size along Y: 50 # pixel size along Y -- in um + +spectral range: + wavelength min: 410 # minimum wavelength -- in nm + wavelength max: 1050 # maximum wavelength -- in nm + number of spectral samples: 251 diff --git a/simca/configs/cassi_system_satur_test.yml b/simca/configs/cassi_system_satur_test.yml new file mode 100755 index 0000000..e6b6a06 --- /dev/null +++ b/simca/configs/cassi_system_satur_test.yml @@ -0,0 +1,48 @@ +##### Configuration file for the chosen optical system + +infos: + system name: HYACAMEO + +system architecture: + system type: DD-CASSI + propagation type: simca + focal lens: 50000 + dispersive element: + # dispersive element caracteristics + type: tripleprism # name of the dispersive element + glass1: P-SK60 # glass type of the dispersive element (only used if type == 'prism') + glass2: SF4 # glass type of the dispersive element (only used if type == 'prism') + glass3: P-SK60 # glass type of the dispersive element (only used if type == 'prism') + A1: 21.5 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees + A2: 43.0 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees + A3: 21.5 # apex angle of the prism in degrees (only used if type == 'prism') -- in degrees + nd1: 1.6074 + nd2: 1.7552 + nd3: 1.6074 + vd1: 56.65 + vd2: 27.58 + vd3: 56.65 + continuous glass materials 1: False + continuous glass materials 2: False + continuous glass materials 3: False + m: 1 # grating order to consider (only used if type == 'grating') -- no units + G: 30 # grating density (only used if type == 'grating') -- in lines/mm + alpha_c: 2.6 + delta alpha c: 0 + delta beta c: 0 + wavelength center: 600 # central wavelength -- in nm +detector: + number of pixels along X: 51 # number of pixels along X axis -- no units + number of pixels along Y: 51 # number of pixels along Y axis -- no units + pixel size along X: 15 # pixel size along X -- in um + pixel size along Y: 15 # pixel size along Y -- in um +coded aperture: + number of pixels along X: 51 # 151 # number of pixels along X axis -- no units + number of pixels along Y: 51 # 151 # number of pixels along Y axis -- no units + pixel size along X: 15 # pixel size along X -- in um + pixel size along Y: 15 # pixel size along Y -- in um + +spectral range: + wavelength min: 410 # minimum wavelength -- in nm + wavelength max: 1050 # maximum wavelength -- in nm + number of spectral samples: 251 diff --git a/simca/cost_functions.py b/simca/cost_functions.py index 07ac8e6..2b3e245 100644 --- a/simca/cost_functions.py +++ b/simca/cost_functions.py @@ -15,21 +15,11 @@ def evaluate_slit_scanning_straightness(filtering_cube, device, sigma = 0.75, po gaussian = torch.exp(-torch.square(gaussian)/(2*sigma**2)).unsqueeze(0).to(device) gaussian = gaussian w = filtering_cube.shape[2]//2 - """for i in range(filtering_cube.shape[2]): - vertical_binning = torch.sum(filtering_cube[:, :, i], axis=0) - #max_value = torch.max(vertical_binning) - std_deviation_vertical = torch.std(vertical_binning) - #std_deviation_horizontal = torch.std(torch.sum(filtering_cube[:,:,i], axis=1)) - # Reward the max value and penalize based on the standard deviation - # Calculate the differences between consecutive rows (vectorized) - row_diffs = filtering_cube[1:, :, i] - filtering_cube[:-1, :, i] - #cost_value = cost_value + max_value / std_deviation - torch.sum(torch.sum(torch.abs(row_diffs))) - cost_value = cost_value + std_deviation_vertical - 0.2*torch.sum(torch.sum(torch.square(row_diffs))) #- 0.2*torch.sum(torch.sum(torch.abs(row_diffs)))""" + # Minimize the smile distorsion at the central wavelength row_diffs = filtering_cube[1:, :, w] - filtering_cube[:-1, :, w] cost_value = cost_value - torch.sum((torch.abs(filtering_cube[:, :, w] - gaussian)+1e-8)**0.4) + 0.6*torch.sum(filtering_cube[:, pos_cube,w]) - 0.8*torch.sum(torch.sum(torch.abs(row_diffs))) #- 2*torch.sum(torch.var(filtering_cube[:, :, 0], dim=0)) - #delta = 2 - #cost_value = cost_value - (delta**2)*torch.sum((torch.sqrt(1+((filtering_cube[:, :, 0] - gaussian)/delta)**2)-1)) # pseudo-huber loss + # Minimizing the negative of cost_value to maximize the original objective return -cost_value @@ -45,7 +35,6 @@ def evaluate_mean_lighting(acquisition): """ Evaluate the mean and std values of the acquisition, maximizing mean and minimizing std """ - #cost_value = torch.mean(acquisition)/(torch.std(acquisition)+1e-9) cost_value = 2*torch.mean(acquisition) - 8*torch.std(acquisition) return -cost_value @@ -53,47 +42,9 @@ def evaluate_mean_lighting(acquisition): def evaluate_max_lighting(widths, acquisition, target): cost_value = 0 - #col = acquisition[:, pos_cube] - #col = acquisition[:, pos_cube-2:pos_cube+2] - col = acquisition[acquisition>100].flatten() - #col = acquisition[93:208, 30].unsqueeze(1) - """for i in range(1, 5): - col = torch.cat((col, acquisition[93:208, 30+i*15].unsqueeze(1)), 1) """ - - #cost_value = 2*torch.mean(col)**2 - 25*torch.var(col) - #cost_value = - torch.var(col) - """ cost_value = 15*torch.mean(col)**2 - 25*torch.var(col) - cost_value = 8000*torch.mean(col)**2 - torch.sum((col-10000)**2) - cost_value = 0.75*torch.mean(col) - torch.mean(torch.abs(col-6000)) - cost_value = torch.mean(col) - 2*torch.std(col) - - lines = torch.mean(acquisition, axis=1) - - #cost_value = - torch.var(torch.log(col)) - - cost_value = - torch.var((torch.log(col)- torch.log(torch.tensor([40000])))**2) - cost_value = - torch.var(torch.log(col)) - torch.log(torch.var(col))# - torch.mean((torch.log(col)- torch.log(torch.tensor([14000])))**2) - #cost_value = - torch.var(torch.log(col)**2) - torch.var((torch.log(col)- torch.log(torch.tensor([20000])))**2) - cost_value = - torch.var(torch.log(col)**2) - 2*torch.var((torch.log(col)- torch.log(torch.tensor([6000])))**2) - #cost_value = - torch.sum((2000*10000*((col-2000)+(col-10000)) - (10000-2000))/2) - #cost_value = -torch.var(torch.log(col)) """ - #cost_value = -torch.var(torch.exp(col/11000)) - row_diffs = torch.abs(widths[0,1:] - widths[0,:-1]) - #cost_value = - torch.var(torch.exp(col/18100)) - torch.sum(-torch.log(1+row_diffs)) - #print(torch.var(torch.exp(col/20000))) - #print(torch.mean(torch.log(col))) - #print(torch.sum(-torch.log(1+row_diffs))) - - - def saturation(scene, target_, margin=0.05): - cost = 0 - for elem in scene: - if elem <= target_*(1+margin): - cost += (target_-elem) - else: - cost += (elem-target_)**3 - return cost + acq = acquisition[acquisition>100].flatten() + # Squared loss on the left, log-barrier on the right def bowl(scene, target_, saturation=None): if saturation is None: saturation = target_*1.2 @@ -108,12 +59,11 @@ def bowl(scene, target_, saturation=None): B = -1/(s-t) + t/((s-t)**2) - 2/t C = - 1/2*A*t - B*t cost += - math.log((s-x)/(s-t)) + 1/2*A*x**2 + B*x + C - #cost += - math.exp(80*(1-(saturation-elem)/(saturation-target_))) - #cost += (elem-target_)**4 else: cost += 1e18*x**2 return cost + # Squared loss on the left, inverse on the right def bowl_inverse(scene, target_, saturation=None): if saturation is None: saturation = target_*1.2 @@ -128,64 +78,15 @@ def bowl_inverse(scene, target_, saturation=None): B = -1/((s-t)**2) - A*t C = -1/(s-t) - 1/2*A*t - B*t cost += - 1/(s-x) + 1/2*A*x**2 + B*x + C - #cost += - math.exp(80*(1-(saturation-elem)/(saturation-target_))) - #cost += (elem-target_)**4 else: cost += 1e18*x**2 return cost + print("Var: ", torch.var(acq)) + print("Min: ", torch.min(acq)) + print("Mean: ", torch.mean(acq)) + print("Max: ", torch.max(acq)) - #cost_value = - torch.var(col) - torch.sum(-torch.log(1+row_diffs)) - #cost_value = - torch.var(torch.exp(col/18000)) - 10000*torch.sum(torch.log(1/(1+row_diffs))) - #cost_value = - torch.var(torch.exp(col/30000))# - 2*torch.count_nonzero(row_diffs) - #cost_value = - torch.var(torch.exp(col/20000)) - 40*torch.count_nonzero(row_diffs) - #cost_value = torch.mean(col) - #cost_value = - 2*torch.var((torch.exp(col/20000)- torch.exp(torch.tensor([9000])/20000))**2) - #cost_value = - saturation(col, 45000, margin=0.1) - - #print("Jumps: ", torch.count_nonzero(row_diffs)) - print("Var: ", torch.var(col)) - #print("Saturation: ", - saturation(col.flatten(), 120000, margin=0.1)) - print("Min: ", torch.min(col)) - print("Mean: ", torch.mean(col)) - print("Max: ", torch.max(col)) - - #cost_value = - torch.var(torch.exp(col/100000)) #- 1e-5*saturation(col.flatten(), 200000, margin=0.1) - #cost_value = - torch.var((torch.log(col.squeeze())- torch.log(torch.tensor([200000]).to('cuda')))**2) - #cost_value = - torch.var(col) #- saturation(col.flatten(), 120000, margin=0.1) - cost_value = - bowl(col, target, saturation=2.2e6) + cost_value = - bowl(acq, target, saturation=2.2e6) print("Cost: ", - cost_value) return - cost_value - -# def evalute_slit_scanning_straightness(filtering_cube,threshold): -# """ -# Evaluate the straightness of the slit scanning. -# working cost function for up to focal >100000 -# """ -# cost_value = torch.tensor(0.0, requires_grad=True) - -# for i in range(filtering_cube.shape[2]): -# vertical_binning = torch.sum(filtering_cube[:, :, i], axis=0) -# max_value = torch.max(vertical_binning) -# values_above_threshold = vertical_binning > threshold -# number_of_values_above_threshold = torch.sum(values_above_threshold) -# cost_value = cost_value + max_value / number_of_values_above_threshold -# return -cost_value - - -# def evalute_slit_scanning_straightness(filtering_cube,threshold): -# """ -# Evaluate the straightness of the slit scanning. -# """ -# cost_value =0 - -# for i in range(filtering_cube.shape[2]): -# vertical_binning = torch.sum(filtering_cube[:,:,i],axis=0) -# max_value = torch.max(vertical_binning) -# # Differentiable way to count values above threshold -# values_above_threshold = vertical_binning > threshold -# number_of_values_above_threshold = torch.sum(values_above_threshold) -# cost_value += 1 / (number_of_values_above_threshold ** 2) - -# cost_value = -1*(cost_value) -# return cost_value \ No newline at end of file diff --git a/simca/functions_patterns_generation.py b/simca/functions_patterns_generation.py index bfe9468..246eadc 100644 --- a/simca/functions_patterns_generation.py +++ b/simca/functions_patterns_generation.py @@ -110,7 +110,7 @@ def generate_ln_orthogonal_pattern(size, W, N): return list_of_patterns -def generate_random_pattern(shape, ROM): +def generate_random_pattern(shape, ROM, fix_random_pattern=False): """ Generate a random pattern with a given rate of open/close mirrors @@ -121,9 +121,14 @@ def generate_random_pattern(shape, ROM): Returns: numpy.ndarray: random pattern """ - + if fix_random_pattern: + np.random.seed(0) + pattern = np.random.choice([0, 1], size=shape, p=[1 - ROM, ROM]) + if fix_random_pattern: + np.random.seed() + return pattern def generate_slit_pattern(shape, slit_position,slit_width): diff --git a/summarize_results.py b/summarize_results.py new file mode 100755 index 0000000..c8b2f1b --- /dev/null +++ b/summarize_results.py @@ -0,0 +1,63 @@ +from simca import load_yaml_config, save_config_file +import numpy as np +import os + +file_list = ["predict_results_recons.yml", "predict_results_full_learned_mask.yml", "predict_results_full_learned_mask_float.yml"] + +for file in file_list: + if not os.path.exists(f'./results/{file}'): + continue + dict = load_yaml_config(f'./results/{file}') + + ssim_array = [0 for i in range(20)] + rmse_array = [0 for i in range(20)] + psnr_array = [0 for i in range(20)] + + list_rmse = [] + list_ssim = [] + list_psnr = [] + + for key in dict.keys(): + list_ = dict[key] + val = np.array(list_) + if all(char.isdigit() for char in key[-2:]): + if 'SSIM' in key: + ssim_array[int(key[-2:])-1] = float(np.mean(val)) + list_ssim = list_ssim + list_ + elif 'RMSE' in key: + rmse_array[int(key[-2:])-1] = float(np.mean(val)) + list_rmse = list_rmse + list_ + elif 'PSNR' in key: + psnr_array[int(key[-2:])-1] = float(np.mean(val)) + list_psnr = list_psnr + list_ + else: + if 'SSIM' in key: + ssim_array[int(key[-1])-1] = float(np.mean(val)) + list_ssim = list_ssim + list_ + elif 'RMSE' in key: + rmse_array[int(key[-1])-1] = float(np.mean(val)) + list_rmse = list_rmse + list_ + elif 'PSNR' in key: + psnr_array[int(key[-1])-1] = float(np.mean(val)) + list_psnr = list_psnr + list_ + + array_list_ssim = np.array(list_ssim) + array_list_rmse = np.array(list_rmse) + array_list_psnr = np.array(list_psnr) + + array_list_ssim = array_list_ssim[np.nonzero(array_list_ssim)] + array_list_rmse = array_list_rmse[np.nonzero(array_list_rmse)] + array_list_psnr = array_list_psnr[np.nonzero(array_list_psnr)] + + res_dict = {'SSIM': ssim_array, + 'SSIM overall': float(np.mean(array_list_ssim)), + 'RMSE': rmse_array, + 'RMSE overall': float(np.mean(array_list_rmse)), + 'PSNR': psnr_array, + 'PSNR overall': float(np.mean(array_list_psnr))} + + print("SSIM std: ", np.std(np.array(list_ssim))) + print("RMSE std: ", np.std(np.array(list_rmse))) + print("PSNR std: ", np.std(np.array(list_psnr))) + + save_config_file("mean_"+file[:-4], res_dict, './results') \ No newline at end of file diff --git a/test_simca_reconstruction.py b/test_simca_reconstruction.py new file mode 100755 index 0000000..105f9c4 --- /dev/null +++ b/test_simca_reconstruction.py @@ -0,0 +1,51 @@ +import pytorch_lightning as pl +from data_handler import CubesDataModule +from optimization_modules import JointReconstructionModule_V1 +from pytorch_lightning.loggers import TensorBoardLogger +import torch +import datetime + + +predict_data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28_test" # Folder where the test dataset is + +predict_datamodule = CubesDataModule(predict_data_dir, batch_size=1, num_workers=5, augment=False) + +datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') + +name = "test_simca_reconstruction" +model_name = "dauhst_9" + +reconstruction_checkpoint = "./saved_checkpoints/best-checkpoint-recons-only.ckpt" + +log_dir = 'tb_logs' + +train = False +fix_random_pattern = False # Set to True to fix the random pattern to only learn reconstruction for a single fixed pattern +run_on_cpu = False # Set to True if you prefer to run it on cpu + +if not train: + name += '_predict' + +logger = TensorBoardLogger(log_dir, name=name) + + +reconstruction_module = JointReconstructionModule_V1(model_name,log_dir=log_dir+'/'+ name, + reconstruction_checkpoint=reconstruction_checkpoint, + fix_random_pattern=fix_random_pattern) + +max_epoch = 330 + +if (not run_on_cpu) and (torch.cuda.is_available()): + trainer = pl.Trainer( logger=logger, + accelerator="gpu", + max_epochs=max_epoch, + log_every_n_steps=1) +else: + trainer = pl.Trainer( logger=logger, + accelerator="cpu", + max_epochs=max_epoch, + log_every_n_steps=1) + + +reconstruction_module.eval() +trainer.predict(reconstruction_module, predict_datamodule) diff --git a/test_simca_reconstruction_full_binary.py b/test_simca_reconstruction_full_binary.py new file mode 100755 index 0000000..756328a --- /dev/null +++ b/test_simca_reconstruction_full_binary.py @@ -0,0 +1,55 @@ +import pytorch_lightning as pl +from data_handler import CubesDataModule +from optimization_modules_full import JointReconstructionModule_V2 +from pytorch_lightning.loggers import TensorBoardLogger +import torch +import datetime + + +predict_data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28_test" # Folder where the test dataset is + +predict_datamodule = CubesDataModule(predict_data_dir, batch_size=1, num_workers=5, augment=False) + +datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') + +name = "test_simca_reconstruction_full_binary" +model_name = "dauhst_9" + +reconstruction_checkpoint = "./saved_checkpoints/best-checkpoint-recons-only.ckpt" + +full_model_checkpoint = "./saved_checkpoints/best-checkpoint-full-binary.ckpt" + +mask_model = "learned_mask" + +log_dir = 'tb_logs' + +train = False +retrain_recons = False +run_on_cpu = False # Set to True if you prefer to run it on cpu + +logger = TensorBoardLogger(log_dir, name=name) + + +reconstruction_module = JointReconstructionModule_V2(model_name, + log_dir=log_dir+'/'+ name, + mask_model=mask_model, + reconstruction_checkpoint=reconstruction_checkpoint, + full_checkpoint=full_model_checkpoint, + train_reconstruction=retrain_recons) + + +max_epoch = 150 + +if (not run_on_cpu) and (torch.cuda.is_available()): + trainer = pl.Trainer( logger=logger, + accelerator="gpu", + max_epochs=max_epoch, + log_every_n_steps=1) +else: + trainer = pl.Trainer( logger=logger, + accelerator="cpu", + max_epochs=max_epoch, + log_every_n_steps=1) + +reconstruction_module.eval() +trainer.predict(reconstruction_module, predict_datamodule) diff --git a/test_simca_reconstruction_full_float.py b/test_simca_reconstruction_full_float.py new file mode 100755 index 0000000..f767025 --- /dev/null +++ b/test_simca_reconstruction_full_float.py @@ -0,0 +1,55 @@ +import pytorch_lightning as pl +from data_handler import CubesDataModule +from optimization_modules_full import JointReconstructionModule_V2 +from pytorch_lightning.loggers import TensorBoardLogger +import torch +import datetime + + +predict_data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28_test" # Folder where the test dataset is + +predict_datamodule = CubesDataModule(predict_data_dir, batch_size=1, num_workers=5, augment=False) + +datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') + +name = "test_simca_reconstruction_full_float" +model_name = "dauhst_9" + +reconstruction_checkpoint = "./saved_checkpoints/best-checkpoint-recons-only.ckpt" + +full_model_checkpoint = "./saved_checkpoints/best-checkpoint-full-float.ckpt" + +mask_model = "learned_mask_float" + +log_dir = 'tb_logs' + +train = False +retrain_recons = False +run_on_cpu = False # Set to True if you prefer to run it on cpu + +logger = TensorBoardLogger(log_dir, name=name) + + +reconstruction_module = JointReconstructionModule_V2(model_name, + log_dir=log_dir+'/'+ name, + mask_model=mask_model, + reconstruction_checkpoint=reconstruction_checkpoint, + full_checkpoint=full_model_checkpoint, + train_reconstruction=retrain_recons) + + +max_epoch = 150 + +if (not run_on_cpu) and (torch.cuda.is_available()): + trainer = pl.Trainer( logger=logger, + accelerator="gpu", + max_epochs=max_epoch, + log_every_n_steps=1) +else: + trainer = pl.Trainer( logger=logger, + accelerator="cpu", + max_epochs=max_epoch, + log_every_n_steps=1) + +reconstruction_module.eval() +trainer.predict(reconstruction_module, predict_datamodule) diff --git a/training_simca_reconstruction.py b/training_simca_reconstruction.py index 500ca70..2d0f5ee 100755 --- a/training_simca_reconstruction.py +++ b/training_simca_reconstruction.py @@ -7,25 +7,27 @@ import datetime -data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28" +data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28_train" # Folder where the train dataset is + datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=11) datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') -name = "testing_simca_reconstruction" +name = "training_simca_reconstruction" model_name = "dauhst_9" -reconstruction_checkpoint = "/home/lpaillet/Documents/simca/tb_logs/testing_simca_reconstruction/version_24/checkpoints/epoch=499-step=18000.ckpt" -reconstruction_checkpoint = None log_dir = 'tb_logs' train = True +fix_random_pattern = False # Set to True to fix the random pattern to only learn reconstruction for a single fixed pattern +run_on_cpu = False # Set to True if you prefer to run it on cpu + logger = TensorBoardLogger(log_dir, name=name) early_stop_callback = EarlyStopping( monitor='val_loss', # Metric to monitor - patience=40, # Number of epochs to wait for improvement + patience=500, # Number of epochs to wait for improvement verbose=True, mode='min' # 'min' for metrics where lower is better, 'max' for vice versa ) @@ -40,24 +42,23 @@ ) reconstruction_module = JointReconstructionModule_V1(model_name,log_dir=log_dir+'/'+ name, - reconstruction_checkpoint=reconstruction_checkpoint) + reconstruction_checkpoint=None, + fix_random_pattern=fix_random_pattern) +max_epoch = 330 -if torch.cuda.is_available(): +if (not run_on_cpu) and (torch.cuda.is_available()): trainer = pl.Trainer( logger=logger, accelerator="gpu", - max_epochs=500, + max_epochs=max_epoch, log_every_n_steps=1, callbacks=[early_stop_callback, checkpoint_callback]) else: trainer = pl.Trainer( logger=logger, accelerator="cpu", - max_epochs=500, + max_epochs=max_epoch, log_every_n_steps=1, callbacks=[early_stop_callback, checkpoint_callback]) -if train: - trainer.fit(reconstruction_module, datamodule) -else: - #trainer.predict(reconstruction_module, datamodule) - trainer.predict(reconstruction_module, datamodule, ckpt_path=reconstruction_checkpoint) + +trainer.fit(reconstruction_module, datamodule) diff --git a/training_simca_reconstruction_with_resnet.py b/training_simca_reconstruction_full_binary.py similarity index 55% rename from training_simca_reconstruction_with_resnet.py rename to training_simca_reconstruction_full_binary.py index cf1fda3..faff09b 100755 --- a/training_simca_reconstruction_with_resnet.py +++ b/training_simca_reconstruction_full_binary.py @@ -1,66 +1,72 @@ import pytorch_lightning as pl from data_handler import CubesDataModule -from optimization_modules_with_resnet import JointReconstructionModule_V2 +from optimization_modules_full import JointReconstructionModule_V2 from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger import torch import datetime -# data_dir = "./datasets_reconstruction/" -data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28" -data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28" -datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=5) +data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28_train" # Folder where the train dataset is + +datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=5, augment=True) datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') -name = "testing_simca_reconstruction_full" +name = "training_simca_reconstruction_full_binary" model_name = "dauhst_9" -reconstruction_checkpoint = "/home/lpaillet/Documents/simca/tb_logs/testing_simca_reconstruction/version_24/checkpoints/epoch=499-step=18000.ckpt" + +reconstruction_checkpoint = "./saved_checkpoints/best-checkpoint-recons-only.ckpt" + +mask_model = "learned_mask" log_dir = 'tb_logs' train = True +retrain_recons = True # Set to False if you don't want to fine-tune the reconstruction network +run_on_cpu = False # Set to True if you prefer to run it on cpu logger = TensorBoardLogger(log_dir, name=name) + early_stop_callback = EarlyStopping( monitor='val_loss', # Metric to monitor - patience=40, # Number of epochs to wait for improvement + patience=500, # Number of epochs to wait for improvement verbose=True, mode='min' # 'min' for metrics where lower is better, 'max' for vice versa ) checkpoint_callback = ModelCheckpoint( - monitor='val_loss', # Metric to monitor - dirpath='checkpoints_with_resnet/', # Directory path for saving checkpoints + monitor='val_ssim_loss', # Metric to monitor + dirpath='checkpoints_full_binary/', # Directory path for saving checkpoints filename=f'best-checkpoint_{model_name}_{datetime_}', # Checkpoint file name save_top_k=1, # Save the top k models - mode='min', # 'min' for metrics where lower is better, 'max' for vice versa + mode='max', # 'min' for metrics where lower is better, 'max' for vice versa save_last=True # Additionally, save the last checkpoint to a file named 'last.ckpt' ) + reconstruction_module = JointReconstructionModule_V2(model_name, log_dir=log_dir+'/'+ name, - reconstruction_checkpoint = reconstruction_checkpoint) + mask_model=mask_model, + reconstruction_checkpoint=reconstruction_checkpoint, + full_checkpoint=None, + train_reconstruction=retrain_recons) + +max_epoch = 150 -if torch.cuda.is_available(): +if (not run_on_cpu) and (torch.cuda.is_available()): trainer = pl.Trainer( logger=logger, accelerator="gpu", - max_epochs=500, + max_epochs=max_epoch, log_every_n_steps=30, callbacks=[early_stop_callback, checkpoint_callback]) else: trainer = pl.Trainer( logger=logger, - accelerator="gpu", - max_epochs=500, + accelerator="cpu", + max_epochs=max_epoch, log_every_n_steps=30, callbacks=[early_stop_callback, checkpoint_callback]) -if train: - #reconstruction_checkpoint = None - #trainer.fit(reconstruction_module, datamodule, ckpt_path = reconstruction_checkpoint) - trainer.fit(reconstruction_module, datamodule) -else: - trainer.predict(reconstruction_module, datamodule, ckpt_path=reconstruction_checkpoint) \ No newline at end of file +trainer.fit(reconstruction_module, datamodule) diff --git a/training_simca_reconstruction_full_float.py b/training_simca_reconstruction_full_float.py new file mode 100755 index 0000000..03e3403 --- /dev/null +++ b/training_simca_reconstruction_full_float.py @@ -0,0 +1,72 @@ +import pytorch_lightning as pl +from data_handler import CubesDataModule +from optimization_modules_full import JointReconstructionModule_V2 +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint +from pytorch_lightning.loggers import TensorBoardLogger +import torch +import datetime + + +data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28_train" # Folder where the train dataset is + +datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=5, augment=True) + +datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') + +name = "training_simca_reconstruction_full_float" +model_name = "dauhst_9" + +reconstruction_checkpoint = "./saved_checkpoints/best-checkpoint-recons-only.ckpt" + +mask_model = "learned_mask_float" + +log_dir = 'tb_logs' + +train = True +retrain_recons = True # Set to False if you don't want to fine-tune the reconstruction network +run_on_cpu = False # Set to True if you prefer to run it on cpu + +logger = TensorBoardLogger(log_dir, name=name) + + +early_stop_callback = EarlyStopping( + monitor='val_loss', # Metric to monitor + patience=500, # Number of epochs to wait for improvement + verbose=True, + mode='min' # 'min' for metrics where lower is better, 'max' for vice versa + ) + +checkpoint_callback = ModelCheckpoint( + monitor='val_ssim_loss', # Metric to monitor + dirpath='checkpoints_full_float/', # Directory path for saving checkpoints + filename=f'best-checkpoint_{model_name}_{datetime_}', # Checkpoint file name + save_top_k=1, # Save the top k models + mode='max', # 'min' for metrics where lower is better, 'max' for vice versa + save_last=True # Additionally, save the last checkpoint to a file named 'last.ckpt' +) + + +reconstruction_module = JointReconstructionModule_V2(model_name, + log_dir=log_dir+'/'+ name, + mask_model=mask_model, + reconstruction_checkpoint=reconstruction_checkpoint, + full_checkpoint=None, + train_reconstruction=retrain_recons) + + +max_epoch = 150 + +if (not run_on_cpu) and (torch.cuda.is_available()): + trainer = pl.Trainer( logger=logger, + accelerator="gpu", + max_epochs=max_epoch, + log_every_n_steps=30, + callbacks=[early_stop_callback, checkpoint_callback]) +else: + trainer = pl.Trainer( logger=logger, + accelerator="cpu", + max_epochs=max_epoch, + log_every_n_steps=30, + callbacks=[early_stop_callback, checkpoint_callback]) + +trainer.fit(reconstruction_module, datamodule) diff --git a/training_simca_reconstruction_with_resnet_v2.py b/training_simca_reconstruction_with_resnet_v2.py index 22148eb..ab72015 100755 --- a/training_simca_reconstruction_with_resnet_v2.py +++ b/training_simca_reconstruction_with_resnet_v2.py @@ -1,7 +1,7 @@ import pytorch_lightning as pl from data_handler import CubesDataModule from optimization_modules import JointReconstructionModule_V1 -from optimization_modules_with_resnet import JointReconstructionModule_V2 +from optimization_modules_full import JointReconstructionModule_V2 from optimization_modules_with_resnet_v2 import JointReconstructionModule_V3 from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger @@ -10,37 +10,57 @@ # data_dir = "./datasets_reconstruction/" -data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28" -# data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28" +#data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28" +data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28_train" +predict_data_dir = "./datasets_reconstruction/mst_datasets/TSA_simu_data/Truth" + datamodule = CubesDataModule(data_dir, batch_size=4, num_workers=5) +predict_datamodule = CubesDataModule(predict_data_dir, batch_size=1, num_workers=5, augment=False) datetime_ = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') name = "testing_simca_reconstruction_full" model_name = "dauhst_9" reconstruction_checkpoint = "./checkpoints/epoch=499-step=18000.ckpt" -resnet_checkpoint = None +reconstruction_checkpoint = "./checkpoints/best-checkpoint_dauhst_9_24-03-10_19h55.ckpt" + +resnet_checkpoint = "./checkpoints_with_resnet/best-checkpoint_resnet_only_24-03-09_18h05.ckpt" +full_model_checkpoint = "./checkpoints_with_resnet/best-checkpoint_dauhst_9_24-03-09_19h56.ckpt" +full_model_checkpoint = "./checkpoints_with_resnet/best-checkpoint_dauhst_9_24-03-10_12h26.ckpt" # learned_mask +#full_model_checkpoint = "./checkpoints_with_resnet/best-checkpoint_dauhst_9_24-03-10_13h25.ckpt" # learned_mask_float +#full_model_checkpoint = "./checkpoints_with_resnet/best-checkpoint_dauhst_9_24-03-10_15h40.ckpt" # learned_mask_float + +mask_model = "learned_mask" log_dir = 'tb_logs' train = True retrain_recons = True +if mask_model == "learned_mask_float": + name += '_float' +elif mask_model == 'learned_mask': + name += '_binary' + +if not train: + name += '_predict' + logger = TensorBoardLogger(log_dir, name=name) + early_stop_callback = EarlyStopping( monitor='val_loss', # Metric to monitor - patience=40, # Number of epochs to wait for improvement + patience=4000000, # Number of epochs to wait for improvement verbose=True, mode='min' # 'min' for metrics where lower is better, 'max' for vice versa ) checkpoint_callback = ModelCheckpoint( - monitor='val_loss', # Metric to monitor + monitor='val_ssim_loss', # Metric to monitor dirpath='checkpoints_with_resnet/', # Directory path for saving checkpoints filename=f'best-checkpoint_{model_name}_{datetime_}', # Checkpoint file name save_top_k=1, # Save the top k models - mode='min', # 'min' for metrics where lower is better, 'max' for vice versa + mode='max', # 'min' for metrics where lower is better, 'max' for vice versa save_last=True # Additionally, save the last checkpoint to a file named 'last.ckpt' ) @@ -49,15 +69,13 @@ sub_module = JointReconstructionModule_V1(model_name, log_dir) sub_module.load_state_dict(checkpoint["state_dict"]) - -resnet_checkpoint = "./checkpoints/best-checkpoint_resnet_only_24-03-09_18h05.ckpt" - if not retrain_recons or not train: sub_module.eval() reconstruction_module = JointReconstructionModule_V3(sub_module, log_dir=log_dir+'/'+ name, - resnet_checkpoint=resnet_checkpoint) + resnet_checkpoint=resnet_checkpoint, + mask_model=mask_model) if torch.cuda.is_available(): @@ -68,7 +86,7 @@ callbacks=[early_stop_callback, checkpoint_callback]) else: trainer = pl.Trainer( logger=logger, - accelerator="gpu", + accelerator="cpu", max_epochs=500, log_every_n_steps=30, callbacks=[early_stop_callback, checkpoint_callback]) @@ -76,4 +94,4 @@ if train: trainer.fit(reconstruction_module, datamodule) else: - trainer.predict(reconstruction_module, datamodule, ckpt_path=resnet_checkpoint) + trainer.predict(reconstruction_module, predict_datamodule, ckpt_path=full_model_checkpoint)