diff --git a/.idea/.gitignore b/.idea/.gitignore deleted file mode 100644 index 26d3352..0000000 --- a/.idea/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -# Default ignored files -/shelf/ -/workspace.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml deleted file mode 100644 index 7d747dc..0000000 --- a/.idea/inspectionProfiles/Project_Default.xml +++ /dev/null @@ -1,34 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml deleted file mode 100644 index 105ce2d..0000000 --- a/.idea/inspectionProfiles/profiles_settings.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml deleted file mode 100644 index c288038..0000000 --- a/.idea/misc.xml +++ /dev/null @@ -1,7 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml deleted file mode 100644 index eeea93f..0000000 --- a/.idea/modules.xml +++ /dev/null @@ -1,8 +0,0 @@ - - - - - - - - \ No newline at end of file diff --git a/.idea/pyHFO.iml b/.idea/pyHFO.iml deleted file mode 100644 index b157454..0000000 --- a/.idea/pyHFO.iml +++ /dev/null @@ -1,12 +0,0 @@ - - - - - - - - - - \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml deleted file mode 100644 index 35eb1dd..0000000 --- a/.idea/vcs.xml +++ /dev/null @@ -1,6 +0,0 @@ - - - - - - \ No newline at end of file diff --git a/ckpt/model_toy/config.json b/ckpt/model_toy/config.json new file mode 100644 index 0000000..c0a4b11 --- /dev/null +++ b/ckpt/model_toy/config.json @@ -0,0 +1,16 @@ +{ + "architectures": [ + "NeuralCNNForImageClassification" + ], + "channel_selection": true, + "freeze": false, + "hidden_size": 64, + "input_channels": 1, + "kernel_size": 7, + "model_type": "resnet", + "num_classes": 1, + "padding": 3, + "stride": 2, + "torch_dtype": "float32", + "transformers_version": "4.41.2" +} diff --git a/ckpt/model_toy/model.safetensors b/ckpt/model_toy/model.safetensors new file mode 100644 index 0000000..2c6259d Binary files /dev/null and b/ckpt/model_toy/model.safetensors differ diff --git a/ckpt/model_toy/optimizer.pt b/ckpt/model_toy/optimizer.pt new file mode 100644 index 0000000..f6e3f21 Binary files /dev/null and b/ckpt/model_toy/optimizer.pt differ diff --git a/ckpt/model_toy/rng_state.pth b/ckpt/model_toy/rng_state.pth new file mode 100644 index 0000000..7248786 Binary files /dev/null and b/ckpt/model_toy/rng_state.pth differ diff --git a/ckpt/model_toy/scheduler.pt b/ckpt/model_toy/scheduler.pt new file mode 100644 index 0000000..26345ed Binary files /dev/null and b/ckpt/model_toy/scheduler.pt differ diff --git a/ckpt/model_toy/trainer_state.json b/ckpt/model_toy/trainer_state.json new file mode 100644 index 0000000..7ce37f4 --- /dev/null +++ b/ckpt/model_toy/trainer_state.json @@ -0,0 +1,333 @@ +{ + "best_metric": 0.20507247745990753, + "best_model_checkpoint": "./result/ucla_10min_win285_freq10_300_shift50/ckpt/fold_0/checkpoint-160", + "epoch": 30.0, + "eval_steps": 500, + "global_step": 240, + "is_hyper_param_search": false, + "is_local_process_zero": true, + "is_world_process_zero": true, + "log_history": [ + { + "epoch": 1.0, + "eval_accuracy": 0.850609756097561, + "eval_f1": 0.8993839835728953, + "eval_loss": 0.5683605670928955, + "eval_runtime": 1.1816, + "eval_samples_per_second": 277.58, + "eval_steps_per_second": 2.539, + "step": 8 + }, + { + "epoch": 2.0, + "eval_accuracy": 0.9176829268292683, + "eval_f1": 0.9409190371991247, + "eval_loss": 0.41057729721069336, + "eval_runtime": 1.0987, + "eval_samples_per_second": 298.529, + "eval_steps_per_second": 2.73, + "step": 16 + }, + { + "epoch": 3.0, + "eval_accuracy": 0.9603658536585366, + "eval_f1": 0.9706546275395034, + "eval_loss": 0.3411404490470886, + "eval_runtime": 1.1038, + "eval_samples_per_second": 297.161, + "eval_steps_per_second": 2.718, + "step": 24 + }, + { + "epoch": 4.0, + "eval_accuracy": 0.9603658536585366, + "eval_f1": 0.9699769053117783, + "eval_loss": 0.29127684235572815, + "eval_runtime": 1.2175, + "eval_samples_per_second": 269.41, + "eval_steps_per_second": 2.464, + "step": 32 + }, + { + "epoch": 5.0, + "eval_accuracy": 0.9573170731707317, + "eval_f1": 0.9674418604651163, + "eval_loss": 0.2733766436576843, + "eval_runtime": 1.1854, + "eval_samples_per_second": 276.7, + "eval_steps_per_second": 2.531, + "step": 40 + }, + { + "epoch": 6.0, + "eval_accuracy": 0.9390243902439024, + "eval_f1": 0.9521531100478469, + "eval_loss": 0.28573527932167053, + "eval_runtime": 1.1529, + "eval_samples_per_second": 284.499, + "eval_steps_per_second": 2.602, + "step": 48 + }, + { + "epoch": 7.0, + "eval_accuracy": 0.948170731707317, + "eval_f1": 0.96, + "eval_loss": 0.26964008808135986, + "eval_runtime": 1.0965, + "eval_samples_per_second": 299.144, + "eval_steps_per_second": 2.736, + "step": 56 + }, + { + "epoch": 8.0, + "eval_accuracy": 0.9664634146341463, + "eval_f1": 0.974477958236659, + "eval_loss": 0.2196166068315506, + "eval_runtime": 1.0768, + "eval_samples_per_second": 304.611, + "eval_steps_per_second": 2.786, + "step": 64 + }, + { + "epoch": 9.0, + "eval_accuracy": 0.9725609756097561, + "eval_f1": 0.9793103448275862, + "eval_loss": 0.21076323091983795, + "eval_runtime": 1.0639, + "eval_samples_per_second": 308.285, + "eval_steps_per_second": 2.82, + "step": 72 + }, + { + "epoch": 10.0, + "eval_accuracy": 0.9573170731707317, + "eval_f1": 0.9681818181818181, + "eval_loss": 0.21153110265731812, + "eval_runtime": 1.0957, + "eval_samples_per_second": 299.354, + "eval_steps_per_second": 2.738, + "step": 80 + }, + { + "epoch": 11.0, + "eval_accuracy": 0.9512195121951219, + "eval_f1": 0.9629629629629629, + "eval_loss": 0.2417690008878708, + "eval_runtime": 1.0774, + "eval_samples_per_second": 304.425, + "eval_steps_per_second": 2.784, + "step": 88 + }, + { + "epoch": 12.0, + "eval_accuracy": 0.9542682926829268, + "eval_f1": 0.9651972157772621, + "eval_loss": 0.22638092935085297, + "eval_runtime": 1.1599, + "eval_samples_per_second": 282.786, + "eval_steps_per_second": 2.586, + "step": 96 + }, + { + "epoch": 13.0, + "eval_accuracy": 0.9664634146341463, + "eval_f1": 0.9747126436781609, + "eval_loss": 0.21275581419467926, + "eval_runtime": 1.1752, + "eval_samples_per_second": 279.101, + "eval_steps_per_second": 2.553, + "step": 104 + }, + { + "epoch": 14.0, + "eval_accuracy": 0.9634146341463414, + "eval_f1": 0.9724770642201835, + "eval_loss": 0.2107168734073639, + "eval_runtime": 1.1628, + "eval_samples_per_second": 282.081, + "eval_steps_per_second": 2.58, + "step": 112 + }, + { + "epoch": 15.0, + "eval_accuracy": 0.9573170731707317, + "eval_f1": 0.967741935483871, + "eval_loss": 0.21056009829044342, + "eval_runtime": 1.0834, + "eval_samples_per_second": 302.739, + "eval_steps_per_second": 2.769, + "step": 120 + }, + { + "epoch": 16.0, + "eval_accuracy": 0.9603658536585366, + "eval_f1": 0.9702517162471396, + "eval_loss": 0.20577265322208405, + "eval_runtime": 1.1045, + "eval_samples_per_second": 296.956, + "eval_steps_per_second": 2.716, + "step": 128 + }, + { + "epoch": 17.0, + "eval_accuracy": 0.9573170731707317, + "eval_f1": 0.967741935483871, + "eval_loss": 0.2126343846321106, + "eval_runtime": 1.1305, + "eval_samples_per_second": 290.138, + "eval_steps_per_second": 2.654, + "step": 136 + }, + { + "epoch": 18.0, + "eval_accuracy": 0.9512195121951219, + "eval_f1": 0.9627906976744186, + "eval_loss": 0.21749918162822723, + "eval_runtime": 1.0912, + "eval_samples_per_second": 300.591, + "eval_steps_per_second": 2.749, + "step": 144 + }, + { + "epoch": 19.0, + "eval_accuracy": 0.948170731707317, + "eval_f1": 0.9605568445475638, + "eval_loss": 0.20949558913707733, + "eval_runtime": 1.0982, + "eval_samples_per_second": 298.667, + "eval_steps_per_second": 2.732, + "step": 152 + }, + { + "epoch": 20.0, + "eval_accuracy": 0.9573170731707317, + "eval_f1": 0.9675925925925926, + "eval_loss": 0.20507247745990753, + "eval_runtime": 1.1981, + "eval_samples_per_second": 273.778, + "eval_steps_per_second": 2.504, + "step": 160 + }, + { + "epoch": 21.0, + "eval_accuracy": 0.9542682926829268, + "eval_f1": 0.965034965034965, + "eval_loss": 0.20550090074539185, + "eval_runtime": 1.0729, + "eval_samples_per_second": 305.706, + "eval_steps_per_second": 2.796, + "step": 168 + }, + { + "epoch": 22.0, + "eval_accuracy": 0.9573170731707317, + "eval_f1": 0.9672897196261683, + "eval_loss": 0.20781563222408295, + "eval_runtime": 1.1205, + "eval_samples_per_second": 292.726, + "eval_steps_per_second": 2.677, + "step": 176 + }, + { + "epoch": 23.0, + "eval_accuracy": 0.9542682926829268, + "eval_f1": 0.965034965034965, + "eval_loss": 0.20751667022705078, + "eval_runtime": 1.0708, + "eval_samples_per_second": 306.319, + "eval_steps_per_second": 2.802, + "step": 184 + }, + { + "epoch": 24.0, + "eval_accuracy": 0.9542682926829268, + "eval_f1": 0.965034965034965, + "eval_loss": 0.20922790467739105, + "eval_runtime": 1.0761, + "eval_samples_per_second": 304.802, + "eval_steps_per_second": 2.788, + "step": 192 + }, + { + "epoch": 25.0, + "eval_accuracy": 0.9573170731707317, + "eval_f1": 0.9672897196261683, + "eval_loss": 0.21191848814487457, + "eval_runtime": 1.0929, + "eval_samples_per_second": 300.107, + "eval_steps_per_second": 2.745, + "step": 200 + }, + { + "epoch": 26.0, + "eval_accuracy": 0.9573170731707317, + "eval_f1": 0.9672897196261683, + "eval_loss": 0.2145329713821411, + "eval_runtime": 1.0973, + "eval_samples_per_second": 298.909, + "eval_steps_per_second": 2.734, + "step": 208 + }, + { + "epoch": 27.0, + "eval_accuracy": 0.9542682926829268, + "eval_f1": 0.965034965034965, + "eval_loss": 0.21384301781654358, + "eval_runtime": 1.0774, + "eval_samples_per_second": 304.433, + "eval_steps_per_second": 2.784, + "step": 216 + }, + { + "epoch": 28.0, + "eval_accuracy": 0.9542682926829268, + "eval_f1": 0.965034965034965, + "eval_loss": 0.21024774014949799, + "eval_runtime": 1.1044, + "eval_samples_per_second": 296.997, + "eval_steps_per_second": 2.716, + "step": 224 + }, + { + "epoch": 29.0, + "eval_accuracy": 0.9512195121951219, + "eval_f1": 0.9627906976744186, + "eval_loss": 0.2104579359292984, + "eval_runtime": 1.2639, + "eval_samples_per_second": 259.513, + "eval_steps_per_second": 2.374, + "step": 232 + }, + { + "epoch": 30.0, + "eval_accuracy": 0.9512195121951219, + "eval_f1": 0.9627906976744186, + "eval_loss": 0.2128143310546875, + "eval_runtime": 1.2639, + "eval_samples_per_second": 259.518, + "eval_steps_per_second": 2.374, + "step": 240 + } + ], + "logging_steps": 500, + "max_steps": 240, + "num_input_tokens_seen": 0, + "num_train_epochs": 30, + "save_steps": 500, + "stateful_callbacks": { + "TrainerControl": { + "args": { + "should_epoch_stop": false, + "should_evaluate": false, + "should_log": false, + "should_save": true, + "should_training_stop": true + }, + "attributes": {} + } + }, + "total_flos": 0.0, + "train_batch_size": 128, + "trial_name": null, + "trial_params": null +} diff --git a/ckpt/model_toy/training_args.bin b/ckpt/model_toy/training_args.bin new file mode 100644 index 0000000..3fbdf7e Binary files /dev/null and b/ckpt/model_toy/training_args.bin differ diff --git a/main.py b/main.py index 4909549..6d40842 100644 --- a/main.py +++ b/main.py @@ -1,844 +1,17 @@ import os import re import sys -import traceback -from pathlib import Path -from queue import Queue - -from PyQt5 import uic -from PyQt5.QtGui import * -from PyQt5.QtCore import * +from src.ui.main_window import MainWindow +from PyQt5.QtCore import Qt from PyQt5.QtWidgets import * -from PyQt5.QtWidgets import QMessageBox - - -from src.hfo_app import HFO_App -from src.param.param_classifier import ParamClassifier -from src.param.param_detector import ParamDetector, ParamSTE, ParamMNI -from src.param.param_filter import ParamFilter -from src.ui.quick_detection import HFOQuickDetector -from src.ui.channels_selection import ChannelSelectionWindow -from src.ui.bipolar_channel_selection import BipolarChannelSelectionWindow -from src.ui.annotation import HFOAnnotation -from src.utils.utils_gui import * -from src.ui.plot_waveform import * -from PyQt5.QtCore import pyqtSignal -# import tkinter as tk -# from tkinter import * -# from tkinter import messagebox -import threading -import time - import multiprocessing as mp import torch - import warnings warnings.filterwarnings("ignore") -ROOT_DIR = Path(__file__).parent - - -class HFOMainWindow(QMainWindow): - close_signal = pyqtSignal() - def __init__(self): - super(HFOMainWindow, self).__init__() - self.ui = uic.loadUi(os.path.join(ROOT_DIR, 'src/ui/main_window.ui'), self) - self.setWindowIcon(QtGui.QIcon(os.path.join(ROOT_DIR, 'src/ui/images/icon1.png'))) - self.setWindowTitle("pyHFO") - self.hfo_app = HFO_App() - self.threadpool = QThreadPool() - self.replace_last_line = False - self.stdout = Queue() - self.stderr = Queue() - sys.stdout = WriteStream(self.stdout) - sys.stderr = WriteStream(self.stderr) - - self.thread_stdout = STDOutReceiver(self.stdout) - self.thread_stdout.std_received_signal.connect(self.message_handler) - self.thread_stdout.start() - - self.thread_stderr = STDErrReceiver(self.stderr) - self.thread_stderr.std_received_signal.connect(self.message_handler) - self.thread_stderr.start() - - self.action_Open_EDF.triggered.connect(self.open_file) - self.actionQuick_Detection.triggered.connect(self.open_quick_detection) - self.action_Load_Detection.triggered.connect(self.load_from_npz) - self.overview_filter_button.clicked.connect(self.filter_data) - # set filter button to be disabled by default - self.overview_filter_button.setEnabled(False) - # self.show_original_button.clicked.connect(self.toggle_filtered) - - self.is_data_filtered = False - - self.waveform_plot_widget = pg.PlotWidget() - self.waveform_mini_widget = pg.PlotWidget() - self.widget.layout().addWidget(self.waveform_plot_widget, 0, 1) - self.widget.layout().addWidget(self.waveform_mini_widget, 1, 1) - self.widget.layout().setRowStretch(0, 9) - self.widget.layout().setRowStretch(1, 1) - self.waveform_plot = PlotWaveform(self.waveform_plot_widget, self.waveform_mini_widget, self.hfo_app) - - ## top toolbar buttoms - self.actionOpen_EDF_toolbar.triggered.connect(self.open_file) - self.actionQuick_Detection_toolbar.triggered.connect(self.open_quick_detection) - self.actionLoad_Detection_toolbar.triggered.connect(self.load_from_npz) - - self.mni_detect_button.clicked.connect(self.detect_HFOs) - self.mni_detect_button.setEnabled(False) - self.ste_detect_button.clicked.connect(self.detect_HFOs) - self.ste_detect_button.setEnabled(False) - - #classifier tab buttons - self.classifier_param = ParamClassifier() - #self.classifier_save_button.clicked.connect(self.hfo_app.set_classifier()) - - #init inputs - self.init_default_filter_input_params() - self.init_default_ste_input_params() - self.init_default_mni_input_params() - - #classifier default buttons - self.default_cpu_button.clicked.connect(self.set_classifier_param_cpu_default) - self.default_gpu_button.clicked.connect(self.set_classifier_param_gpu_default) - - #choose model files connection - self.choose_artifact_model_button.clicked.connect(lambda : self.choose_model_file("artifact")) - self.choose_spike_model_button.clicked.connect(lambda : self.choose_model_file("spike")) - - #custom model param connection - self.classifier_save_button.clicked.connect(self.set_custom_classifier_param) - - #detect_all_button - self.detect_all_button.clicked.connect(lambda: self.classify(True)) - self.detect_all_button.setEnabled(False) - # self.detect_artifacts_button.clicked.connect(lambda : self.classify(False)) - - self.save_csv_button.clicked.connect(self.save_to_excel) - self.save_csv_button.setEnabled(False) - - #set n_jobs min and max - self.n_jobs_spinbox.setMinimum(1) - self.n_jobs_spinbox.setMaximum(mp.cpu_count()) - - #set default n_jobs - self.n_jobs_spinbox.setValue(self.hfo_app.n_jobs) - self.n_jobs_ok_button.clicked.connect(self.set_n_jobs) - - self.STE_save_button.clicked.connect(self.save_ste_params) - self.MNI_save_button.clicked.connect(self.save_mni_params) - self.STE_save_button.setEnabled(False) - self.MNI_save_button.setEnabled(False) - - self.save_npz_button.clicked.connect(self.save_to_npz) - self.save_npz_button.setEnabled(False) - - self.Filter60Button.toggled.connect(self.switch_60) - self.Filter60Button.setEnabled(False) - - self.bipolar_button.clicked.connect(self.open_bipolar_channel_selection) - self.bipolar_button.setEnabled(False) - - #annotation button - self.annotation_button.clicked.connect(self.open_annotation) - self.annotation_button.setEnabled(False) - - self.Choose_Channels_Button.setEnabled(False) - self.waveform_plot_button.setEnabled(False) - - self.channels_to_plot = [] - - #check if gpu is available - self.gpu = torch.cuda.is_available() - # print(f"GPU available: {self.gpu}") - if not self.gpu: - #disable gpu buttons - self.default_gpu_button.setEnabled(False) - - self.quick_detect_open = False - self.set_mni_input_len(8) - self.set_ste_input_len(8) - - #close window signal - - def reinitialize_buttons(self): - self.mni_detect_button.setEnabled(False) - self.ste_detect_button.setEnabled(False) - self.detect_all_button.setEnabled(False) - self.save_csv_button.setEnabled(False) - self.save_npz_button.setEnabled(False) - self.STE_save_button.setEnabled(False) - self.MNI_save_button.setEnabled(False) - self.Filter60Button.setEnabled(False) - - def set_mni_input_len(self,max_len = 5): - self.mni_epoch_time_input.setMaxLength(max_len) - self.mni_epoch_chf_input.setMaxLength(max_len) - self.mni_chf_percentage_input.setMaxLength(max_len) - self.mni_min_window_input.setMaxLength(max_len) - self.mni_min_gap_time_input.setMaxLength(max_len) - self.mni_threshold_percentage_input.setMaxLength(max_len) - self.mni_baseline_window_input.setMaxLength(max_len) - self.mni_baseline_shift_input.setMaxLength(max_len) - self.mni_baseline_threshold_input.setMaxLength(max_len) - self.mni_baseline_min_time_input.setMaxLength(max_len) - - def set_ste_input_len(self,max_len = 5): - self.ste_rms_window_input.setMaxLength(max_len) - self.ste_min_window_input.setMaxLength(max_len) - self.ste_min_gap_input.setMaxLength(max_len) - self.ste_epoch_length_input.setMaxLength(max_len) - self.ste_min_oscillation_input.setMaxLength(max_len) - self.ste_rms_threshold_input.setMaxLength(max_len) - self.ste_peak_threshold_input.setMaxLength(max_len) - - - def close_other_window(self): - self.close_signal.emit() - - def set_n_jobs(self): - self.hfo_app.n_jobs = int(self.n_jobs_spinbox.value()) - # print(f"n_jobs set to {self.hfo_app.n_jobs}") - - def set_channels_to_plot(self, channels_to_plot, display_all = True): - self.waveform_plot.set_channels_to_plot(channels_to_plot) - # print(f"Channels to plot: {self.channels_to_plot}") - self.n_channel_input.setMaximum(len(channels_to_plot)) - if display_all: - self.n_channel_input.setValue(len(channels_to_plot)) - self.waveform_plot_button_clicked() - - def open_channel_selection(self): - self.channel_selection_window = ChannelSelectionWindow(self.hfo_app, self, self.close_signal) - self.channel_selection_window.show() - - def channel_selection_update(self): - self.channel_scroll_bar.setValue(0) - self.waveform_time_scroll_bar.setValue(0) - is_empty = self.n_channel_input.maximum() == 0 - self.waveform_plot.plot(0,0,empty=is_empty,update_hfo=True) - - def switch_60(self): - #get the value of the Filter60Button radio button - filter_60 = self.Filter60Button.isChecked() - # print("filtering:", filter_60) - #if yes - if filter_60: - self.hfo_app.set_filter_60() - #if not - else: - self.hfo_app.set_unfiltered_60() - - #replot - self.waveform_plot.plot() - #add a warning to the text about the HFO info saying that it is outdated now - - @pyqtSlot(str) - def message_handler(self, s): - s = s.replace("\n", "") - horScrollBar = self.STDTextEdit.horizontalScrollBar() - verScrollBar = self.STDTextEdit.verticalScrollBar() - scrollIsAtEnd = verScrollBar.maximum() - verScrollBar.value() <= 10 - - contain_percentage = re.findall(r'%', s) - contain_one_hundred_percentage = re.findall(r'100%', s) - if contain_one_hundred_percentage: - cursor = self.STDTextEdit.textCursor() - cursor.movePosition(QTextCursor.End - 1) - cursor.select(QTextCursor.BlockUnderCursor) - cursor.removeSelectedText() - self.STDTextEdit.setTextCursor(cursor) - self.STDTextEdit.insertPlainText(s) - elif contain_percentage: - cursor = self.STDTextEdit.textCursor() - cursor.movePosition(QTextCursor.End) - cursor.select(QTextCursor.BlockUnderCursor) - cursor.removeSelectedText() - self.STDTextEdit.setTextCursor(cursor) - self.STDTextEdit.insertPlainText(s) - else: - self.STDTextEdit.append(s) - - if scrollIsAtEnd: - verScrollBar.setValue(verScrollBar.maximum()) # Scrolls to the bottom - horScrollBar.setValue(0) # scroll to the left - - def reinitialize(self): - #kill all threads in self.threadpool - self.close_other_window() - self.hfo_app = HFO_App() - self.waveform_plot.update_backend(self.hfo_app, False) - self.main_filename.setText("") - self.main_sampfreq.setText("") - self.main_numchannels.setText("") - self.main_length.setText("") - self.statistics_label.setText("") - - - @pyqtSlot(list) - def update_edf_info(self, results): - self.main_filename.setText(results[0]) - self.main_sampfreq.setText(results[1]) - self.sample_freq = float(results[1]) - self.main_numchannels.setText(results[2]) - # print("updated") - self.main_length.setText(str(round(float(results[3])/(60*float(results[1])),3))+" min") - self.waveform_plot.plot(0, update_hfo=True) - # print("plotted") - #connect buttons - self.waveform_time_scroll_bar.valueChanged.connect(self.scroll_time_waveform_plot) - self.channel_scroll_bar.valueChanged.connect(self.scroll_channel_waveform_plot) - self.waveform_plot_button.clicked.connect(self.waveform_plot_button_clicked) - self.waveform_plot_button.setEnabled(True) - self.Choose_Channels_Button.clicked.connect(self.open_channel_selection) - self.Choose_Channels_Button.setEnabled(True) - #set the display time window spin box - self.display_time_window_input.setValue(self.waveform_plot.get_time_window()) - self.display_time_window_input.setMaximum(self.waveform_plot.get_total_time()) - self.display_time_window_input.setMinimum(0.1) - #set the n channel spin box - self.n_channel_input.setValue(self.waveform_plot.get_n_channels_to_plot()) - self.n_channel_input.setMaximum(self.waveform_plot.get_n_channels()) - self.n_channel_input.setMinimum(1) - #set the time scroll bar range - self.waveform_time_scroll_bar.setMaximum(int(self.waveform_plot.get_total_time()/(self.waveform_plot.get_time_window()*self.waveform_plot.get_time_increment()/100))) - self.waveform_time_scroll_bar.setValue(0) - #set the channel scroll bar range - self.channel_scroll_bar.setMaximum(self.waveform_plot.get_n_channels()-self.waveform_plot.get_n_channels_to_plot()) - #enable the filter button - self.overview_filter_button.setEnabled(True) - self.toggle_filtered_checkbox.stateChanged.connect(self.toggle_filtered) - self.normalize_vertical_input.stateChanged.connect(self.waveform_plot_button_clicked) - #enable the plot out the 60Hz bandstopped signal - self.Filter60Button.setEnabled(True) - self.bipolar_button.setEnabled(True) - #print("EDF file loaded") - - - def init_default_filter_input_params(self): - default_params=ParamFilter() - self.fp_input.setText(str(default_params.fp)) - self.fs_input.setText(str(default_params.fs)) - self.rp_input.setText(str(default_params.rp)) - self.rs_input.setText(str(default_params.rs)) - - def init_default_ste_input_params(self): - default_params=ParamSTE(2000) - self.ste_rms_window_input.setText(str(default_params.rms_window)) - self.ste_rms_threshold_input.setText(str(default_params.rms_thres)) - self.ste_min_window_input.setText(str(default_params.min_window)) - self.ste_epoch_length_input.setText(str(default_params.epoch_len)) - self.ste_min_gap_input.setText(str(default_params.min_gap)) - self.ste_min_oscillation_input.setText(str(default_params.min_osc)) - self.ste_peak_threshold_input.setText(str(default_params.peak_thres)) - - def init_default_mni_input_params(self): - """this is how I got the params, I reversed it here - - epoch_time = self.mni_epoch_time_input.text() - epo_CHF = self.mni_epoch_CHF_input.text() - per_CHF = self.mni_chf_percentage_input.text() - min_win = self.mni_min_window_input.text() - min_gap = self.mni_min_gap_time_input.text() - thrd_perc = self.mni_threshold_percentage_input.text() - base_seg = self.mni_baseline_window_input.text() - base_shift = self.mni_baseline_shift_input.text() - base_thrd = self.mni_baseline_threshold_input.text() - base_min = self.mni_baseline_min_time_input.text() - """ - default_params=ParamMNI(200) - self.mni_epoch_time_input.setText(str(default_params.epoch_time)) - self.mni_epoch_chf_input.setText(str(default_params.epo_CHF)) - self.mni_chf_percentage_input.setText(str(default_params.per_CHF)) - self.mni_min_window_input.setText(str(default_params.min_win)) - self.mni_min_gap_time_input.setText(str(default_params.min_gap)) - self.mni_threshold_percentage_input.setText(str(default_params.thrd_perc*100)) - self.mni_baseline_window_input.setText(str(default_params.base_seg)) - self.mni_baseline_shift_input.setText(str(default_params.base_shift)) - self.mni_baseline_threshold_input.setText(str(default_params.base_thrd)) - self.mni_baseline_min_time_input.setText(str(default_params.base_min)) - - def scroll_time_waveform_plot(self, event): - t_start=self.waveform_time_scroll_bar.value()*self.waveform_plot.get_time_window()*self.waveform_plot.get_time_increment()/100 - self.waveform_plot.plot(t_start) - - def scroll_channel_waveform_plot(self, event): - channel_start=self.channel_scroll_bar.value() - self.waveform_plot.plot(first_channel_to_plot=channel_start, update_hfo=True) - - def get_channels_to_plot(self): - return self.waveform_plot.get_channels_to_plot() - - def get_channel_indices_to_plot(self): - return self.waveform_plot.get_channel_indices_to_plot() - - def waveform_plot_button_clicked(self): - time_window=self.display_time_window_input.value() - self.waveform_plot.set_time_window(time_window) - n_channels_to_plot=self.n_channel_input.value() - self.waveform_plot.set_n_channels_to_plot(n_channels_to_plot) - time_increment = self.Time_Increment_Input.value() - self.waveform_plot.set_time_increment(time_increment) - normalize_vertical = self.normalize_vertical_input.isChecked() - self.waveform_plot.set_normalize_vertical(normalize_vertical) - is_empty = self.n_channel_input.maximum() == 0 - start = self.waveform_plot.t_start - first_channel_to_plot = self.waveform_plot.first_channel_to_plot - - t_value = int(start//(self.waveform_plot.get_time_window()*self.waveform_plot.get_time_increment()/100)) - self.waveform_time_scroll_bar.setMaximum(int(self.waveform_plot.get_total_time()/(self.waveform_plot.get_time_window()*self.waveform_plot.get_time_increment()/100))) - self.waveform_time_scroll_bar.setValue(t_value) - c_value = self.channel_scroll_bar.value() - self.channel_scroll_bar.setMaximum(len(self.waveform_plot.get_channels_to_plot())-n_channels_to_plot) - self.channel_scroll_bar.setValue(c_value) - self.waveform_plot.plot(start,first_channel_to_plot,empty=is_empty,update_hfo=True) - - def open_file(self): - #reinitialize the app - self.hfo_app = HFO_App() - fname, _ = QFileDialog.getOpenFileName(self, "Open File", "", "Recordings Files (*.edf *.eeg *.vhdr *.vmrk)") - if fname: - worker = Worker(self.read_edf, fname) - worker.signals.result.connect(self.update_edf_info) - self.threadpool.start(worker) - - def filtering_complete(self): - self.message_handler('Filtering COMPLETE!') - filter_60 = self.Filter60Button.isChecked() - # print("filtering:", filter_60) - #if yes - if filter_60: - self.hfo_app.set_filter_60() - #if not - else: - self.hfo_app.set_unfiltered_60() - - self.STE_save_button.setEnabled(True) - self.ste_detect_button.setEnabled(True) - self.MNI_save_button.setEnabled(True) - self.mni_detect_button.setEnabled(True) - self.is_data_filtered = True - self.show_filtered = True - self.waveform_plot.set_filtered(True) - self.save_npz_button.setEnabled(True) - - def filter_data(self): - self.message_handler("Filtering data...") - try: - #get filter parameters - fp_raw = self.fp_input.text() - fs_raw = self.fs_input.text() - rp_raw = self.rp_input.text() - rs_raw = self.rs_input.text() - #self.pop_window() - param_dict={"fp":float(fp_raw), "fs":float(fs_raw), "rp":float(rp_raw), "rs":float(rs_raw)} - filter_param = ParamFilter.from_dict(param_dict) - self.hfo_app.set_filter_parameter(filter_param) - except: - # there is error of the filter machine - # therefore pop up window to show that filter failed - msg = QMessageBox() - msg.setIcon(QMessageBox.Critical) - msg.setText("Error") - msg.setInformativeText('Filter could not be constructed with the given parameters') - msg.setWindowTitle("Filter Construction Error") - msg.exec_() - return - worker=Worker(self._filter) - worker.signals.finished.connect(self.filtering_complete) - self.threadpool.start(worker) - - def toggle_filtered(self): - # self.message_handler('Showing original data...') - if self.is_data_filtered: - self.show_filtered = not self.show_filtered - self.waveform_plot.set_filtered(self.show_filtered) - self.waveform_plot_button_clicked() - - def read_edf(self, fname, progress_callback): - self.reinitialize() - self.hfo_app.load_edf(fname) - eeg_data,channel_names=self.hfo_app.get_eeg_data() - edf_info=self.hfo_app.get_edf_info() - self.waveform_plot.init_eeg_data() - filename = os.path.basename(fname) - sample_freq = str(self.hfo_app.sample_freq) - num_channels = str(len(self.hfo_app.channel_names)) - length = str(self.hfo_app.eeg_data.shape[1]) - return [filename, sample_freq, num_channels, length] - - - def _filter(self, progress_callback): - self.hfo_app.filter_eeg_data() - return [] - - - def open_detector(self): - # Pass the function to execute, function, args, kwargs - worker = Worker(self.quick_detect) - self.threadpool.start(worker) - - def round_dict(self, d:dict, n:int): - for key in d.keys(): - if type(d[key]) == float: - d[key] = round(d[key], n) - return d - - def save_ste_params(self): - #get filter parameters - rms_window_raw = self.ste_rms_window_input.text() - min_window_raw = self.ste_min_window_input.text() - min_gap_raw = self.ste_min_gap_input.text() - epoch_len_raw = self.ste_epoch_length_input.text() - min_osc_raw = self.ste_min_oscillation_input.text() - rms_thres_raw = self.ste_rms_threshold_input.text() - peak_thres_raw = self.ste_peak_threshold_input.text() - try: - param_dict = {"sample_freq":2000,"pass_band":1, "stop_band":80, #these are placeholder params, will be updated later - "rms_window":float(rms_window_raw), "min_window":float(min_window_raw), "min_gap":float(min_gap_raw), - "epoch_len":float(epoch_len_raw), "min_osc":float(min_osc_raw), "rms_thres":float(rms_thres_raw), - "peak_thres":float(peak_thres_raw),"n_jobs":self.hfo_app.n_jobs} - detector_params = {"detector_type":"STE", "detector_param":param_dict} - self.hfo_app.set_detector(ParamDetector.from_dict(detector_params)) - - #set display parameters - self.ste_epoch_display.setText(epoch_len_raw) - self.ste_min_window_display.setText(min_window_raw) - self.ste_rms_window_display.setText(rms_window_raw) - self.ste_min_gap_time_display.setText(min_gap_raw) - self.ste_min_oscillations_display.setText(min_osc_raw) - self.ste_peak_threshold_display.setText(peak_thres_raw) - self.ste_rms_threshold_display.setText(rms_thres_raw) - self.update_detector_tab("STE") - except: - msg = QMessageBox() - msg.setIcon(QMessageBox.Critical) - msg.setText("Error!") - msg.setInformativeText('Detector could not be constructed given the parameters') - msg.setWindowTitle("Detector Construction Failed") - msg.exec_() - - - def save_mni_params(self): - try: - epoch_time = self.mni_epoch_time_input.text() - epo_CHF = self.mni_epoch_chf_input.text() - per_CHF = self.mni_chf_percentage_input.text() - min_win = self.mni_min_window_input.text() - min_gap = self.mni_min_gap_time_input.text() - thrd_perc = self.mni_threshold_percentage_input.text() - base_seg = self.mni_baseline_window_input.text() - base_shift = self.mni_baseline_shift_input.text() - base_thrd = self.mni_baseline_threshold_input.text() - base_min = self.mni_baseline_min_time_input.text() - - param_dict = {"sample_freq":2000,"pass_band":1, "stop_band":80, #these are placeholder params, will be updated later - "epoch_time":float(epoch_time), "epo_CHF":float(epo_CHF), "per_CHF":float(per_CHF), - "min_win":float(min_win), "min_gap":float(min_gap), "base_seg":float(base_seg), - "thrd_perc":float(thrd_perc)/100, - "base_shift":float(base_shift), "base_thrd":float(base_thrd), "base_min":float(base_min), - "n_jobs":self.hfo_app.n_jobs} - # param_dict = self.round_dict(param_dict, 3) - detector_params = {"detector_type":"MNI", "detector_param":param_dict} - self.hfo_app.set_detector(ParamDetector.from_dict(detector_params)) - - #set display parameters - self.mni_epoch_display.setText(epoch_time) - self.mni_epoch_chf_display.setText(epo_CHF) - self.mni_chf_percentage_display.setText(per_CHF) - self.mni_min_window_display.setText(min_win) - self.mni_min_gap_time_display.setText(min_gap) - self.mni_threshold_percentile_display.setText(thrd_perc) - self.mni_baseline_window_display.setText(base_seg) - self.mni_baseline_shift_display.setText(base_shift) - self.mni_baseline_threshold_display.setText(base_thrd) - self.mni_baseline_min_time_display.setText(base_min) - - self.update_detector_tab("MNI") - except: - msg = QMessageBox() - msg.setIcon(QMessageBox.Critical) - msg.setText("Error!") - msg.setInformativeText('Detector could not be constructed given the parameters') - msg.setWindowTitle("Detector Construction Failed") - msg.exec_() - - def detect_HFOs(self): - print("Detecting HFOs...") - worker=Worker(self._detect) - worker.signals.result.connect(self._detect_finished) - self.threadpool.start(worker) - - def _detect_finished(self): - #right now do nothing beyond message handler saying that - # it has detected HFOs - self.message_handler("HFOs detected") - self.update_statistics_label() - self.waveform_plot.set_plot_HFOs(True) - self.detect_all_button.setEnabled(True) - self.annotation_button.setEnabled(True) - - def _detect(self, progress_callback): - #call detect HFO function on backend - self.hfo_app.detect_HFO() - return [] - - def open_quick_detection(self): - # if we want to open multiple qd dialog - if not self.quick_detect_open: - qd = HFOQuickDetector(HFO_App(), self, self.close_signal) - # print("created new quick detector") - qd.show() - self.quick_detect_open = True - - def set_quick_detect_open(self, open): - self.quick_detect_open = open - - def update_detector_tab(self, index): - if index == "MNI": - self.stackedWidget.setCurrentIndex(0) - elif index == "STE": - self.stackedWidget.setCurrentIndex(1) - - def set_classifier_param_display(self): - classifier_param = self.hfo_app.get_classifier_param() - - self.overview_artifact_path_display.setText(classifier_param.artifact_path) - self.overview_spike_path_display.setText(classifier_param.spike_path) - self.overview_use_spike_checkbox.setChecked(classifier_param.use_spike) - self.overview_device_display.setText(str(classifier_param.device)) - self.overview_batch_size_display.setText(str(classifier_param.batch_size)) - - #set also the input fields - self.classifier_artifact_filename.setText(classifier_param.artifact_path) - self.classifier_spike_filename.setText(classifier_param.spike_path) - self.use_spike_checkbox.setChecked(classifier_param.use_spike) - self.classifier_device_input.setText(str(classifier_param.device)) - self.classifier_batch_size_input.setText(str(classifier_param.batch_size)) - - def set_classifier_param_gpu_default(self): - self.hfo_app.set_default_gpu_classifier() - self.set_classifier_param_display() - - def set_classifier_param_cpu_default(self): - self.hfo_app.set_default_cpu_classifier() - self.set_classifier_param_display() - - def set_custom_classifier_param(self): - artifact_path = self.classifier_artifact_filename.text() - spike_path = self.classifier_spike_filename.text() - use_spike = self.use_spike_checkbox.isChecked() - device = self.classifier_device_input.text() - if device=="cpu": - model_type = "default_cpu" - elif device=="cuda:0" and self.gpu: - model_type = "default_gpu" - else: - # print("device not recognized, please set to cpu for cpu or cuda:0 for gpu") - msg = QMessageBox() - msg.setIcon(QMessageBox.Critical) - msg.setText("Error!") - msg.setInformativeText('Device not recognized, please set to CPU for CPU or cuda:0 for GPU') - msg.setWindowTitle("Device not recognized") - msg.exec_() - return - batch_size = self.classifier_batch_size_input.text() - - classifier_param = ParamClassifier(artifact_path=artifact_path, spike_path=spike_path, use_spike=use_spike, - device=device, batch_size=int(batch_size), model_type=model_type) - self.hfo_app.set_classifier(classifier_param) - self.set_classifier_param_display() - - def choose_model_file(self, model_type): - fname,_ = QFileDialog.getOpenFileName(self, 'Open file', "", ".tar files (*.tar)") - if model_type == "artifact": - self.classifier_artifact_filename.setText(fname) - elif model_type == "spike": - self.classifier_spike_filename.setText(fname) - - def _classify(self,artifact_only=False): - threshold = 0.5 - seconds_to_ignore_before=float(self.overview_ignore_before_input.text()) - seconds_to_ignore_after=float(self.overview_ignore_after_input.text()) - self.hfo_app.classify_artifacts([seconds_to_ignore_before,seconds_to_ignore_after], threshold) - if not artifact_only: - self.hfo_app.classify_spikes() - return [] - - def _classify_finished(self): - self.message_handler("Classification finished!..") - self.update_statistics_label() - self.waveform_plot.set_plot_HFOs(True) - self.save_csv_button.setEnabled(True) - - def classify(self,check_spike=True): - self.message_handler("Classifying HFOs...") - if check_spike: - use_spike=self.overview_use_spike_checkbox.isChecked() - else: - use_spike=False - worker=Worker(lambda progress_callback: self._classify((not use_spike))) - worker.signals.result.connect(self._classify_finished) - self.threadpool.start(worker) - - def update_statistics_label(self): - num_HFO = self.hfo_app.hfo_features.get_num_HFO() - num_artifact = self.hfo_app.hfo_features.get_num_artifact() - num_spike = self.hfo_app.hfo_features.get_num_spike() - num_real = self.hfo_app.hfo_features.get_num_real() - - self.statistics_label.setText(" Number of HFOs: " + str(num_HFO) +\ - "\n Number of artifacts: " + str(num_artifact) +\ - "\n Number of spikes: " + str(num_spike) +\ - "\n Number of real HFOs: " + str(num_real)) - - def save_to_excel(self): - #open file dialog - fname,_ = QFileDialog.getSaveFileName(self, 'Save file', "", ".xlsx files (*.xlsx)") - if fname: - self.hfo_app.export_excel(fname) - - def _save_to_npz(self,fname,progress_callback): - self.hfo_app.export_app(fname) - return [] - - def save_to_npz(self): - #open file dialog - # print("saving to npz...",end="") - fname,_ = QFileDialog.getSaveFileName(self, 'Save file', "", ".npz files (*.npz)") - if fname: - # print("saving to {fname}...",end="") - worker = Worker(self._save_to_npz, fname) - worker.signals.result.connect(lambda: 0) - self.threadpool.start(worker) - - def _load_from_npz(self,fname,progress_callback): - self.hfo_app = self.hfo_app.import_app(fname) - return [] - - def load_from_npz(self): - #open file dialog - fname,_ = QFileDialog.getOpenFileName(self, 'Open file', "", ".npz files (*.npz)") - self.message_handler("Loading from npz...") - if fname: - self.reinitialize() - worker = Worker(self._load_from_npz, fname) - worker.signals.result.connect(self.load_from_npz_finished) - self.threadpool.start(worker) - # print(self.hfo_app.get_edf_info()) - - def load_from_npz_finished(self): - edf_info = self.hfo_app.get_edf_info() - self.waveform_plot.update_backend(self.hfo_app) - self.waveform_plot.init_eeg_data() - edf_name=str(edf_info["edf_fn"]) - edf_name=edf_name[edf_name.rfind("/")+1:] - self.update_edf_info([edf_name, str(edf_info["sfreq"]), - str(edf_info["nchan"]), str(self.hfo_app.eeg_data.shape[1])]) - #update number of jobs - self.n_jobs_spinbox.setValue(self.hfo_app.n_jobs) - if self.hfo_app.filtered: - self.filtering_complete() - filter_param = self.hfo_app.param_filter - #update filter params - self.fp_input.setText(str(filter_param.fp)) - self.fs_input.setText(str(filter_param.fs)) - self.rp_input.setText(str(filter_param.rp)) - self.rs_input.setText(str(filter_param.rs)) - #update the detector parameters: - if self.hfo_app.detected: - self.set_detector_param_display() - self._detect_finished() - self.update_statistics_label() - #update classifier param - if self.hfo_app.classified: - self.set_classifier_param_display() - self._classify_finished() - self.update_statistics_label() - - def update_ste_params(self,ste_params): - rms_window = str(ste_params["rms_window"]) - min_window = str(ste_params["min_window"]) - min_gap = str(ste_params["min_gap"]) - epoch_len = str(ste_params["epoch_len"]) - min_osc = str(ste_params["min_osc"]) - rms_thres = str(ste_params["rms_thres"]) - peak_thres = str(ste_params["peak_thres"]) - - self.ste_rms_window_input.setText(rms_window) - self.ste_min_window_input.setText(min_window) - self.ste_min_gap_input.setText(min_gap) - self.ste_epoch_length_input.setText(epoch_len) - self.ste_min_oscillation_input.setText(min_osc) - self.ste_rms_threshold_input.setText(rms_thres) - self.ste_peak_threshold_input.setText(peak_thres) - - #set display parameters - self.ste_epoch_display.setText(epoch_len) - self.ste_min_window_display.setText(min_window) - self.ste_rms_window_display.setText(rms_window) - self.ste_min_gap_time_display.setText(min_gap) - self.ste_min_oscillations_display.setText(min_osc) - self.ste_peak_threshold_display.setText(peak_thres) - self.ste_rms_threshold_display.setText(rms_thres) - - self.update_detector_tab("STE") - self.detector_subtabs.setCurrentIndex(0) - - def update_mni_params(self,mni_params): - epoch_time = str(mni_params["epoch_time"]) - epo_CHF = str(mni_params["epo_CHF"]) - per_CHF = str(mni_params["per_CHF"]) - min_win = str(mni_params["min_win"]) - min_gap = str(mni_params["min_gap"]) - thrd_perc = str(mni_params["thrd_perc"]) - base_seg = str(mni_params["base_seg"]) - base_shift = str(mni_params["base_shift"]) - base_thrd = str(mni_params["base_thrd"]) - base_min = str(mni_params["base_min"]) - - self.mni_epoch_time_input.setText(epoch_time) - self.mni_epoch_chf_input.setText(epo_CHF) - self.mni_chf_percentage_input.setText(per_CHF) - self.mni_min_window_input.setText(min_win) - self.mni_min_gap_time_input.setText(min_gap) - self.mni_threshold_percentage_input.setText(thrd_perc) - self.mni_baseline_window_input.setText(base_seg) - self.mni_baseline_shift_input.setText(base_shift) - self.mni_baseline_threshold_input.setText(base_thrd) - self.mni_baseline_min_time_input.setText(base_min) - - #set display parameters - self.mni_epoch_display.setText(epoch_time) - self.mni_epoch_chf_display.setText(epo_CHF) - self.mni_chf_percentage_display.setText(per_CHF) - self.mni_min_window_display.setText(min_win) - self.mni_min_gap_time_display.setText(min_gap) - self.mni_threshold_percentile_display.setText(thrd_perc) - self.mni_baseline_window_display.setText(base_seg) - self.mni_baseline_shift_display.setText(base_shift) - self.mni_baseline_threshold_display.setText(base_thrd) - self.mni_baseline_min_time_display.setText(base_min) - - self.update_detector_tab("MNI") - self.detector_subtabs.setCurrentIndex(1) - - def set_detector_param_display(self): - detector_params = self.hfo_app.param_detector - detector_type = detector_params.detector_type.lower() - if detector_type == "ste": - self.update_ste_params(detector_params.detector_param.to_dict()) - elif detector_type == "mni": - self.update_mni_params(detector_params.detector_param.to_dict()) - - def open_bipolar_channel_selection(self): - self.bipolar_channel_selection_window = BipolarChannelSelectionWindow(self.hfo_app, self, self.close_signal,self.waveform_plot) - self.bipolar_channel_selection_window.show() - - def open_annotation(self): - self.save_csv_button.setEnabled(True) - annotation = HFOAnnotation(self.hfo_app, self, self.close_signal) - annotation.show() +# Enable DPI scaling +QApplication.setAttribute(Qt.AA_EnableHighDpiScaling, True) +QApplication.setAttribute(Qt.AA_UseHighDpiPixmaps, True) def closeAllWindows(): @@ -848,7 +21,7 @@ def closeAllWindows(): if __name__ == '__main__': mp.freeze_support() app = QApplication(sys.argv) - mainWindow = HFOMainWindow() + mainWindow = MainWindow() mainWindow.show() app.aboutToQuit.connect(closeAllWindows) sys.exit(app.exec_()) diff --git a/main.spec b/main.spec deleted file mode 100644 index 6312311..0000000 --- a/main.spec +++ /dev/null @@ -1,48 +0,0 @@ -# -*- mode: python ; coding: utf-8 -*- -import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5) - -block_cipher = None - - -a = Analysis( - ['main.py'], - pathex=[], - binaries=[], - datas=[("C:\\Users\\Lawrence\\anaconda3\\envs\\pyHFO\\Lib\\site-packages\\mne", "mne"), - ('C:\\Users\\Lawrence\\anaconda3\\envs\\pyHFO\\Lib\\site-packages\\torch', "torch"), - ('C:\\Users\\Lawrence\\anaconda3\\envs\\pyHFO\\Lib\\site-packages\\torchvision', "torchvision"), - ('src','src'), - ('ckpt','ckpt')], - hiddenimports=[], - hookspath=[], - hooksconfig={}, - runtime_hooks=[], - excludes=[], - win_no_prefer_redirects=False, - win_private_assemblies=False, - cipher=block_cipher, - noarchive=False, -) -pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) - -exe = EXE( - pyz, - a.scripts, - a.binaries, - a.zipfiles, - a.datas, - [], - name='pyHFO', - debug=False, - bootloader_ignore_signals=False, - strip=False, - upx=True, - upx_exclude=[], - runtime_tmpdir=None, - console=False, - disable_windowed_traceback=False, - argv_emulation=False, - target_arch=None, - codesign_identity=None, - entitlements_file=None, -) diff --git a/requirements.txt b/requirements.txt index 46bafdb..22c155c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,5 @@ scikit-image==0.21.0 torch==2.0.1 torchvision==0.15.2 tqdm==4.65.0 +transformers[torch] +yasa \ No newline at end of file diff --git a/spike_test.py b/spike_test.py new file mode 100644 index 0000000..4b7872a --- /dev/null +++ b/spike_test.py @@ -0,0 +1,89 @@ +import mne +import numpy as np +from scipy.signal import hilbert, find_peaks +from scipy.signal import detrend + +# Load the EDF file using MNE +file_path = 'MV_2.edf' # Replace with your actual file path +raw_data = mne.io.read_raw_edf(file_path, preload=True) + +# Get basic information about the file +info = raw_data.info +signal_labels = raw_data.ch_names +sampling_rate = raw_data.info['sfreq'] + +# Apply notch filter to remove 60 Hz noise +raw_notched = raw_data.copy().notch_filter(freqs=60, method='fir', fir_design='firwin') + +# Bandpass filter the data between 25 and 80 Hz using MNE's built-in FIR filter +low_cutoff = 25 # Lower bound of the bandpass filter in Hz +high_cutoff = 80 # Upper bound of the bandpass filter in Hz +raw_filtered = raw_notched.copy().filter(l_freq=low_cutoff, h_freq=high_cutoff, method='fir', fir_design='firwin', phase='zero-double') + + +# Extract the filtered signal and unfiltered signal for the first channel +filtered_signal_data, filtered_times = raw_filtered[0, :] +unfiltered_signal_data, _ = raw_notched[0, :] + +# Apply Hilbert transform to the filtered signal +analytic_signal = hilbert(filtered_signal_data[0]) + +# Calculate the amplitude envelope (magnitude of the analytic signal) +amplitude_envelope = np.abs(analytic_signal) + +# Calculate the mean amplitude of the envelope +mean_amplitude = np.mean(amplitude_envelope) + +# Identify candidate spikes where the amplitude envelope exceeds 3 times the mean amplitude +threshold = 3 * mean_amplitude +candidate_spikes = np.where(amplitude_envelope > threshold)[0] + +# Time window around each candidate spike (±0.25 s) +window_size = int(0.25 * sampling_rate) + +# List to store the valid spikes +valid_spikes = [] + +for spike_idx in candidate_spikes: + # Get the time window around the spike for unfiltered data + start_idx = max(0, spike_idx - window_size) + end_idx = min(len(unfiltered_signal_data[0]), spike_idx + window_size) + + signal_window = unfiltered_signal_data[0][start_idx:end_idx] + + # Detrend the window + detrended_signal = detrend(signal_window) + + # Identify peaks and troughs + peaks, _ = find_peaks(detrended_signal) + troughs, _ = find_peaks(-detrended_signal) + + # Calculate the Fano factor + if len(peaks) > 1 and len(troughs) > 1: + # Calculate inter-peak and inter-trough intervals + inter_peak_intervals = np.diff(peaks) / sampling_rate # Convert to seconds + inter_trough_intervals = np.diff(troughs) / sampling_rate + + peak_fano_factor = np.var(inter_peak_intervals) / np.mean(inter_peak_intervals) + trough_fano_factor = np.var(inter_trough_intervals) / np.mean(inter_trough_intervals) + + fano_factor = (peak_fano_factor + trough_fano_factor) / 2 # Average + + # Calculate the maximum amplitude in the window (rectified) + max_amplitude = np.max(np.abs(signal_window)) + + # Check the conditions for valid spikes + if max_amplitude > 3 * mean_amplitude and fano_factor >= 2.5: + valid_spikes.append(spike_idx) + +# Merge spikes detected within 20 ms of each other +merged_spikes = [] +previous_spike = -np.inf +for spike in valid_spikes: + if spike - previous_spike > 0.02 * sampling_rate: + merged_spikes.append(spike) + previous_spike = spike + +# Display final spike indices and times +spike_times = filtered_times[merged_spikes] +print(f"Detected spikes at times (s): {spike_times}") diff --git a/src/classifer.py b/src/classifer.py index ba50dae..842e3d0 100644 --- a/src/classifer.py +++ b/src/classifer.py @@ -3,6 +3,12 @@ from src.param.param_classifier import ParamClassifier import torch from src.model import PreProcessing +from transformers import TrainingArguments, ViTForImageClassification +from transformers import Trainer +from src.dl_models import * +import os + + class Classifier(): def __init__(self, param:ParamClassifier): self.device = param.device @@ -12,6 +18,7 @@ def __init__(self, param:ParamClassifier): self.load_func = torch.load if "default" in self.model_type else torch.load #torch.hub.load_state_dict_from_url if param.artifact_path: self.update_model_a(param) + self.update_model_toy(param) if param.spike_path: self.update_model_s(param) @@ -47,41 +54,61 @@ def update_model_a(self, param:ParamClassifier): self.model_a = model.to(self.device) self.preprocessing_artifact = PreProcessing.from_param(self.param_artifact_preprocessing) - def artifact_detection(self, HFO_features, ignore_region, threshold=0.5): - if not self.model_a: + def update_model_toy(self, param:ParamClassifier): + self.model_type = param.model_type + self.artifact_path = param.artifact_path + res_dir = os.path.dirname(param.artifact_path) + model = NeuralCNNForImageClassification.from_pretrained(os.path.join(res_dir, 'model_toy')) + + self.param_artifact_preprocessing, _ = load_ckpt(self.load_func, param.artifact_path) + if "default" in self.model_type: + model.channel_selection = True + model.input_channels = 1 + if self.model_type == "default_cpu": + param.device = "cpu" + elif self.model_type == "default_gpu": + param.device = "cuda:0" + else: + raise ValueError("Model type not supported!") + self.device = param.device if torch.cuda.is_available() else "cpu" + self.model_toy = model.to(self.device) + self.preprocessing_artifact = PreProcessing.from_param(self.param_artifact_preprocessing) + + def artifact_detection(self, biomarker_features, ignore_region, threshold=0.5): + if not self.model_toy: raise ValueError("Please load artifact model first!") - return self._classify_artifacts(self.model_a, HFO_features, ignore_region, threshold=threshold) + return self._classify_artifacts(self.model_toy, biomarker_features, ignore_region, threshold=threshold) - def spike_detection(self, HFO_features): + def spike_detection(self, biomarker_features): if not self.model_s: raise ValueError("Please load spike model first!") - return self._classify_spikes(self.model_s, HFO_features) + return self._classify_spikes(self.model_s, biomarker_features) - def _classify_artifacts(self, model, HFO_feature, ignore_region, threshold=0.5): + def _classify_artifacts(self, model, biomarker_feature, ignore_region, threshold=0.5): model = model.to(self.device) - features = self.preprocessing_artifact.process_hfo_feature(HFO_feature) + features = self.preprocessing_artifact.process_biomarker_feature(biomarker_feature) artifact_predictions = np.zeros(features.shape[0]) -1 - starts = HFO_feature.starts - ends = HFO_feature.ends + starts = biomarker_feature.starts + ends = biomarker_feature.ends keep_index = np.where(np.logical_and(starts > ignore_region[0], ends < ignore_region[1]) == True)[0] features = features[keep_index] if len(features) != 0: - predictions = inference(model, features, self.device ,self.batch_size, threshold=threshold) + predictions = inference(model, features, self.device, self.batch_size, threshold=threshold) artifact_predictions[keep_index] = predictions - HFO_feature.update_artifact_pred(artifact_predictions) - return HFO_feature + biomarker_feature.update_artifact_pred(artifact_predictions) + return biomarker_feature - def _classify_spikes(self, model, HFO_feature): - if len(HFO_feature.artifact_predictions) == 0: + def _classify_spikes(self, model, biomarker_feature): + if len(biomarker_feature.artifact_predictions) == 0: raise ValueError("Please run artifact classifier first!") model = model.to(self.device) - features = self.preprocessing_spike.process_hfo_feature(HFO_feature) + features = self.preprocessing_spike.process_biomarker_feature(biomarker_feature) spike_predictions = np.zeros(features.shape[0]) -1 - keep_index = np.where(HFO_feature.artifact_predictions > 0)[0] + keep_index = np.where(biomarker_feature.artifact_predictions > 0)[0] features = features[keep_index] if len(features) != 0: predictions = inference(model, features, self.device, self.batch_size) spike_predictions[keep_index] = predictions - HFO_feature.update_spike_pred(spike_predictions) - return HFO_feature + biomarker_feature.update_spike_pred(spike_predictions) + return biomarker_feature diff --git a/src/controllers/__init__.py b/src/controllers/__init__.py new file mode 100644 index 0000000..d2f89ef --- /dev/null +++ b/src/controllers/__init__.py @@ -0,0 +1,4 @@ +from .mini_plot_controller import MiniPlotController +from .main_waveform_plot_controller import MainWaveformPlotController +from .annotation_controller import AnnotationController +# from .main_window_controller import MainWindowController diff --git a/src/controllers/annotation_controller.py b/src/controllers/annotation_controller.py new file mode 100644 index 0000000..bbb5211 --- /dev/null +++ b/src/controllers/annotation_controller.py @@ -0,0 +1,51 @@ +import numpy as np +from src.models.annotation_model import AnnotationModel +from src.views.annotation_view import AnnotationView + + +class AnnotationController: + def __init__(self, annotation_widget, backend=None): + self.model = AnnotationModel(backend) + self.view = AnnotationView(annotation_widget) + + # define window length + if self.model.backend.biomarker_type == "HFO": + self.interval = 1.0 + elif self.model.backend.biomarker_type == "Spindle": + self.interval = 4.0 + + def create_waveform_plot(self): + self.model.create_waveform_plot() + self.view.add_widget('VisulaizationVerticalLayout', self.model.waveform_plot) + channel, start, end = self.get_current_event() + self.model.waveform_plot.plot(start, end, channel, interval=self.interval) # Default interval + + def create_fft_plot(self): + self.model.create_fft_plot() + self.view.add_widget('FFT_layout', self.model.fft_plot) + channel, start, end = self.get_current_event() + self.model.fft_plot.plot(start, end, channel, interval=self.interval) # Default interval + + def update_plots(self, start, end, channel, interval): + self.model.waveform_plot.plot(start, end, channel, interval=interval) + self.model.fft_plot.plot(start, end, channel, interval=interval) + + def get_current_event(self): + channel, start, end = self.model.get_current_event() + return channel, start, end + + def get_previous_event(self): + channel, start, end = self.model.get_previous_event() + return channel, start, end + + def get_next_event(self): + channel, start, end = self.model.get_next_event() + return channel, start, end + + def get_jumped_event(self, index): + channel, start, end = self.model.get_jumped_event(index) + return channel, start, end + + def set_doctor_annotation(self, ann): + selected_index, item_text = self.model.set_doctor_annotation(ann) + return selected_index, item_text diff --git a/src/controllers/main_waveform_plot_controller.py b/src/controllers/main_waveform_plot_controller.py new file mode 100644 index 0000000..a311c28 --- /dev/null +++ b/src/controllers/main_waveform_plot_controller.py @@ -0,0 +1,116 @@ +from src.models.main_waveform_plot_model import MainWaveformPlotModel +from src.views.main_waveform_plot_view import MainWaveformPlotView +import pyqtgraph as pg +import numpy as np + + +class MainWaveformPlotController: + def __init__(self, main_waveform_plot_widget, backend): + self.model = MainWaveformPlotModel(backend) + self.view = MainWaveformPlotView(main_waveform_plot_widget) + + def init_eeg_data(self): + self.model.init_eeg_data() + + def clear(self): + self.view.clear() + + def init_waveform_display(self): + self.view.enable_axis_information() + + def plot_one_channel(self, x, y, color, width=5): + self.view.plot_waveform(x, y, color, width) + + def update_backend(self, new_backend): + self.model.update_backend(new_backend) + + def set_time_window(self, time_window:int): + self.model.set_time_window(time_window) + + def set_n_channels_to_plot(self, n_channels_to_plot:int): + self.model.set_n_channels_to_plot(n_channels_to_plot) + + def set_channel_indices_to_plot(self, channel_indices_to_plot:list): + self.model.set_channel_indices_to_plot(channel_indices_to_plot) + + def set_channels_to_plot(self, channels_to_plot:list): + self.model.set_channels_to_plot(channels_to_plot) + + def set_normalize_vertical(self, normalize_vertical:bool): + self.model.set_normalize_vertical(normalize_vertical) + + def update_channel_names(self, new_channel_names): + self.model.update_channel_names(new_channel_names) + + def set_first_channel_to_plot(self, first_channel_to_plot:int): + self.model.set_first_channel_to_plot(first_channel_to_plot) + + def set_waveform_filter(self, filtered:bool): + self.model.set_waveform_filter(filtered) + + def get_first_channel_to_plot(self): + return self.model.first_channel_to_plot + + def get_current_eeg_data_to_display(self): + return self.model.get_all_current_eeg_data_to_display() + + def set_current_time_window(self, start_in_time): + self.model.set_current_time_window(start_in_time) + + def get_current_start_end(self): + return self.model.get_current_start_end() + + def get_current_time_window(self): + return self.model.get_current_time_window() + + def set_plot_biomarkers(self, plot_biomarkers:bool): + self.model.set_plot_biomarkers(plot_biomarkers) + + def plot_all_current_channels_for_window(self): + eeg_data_to_display, y_100_length, y_scale_length, offset_value = self.get_current_eeg_data_to_display() + time_to_display = self.get_current_time_window() + first_channel_to_plot = self.get_first_channel_to_plot() + n_channels_to_plot = self.model.n_channels_to_plot + waveform_color = self.model.get_waveform_color() + + for disp_i, ch_i in enumerate(range(first_channel_to_plot, first_channel_to_plot + n_channels_to_plot)): + self.view.plot_waveform(time_to_display, eeg_data_to_display[ch_i] - disp_i*offset_value, waveform_color, 0.5) + + return eeg_data_to_display, y_100_length, y_scale_length, offset_value + + def plot_all_current_biomarkers_for_window(self, eeg_data_to_display, offset_value, top_value): + first_channel_to_plot = self.get_first_channel_to_plot() + n_channels_to_plot = self.model.n_channels_to_plot + channels_to_plot = self.model.channels_to_plot + start_in_time, end_in_time = self.get_current_start_end() + + for disp_i, ch_i in enumerate(range(first_channel_to_plot, first_channel_to_plot+n_channels_to_plot)): + channel = channels_to_plot[ch_i] + (biomarker_starts, biomarker_ends, + biomarker_starts_in_time, biomarker_ends_in_time, + windows_in_time, colors) = self.model.get_all_biomarkers_for_all_current_channels_and_color(channel) + + if self.model.plot_biomarkers: + for i in range(len(biomarker_starts)): + event_start = int(biomarker_starts[i]-start_in_time*self.model.sample_freq) + event_end = int(biomarker_ends[i]-start_in_time*self.model.sample_freq) + self.view.plot_waveform(windows_in_time[i], eeg_data_to_display[ch_i, event_start:event_end]-disp_i*offset_value, colors[i], 2) + self.view.plot_waveform([biomarker_starts_in_time[i], biomarker_ends_in_time[i]], [top_value+0.2,top_value+0.2], colors[i], 10) + + + def draw_scale_bar(self, eeg_data_to_display, offset_value, y_100_length, y_scale_length): + start_in_time, end_in_time = self.get_current_start_end() + n_channels_to_plot = self.model.n_channels_to_plot + # Determine the position for the scale indicator (bottom right corner of the plot) + x_pos = end_in_time #+ 0.15 + y_pos = np.min(eeg_data_to_display[-1]) - n_channels_to_plot * offset_value + 0.8 * offset_value + + # Use a dashed line for the scale + self.view.draw_scale_bar(x_pos, y_pos, y_100_length, y_scale_length) + + def draw_channel_names(self, offset_value): + n_channels_to_plot = self.model.n_channels_to_plot + channels_to_plot = self.model.channels_to_plot + first_channel_to_plot = self.get_first_channel_to_plot() + start_in_time, end_in_time = self.get_current_start_end() + self.view.draw_channel_names(offset_value, n_channels_to_plot, channels_to_plot, first_channel_to_plot, start_in_time, end_in_time) \ No newline at end of file diff --git a/src/controllers/main_window_controller.py b/src/controllers/main_window_controller.py new file mode 100644 index 0000000..bc89cd6 --- /dev/null +++ b/src/controllers/main_window_controller.py @@ -0,0 +1,91 @@ +from src.utils.utils_gui import * + + +class MainWindowController: + def __init__(self, view, model): + self.model = model + self.view = view + + self.supported_biomarker = { + 'HFO': self.create_hfo_window, + 'Spindle': self.create_spindle_window, + 'Spike': self.create_spike_window, + } + + def init_biomarker_window(self, biomarker_type): + # To dynamically create frame for different biomarkers, need first init (optimize later) + self.view.window.frame_biomarker_layout = QHBoxLayout(self.view.window.frame_biomarker_type) + self.supported_biomarker[biomarker_type]() + + def init_general_window(self): + self.view.init_general_window() + + self.model.init_error_terminal_display() + self.model.init_menu_bar() + self.model.init_waveform_display() + + def get_biomarker_type(self): + return self.view.get_biomarker_type() + + def set_biomarker_type(self, bio_type): + self.model.set_biomarker_type_and_init_backend(bio_type) + + def init_biomarker_type(self): + default_biomarker = self.get_biomarker_type() + self.set_biomarker_type(default_biomarker) + + safe_connect_signal_slot(self.view.window.combo_box_biomarker.currentIndexChanged, self.switch_biomarker) + # self.view.window.combo_box_biomarker.currentIndexChanged.connect(self.switch_biomarker) + + def switch_biomarker(self): + selected_biomarker = self.get_biomarker_type() + self.supported_biomarker[selected_biomarker]() + + def create_hfo_window(self): + # set biomarker type + self.set_biomarker_type('HFO') + + # create detection parameters stacked widget + self.view.create_stacked_widget_detection_param('HFO') + + # create biomarker typ frame widget + self.view.create_frame_biomarker('HFO') + + # manage flag + self.view.window.is_data_filtered = False + self.view.window.quick_detect_open = False + + # create center waveform and mini plot + self.model.create_center_waveform_and_mini_plot() + + # connect signal & slot + self.model.connect_signal_and_slot('HFO') + + # init params + self.model.init_param('HFO') + + def create_spindle_window(self): + # set biomarker type + self.set_biomarker_type('Spindle') + + # create detection parameters stacked widget + self.view.create_stacked_widget_detection_param('Spindle') + + # create biomarker typ frame widget + self.view.create_frame_biomarker('Spindle') + + # manage flag + self.view.window.is_data_filtered = False + self.view.window.quick_detect_open = False + + # create center waveform and mini plot + self.model.create_center_waveform_and_mini_plot() + + # connect signal & slot + self.model.connect_signal_and_slot('Spindle') + + # init params + self.model.init_param('Spindle') + + def create_spike_window(self): + print('not implemented yet') diff --git a/src/controllers/mini_plot_controller.py b/src/controllers/mini_plot_controller.py new file mode 100644 index 0000000..246ee95 --- /dev/null +++ b/src/controllers/mini_plot_controller.py @@ -0,0 +1,69 @@ +import sys +from src.models.mini_plot_model import MiniPlotModel +from src.views.mini_plot_view import MiniPlotView + + +class MiniPlotController: + def __init__(self, mini_plot_widget, backend): + self.model = MiniPlotModel(backend) + self.view = MiniPlotView(mini_plot_widget) + + def clear(self): + self.view.clear() + + def init_biomarker_display(self): + self.view.enable_axis_information() + self.view.add_linear_region() + + def init_eeg_data(self): + self.model.init_eeg_data() + + def get_first_channel_to_plot(self): + return self.model.first_channel_to_plot + + def set_channel_indices_to_plot(self, channel_indices_to_plot): + self.model.set_channel_indices_to_plot(channel_indices_to_plot) + + def set_channels_to_plot(self, channels_to_plot): + self.model.set_channels_to_plot(channels_to_plot) + + def set_n_channels_to_plot(self, n_channels_to_plot): + self.model.set_n_channels_to_plot(n_channels_to_plot) + + def set_first_channel_to_plot(self, first_channel_to_plot): + self.model.set_first_channel_to_plot(first_channel_to_plot) + + def update_channel_names(self, new_channel_names): + self.model.update_channel_names(new_channel_names) + + def plot_one_biomarker(self, start_time, end_time, height, color, width=5): + self.view.plot_biomarker(start_time, end_time, height, color, width) + + def plot_all_current_biomarkers_for_one_channel(self, channel, plot_height): + starts_in_time, ends_in_time, colors = self.model.get_all_biomarkers_for_channel_and_color(channel) + + for i in range(len(starts_in_time)): + self.plot_one_biomarker(starts_in_time[i], ends_in_time[i], plot_height, colors[i], 5) + + def plot_all_current_biomarkers_for_all_channels(self, plot_height): + first_channel_to_plot = self.get_first_channel_to_plot() + n_channels_to_plot = self.model.n_channels_to_plot + channels_to_plot = self.model.channels_to_plot + for disp_i, ch_i in enumerate(range(first_channel_to_plot, first_channel_to_plot+n_channels_to_plot)): + channel = channels_to_plot[ch_i] + self.plot_all_current_biomarkers_for_one_channel(channel, plot_height) + + def set_miniplot_title(self, title, height): + self.view.set_miniplot_title(title, height) + + def set_total_x_y_range(self, top_value): + time_max = int(self.model.time.shape[0] / self.model.sample_freq) + self.view.set_x_y_range([0, time_max], [top_value-0.25, top_value+0.25]) + + def update_highlight_window(self, start, end, height): + self.view.update_highlight_window(start, end, height) + + def update_backend(self, new_backend): + self.model.update_backend(new_backend) + + \ No newline at end of file diff --git a/src/dl_models/__init__.py b/src/dl_models/__init__.py new file mode 100644 index 0000000..24c938b --- /dev/null +++ b/src/dl_models/__init__.py @@ -0,0 +1,2 @@ +from .configuration_neuralcnn import ResnetConfig +from .modeling_neuralcnn import NeuralCNNModel, NeuralCNNForImageClassification \ No newline at end of file diff --git a/src/dl_models/configuration_neuralcnn.py b/src/dl_models/configuration_neuralcnn.py new file mode 100644 index 0000000..fe110f7 --- /dev/null +++ b/src/dl_models/configuration_neuralcnn.py @@ -0,0 +1,27 @@ +from transformers import PretrainedConfig +from typing import List + +class ResnetConfig(PretrainedConfig): + model_type = "resnet" + def __init__( + self, + layers: List[int] = [3, 4, 6, 3], + num_classes: int = 1000, + input_channels: int = 1, + hidden_size: int = 64, + kernel_size: int = 7, + stride: int = 2, + padding: int = 3, + freeze: bool = False, + channel_selection: bool = True, + **kwargs, + ): + self.num_classes = num_classes + self.input_channels = input_channels + self.hidden_size = hidden_size + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.freeze = freeze + self.channel_selection = channel_selection + super().__init__(**kwargs) \ No newline at end of file diff --git a/src/dl_models/modeling_neuralcnn.py b/src/dl_models/modeling_neuralcnn.py new file mode 100644 index 0000000..d09cd5e --- /dev/null +++ b/src/dl_models/modeling_neuralcnn.py @@ -0,0 +1,120 @@ +import collections.abc +import math +from transformers import PretrainedConfig +from transformers import PreTrainedModel +from .configuration_neuralcnn import ResnetConfig +import torch +import torch.nn as nn +import torchvision.models as models +from torchvision.models import ResNet18_Weights +from typing import Dict, List, Optional, Set, Tuple, Union +from torch.nn import BCELoss, BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + + +class NeuralCNNModel(PreTrainedModel): + config_class = ResnetConfig + + # def __init__(self, in_channels, outputs, freeze = False, channel_selection = True): + def __init__(self, config): + super().__init__(config) + """ + In the constructor we instantiate two nn.Linear modules and assign them as + member variables. + """ + self.input_channels = config.input_channels + self.outputs = config.num_classes + self.channel_selection = config.channel_selection + self.hidden_size = config.hidden_size + self.kernel_size = config.kernel_size + self.stride = config.stride + self.padding = config.padding + self.cnn = models.resnet18(weights=ResNet18_Weights.DEFAULT) + self.cnn.conv1 = nn.Conv2d(self.input_channels, self.hidden_size, self.kernel_size, self.stride, self.padding, + bias=False) + self.cnn.fc = nn.Sequential(nn.Linear(512, self.hidden_size // 2)) + for param in self.cnn.fc.parameters(): + param.requires_grad = not config.freeze + self.bn0 = nn.BatchNorm1d(self.hidden_size // 2) + self.relu0 = nn.LeakyReLU() + self.fc = nn.Linear(self.hidden_size // 2, self.hidden_size // 2) + self.bn = nn.BatchNorm1d(self.hidden_size // 2) + self.relu = nn.LeakyReLU() + self.fc1 = nn.Linear(self.hidden_size // 2, self.hidden_size // 4) + self.bn1 = nn.BatchNorm1d(self.hidden_size // 4) + self.relu1 = nn.LeakyReLU() + + self.fc_out = nn.Linear(self.hidden_size // 4, self.outputs) + self.final_ac = nn.Sigmoid() + self.criterion = nn.BCELoss() + + def forward(self, x): + """ + In the forward function we accept a Tensor of input data and we must return + a Tensor of output data. We can use Modules defined in the constructor as + well as arbitrary operators on Tensors. + """ + batch = self.cnn(x) + batch = self.bn(self.relu(self.fc(batch))) + batch = self.bn1(self.relu1(self.fc1(batch))) + logits = self.final_ac(self.fc_out(batch)) + + return logits + + +class NeuralCNNForImageClassification(PreTrainedModel): + config_class = ResnetConfig + + # def __init__(self, in_channels, outputs, freeze = False, channel_selection = True): + def __init__(self, config): + super().__init__(config) + """ + In the constructor we instantiate two nn.Linear modules and assign them as + member variables. + """ + self.input_channels = config.input_channels + self.outputs = config.num_classes + self.channel_selection = config.channel_selection + self.hidden_size = config.hidden_size + self.kernel_size = config.kernel_size + self.stride = config.stride + self.padding = config.padding + self.cnn = models.resnet18(weights=ResNet18_Weights.DEFAULT) + self.cnn.conv1 = nn.Conv2d(self.input_channels, self.hidden_size, self.kernel_size, self.stride, self.padding, + bias=False) + self.cnn.fc = nn.Linear(512, self.hidden_size // 2) + for param in self.cnn.fc.parameters(): + param.requires_grad = not config.freeze + self.bn0 = nn.BatchNorm1d(self.hidden_size // 2) + self.relu0 = nn.LeakyReLU() + self.fc = nn.Linear(self.hidden_size // 2, self.hidden_size // 2) + self.bn = nn.BatchNorm1d(self.hidden_size // 2) + self.relu = nn.LeakyReLU() + self.fc1 = nn.Linear(self.hidden_size // 2, self.hidden_size // 4) + self.bn1 = nn.BatchNorm1d(self.hidden_size // 4) + self.relu1 = nn.LeakyReLU() + + self.fc_out = nn.Linear(self.hidden_size // 4, self.outputs) + self.final_ac = nn.Sigmoid() + + def forward(self, + input_features: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + ): + """ + In the forward function we accept a Tensor of input data and we must return + a Tensor of output data. We can use Modules defined in the constructor as + well as arbitrary operators on Tensors. + """ + if self.input_channels == 1: + input_features = input_features[:, 0:1, :, :] + batch = self.cnn(input_features) + batch = self.bn(self.relu(self.fc(batch))) + batch = self.bn1(self.relu1(self.fc1(batch))) + logits = self.fc_out(batch) + out = self.final_ac(logits) + # if labels is not None: + # loss_fct = BCEWithLogitsLoss() + # loss = loss_fct(logits, labels) + # return {"loss": loss, "logits": logits} + # return {"logits": logits} + return out diff --git a/src/hfo_app.py b/src/hfo_app.py index 2f7ce11..df0d484 100644 --- a/src/hfo_app.py +++ b/src/hfo_app.py @@ -7,7 +7,7 @@ from src.classifer import Classifier from src.utils.utils_feature import * from src.utils.utils_filter import construct_filter, filter_data -from src.utils.utils_detector import set_STE_detector, set_MNI_detector +from src.utils.utils_detector import set_STE_detector, set_MNI_detector, set_HIL_detector from src.utils.utils_io import get_edf_info, read_eeg_data, dump_to_npz from src.utils.utils_plotting import plot_feature @@ -22,11 +22,13 @@ class HFO_App(object): def __init__(self): self.version = "1.0.0" + self.biomarker_type = 'HFO' self.n_jobs = 4 ## eeg related self.eeg_data = None self.raw = None self.channel_names = None + self.event_channel_names = None self.sample_freq = 0 # Hz self.edf_param = None @@ -49,7 +51,7 @@ def __init__(self): ## feature related self.feature_param = None - self.hfo_features = None + self.event_features = None ## classifier related self.param_classifier = None @@ -216,7 +218,7 @@ def set_detector(self, param:ParamDetector): param should be a type of ParamDetector param/param_detector.py it should contain the following fields: - detector_type: str, "STE" or "MNI" + detector_type: str, "STE", "MNI" or "HIL" detector_param: param_detector/ParamSTE or param_detector/ParamMNI ''' @@ -230,9 +232,9 @@ def set_detector(self, param:ParamDetector): self.detector = set_STE_detector(param.detector_param) elif param.detector_type.lower() == "mni": self.detector = set_MNI_detector(param.detector_param) - # print(param.detector_param.to_dict()) - - def detect_HFO(self, param_filter:ParamFilter =None, param_detector:ParamDetector=None): + elif param.detector_type.lower() == "hil": + self.detector = set_HIL_detector(param.detector_param) + def detect_biomarker(self, param_filter:ParamFilter =None, param_detector:ParamDetector=None): ''' This the function should be linked to the detect button in the overview window, it can also be called with a param to set the detector, the detector will be reseted if the param is not None @@ -244,8 +246,8 @@ def detect_HFO(self, param_filter:ParamFilter =None, param_detector:ParamDetecto self.set_detector(param_detector) if self.filter_data is None or len(self.filter_data) == 0: self.filter_eeg_data() - self.channel_names, self.HFOs = self.detector.detect_multi_channels(self.filter_data, self.channel_names, filtered=True) - self.hfo_features = HFO_Feature.construct(self.channel_names, self.HFOs, self.param_detector.detector_type, self.sample_freq) + self.event_channel_names, self.HFOs = self.detector.detect_multi_channels(self.filter_data, self.channel_names, filtered=True) + self.event_features = HFO_Feature.construct(self.event_channel_names, self.HFOs, self.param_detector.detector_type, self.sample_freq) self.detected = True ''' @@ -261,24 +263,24 @@ def generate_HFO_features(self): win_size = 224 time_range = [0, 1000] # 0~1000ms - starts = self.hfo_features.starts - ends = self.hfo_features.ends - channel_names = self.hfo_features.channel_names - hfo_waveforms = extract_waveforms(self.eeg_data, starts, ends, channel_names, self.channel_names, self.sample_freq, time_range) - param_list = [{"start":starts[i], "end":ends[i], "data":hfo_waveforms[i], "channel_name":channel_names[i], + starts = self.event_features.starts + ends = self.event_features.ends + channel_names = self.event_features.channel_names + biomarker_waveforms = extract_waveforms(self.eeg_data, starts, ends, channel_names, self.channel_names, self.sample_freq, time_range) + param_list = [{"start":starts[i], "end":ends[i], "data":biomarker_waveforms[i], "channel_name":channel_names[i], "sample_rate": self.sample_freq, "win_size": win_size, "ps_MinFreqHz": freq_range[0], "ps_MaxFreqHz": freq_range[1], "time_window_ms" : (time_range[1] - time_range[0])//2, } for i in range(len(starts))] - ret = parallel_process(param_list, compute_hfo_feature, n_jobs=self.n_jobs, use_kwargs=True, front_num=2) + ret = parallel_process(param_list, compute_biomarker_feature, n_jobs=self.n_jobs, use_kwargs=True, front_num=2) starts, ends, channel_names, time_frequncy_img, amplitude_coding_plot = np.zeros(len(ret)), np.zeros(len(ret)), np.empty(len(ret), dtype= object), np.zeros((len(ret), win_size,win_size)), np.zeros((len(ret), win_size, win_size)) for i in range(len(ret)): channel_names[i], starts[i], ends[i], time_frequncy_img[i], amplitude_coding_plot[i] = ret[i] interval = np.concatenate([starts[:, None], ends[:, None]], axis=1) feature = np.concatenate([time_frequncy_img[:, None, :, :], amplitude_coding_plot[:, None, :, :]], axis=1) - self.hfo_features = HFO_Feature(channel_names, interval, feature, sample_freq = self.sample_freq, HFO_type=self.param_detector.detector_type, feature_size=win_size, freq_range=freq_range, time_range=time_range) + self.event_features = HFO_Feature(channel_names, interval, feature, sample_freq = self.sample_freq, HFO_type=self.param_detector.detector_type, feature_size=win_size, freq_range=freq_range, time_range=time_range) ''' @@ -352,17 +354,17 @@ def set_default_gpu_classifier(self): def classify_artifacts(self, ignore_region = [1, 1], threshold=0.5): - if not self.hfo_features.has_feature(): + if not self.event_features.has_feature(): self.generate_HFO_features() ignore_region = np.array(ignore_region) * self.sample_freq ignore_region = np.array([ignore_region[0], len(self.eeg_data[0]) - ignore_region[1]]) - self.classifier.artifact_detection(self.hfo_features, ignore_region, threshold=threshold) + self.classifier.artifact_detection(self.event_features, ignore_region, threshold=threshold) self.classified = True def classify_spikes(self): - if not self.hfo_features.has_feature(): + if not self.event_features.has_feature(): self.generate_HFO_features() - self.classifier.spike_detection(self.hfo_features) + self.classifier.spike_detection(self.event_features) ''' results APIs @@ -372,24 +374,24 @@ def get_res_overview(self): ''' return the overview of the results ''' - if not self.hfo_features.has_feature(): + if not self.event_features.has_feature(): self.generate_HFO_features() return { - "n_HFO": self.hfo_features.num_HFO, - "n_artifact": self.hfo_features.num_artifact, - "n_real": self.hfo_features.num_real, - "n_spike": self.hfo_features.num_spike + "n_HFO": self.event_features.num_HFO, + "n_artifact": self.event_features.num_artifact, + "n_real": self.event_features.num_real, + "n_spike": self.event_features.num_spike } def export_report(self, path): - if not self.hfo_features: + if not self.event_features: return None - self.hfo_features.export_csv(path) + self.event_features.export_csv(path) def export_excel(self, path): - if not self.hfo_features: + if not self.event_features: return None - self.hfo_features.export_excel(path) + self.event_features.export_excel(path) def export_app(self, path): ''' @@ -404,16 +406,16 @@ def export_app(self, path): "param_filter": self.param_filter.to_dict() if self.param_filter else None, "HFOs": self.HFOs, "param_detector": self.param_detector.to_dict() if self.param_detector else None, - "HFO_features": self.hfo_features.to_dict() if self.hfo_features else None, + "HFO_features": self.event_features.to_dict() if self.event_features else None, "param_classifier": self.param_classifier.to_dict() if self.param_classifier else None, "classified": self.classified, "filtered": self.filtered, "detected": self.detected, - "artifact_predictions": np.array(self.hfo_features.artifact_predictions), - "spike_predictions": np.array(self.hfo_features.spike_predictions), - "artifact_annotations": np.array(self.hfo_features.artifact_annotations), - "spike_annotations": np.array(self.hfo_features.spike_annotations), - "annotated": np.array(self.hfo_features.annotated), + "artifact_predictions": np.array(self.event_features.artifact_predictions), + "spike_predictions": np.array(self.event_features.spike_predictions), + "artifact_annotations": np.array(self.event_features.artifact_annotations), + "spike_annotations": np.array(self.event_features.spike_annotations), + "annotated": np.array(self.event_features.annotated), } dump_to_npz(checkpoint, path) @@ -432,11 +434,11 @@ def import_app(path): app.classified = checkpoint["classified"].item() app.filtered = checkpoint["filtered"].item() app.detected = checkpoint["detected"].item() - app.hfo_features.artifact_predictions = checkpoint["artifact_predictions"].item() - app.hfo_features.spike_predictions = checkpoint["spike_predictions"].item() - app.hfo_features.artifact_annotations = checkpoint["artifact_annotations"].item() - app.hfo_features.spike_annotations = checkpoint["spike_annotations"].item() - app.hfo_features.annotated = checkpoint["annotated"].item() + app.event_features.artifact_predictions = checkpoint["artifact_predictions"].item() + app.event_features.spike_predictions = checkpoint["spike_predictions"].item() + app.event_features.artifact_annotations = checkpoint["artifact_annotations"].item() + app.event_features.spike_annotations = checkpoint["spike_annotations"].item() + app.event_features.annotated = checkpoint["annotated"].item() if app.filtered: app.param_filter = ParamFilter.from_dict(checkpoint["param_filter"].item()) app.filter_eeg_data(app.param_filter) @@ -445,8 +447,8 @@ def import_app(path): app.HFOs = checkpoint["HFOs"] app.param_detector = ParamDetector.from_dict(checkpoint["param_detector"].item()) #print("new HFO features") - app.hfo_features = HFO_Feature.from_dict(checkpoint["HFO_features"].item()) - # print(app.hfo_features) + app.event_features = HFO_Feature.from_dict(checkpoint["HFO_features"].item()) + # print(app.event_features) if app.classified: app.param_classifier = ParamClassifier.from_dict(checkpoint["param_classifier"].item()) return app @@ -463,34 +465,34 @@ def extract_data(data, data_filtered, start, end): data_filtered = np.squeeze(data_filtered) if start < self.sample_freq // 2: plot_start, plot_end = 0, self.sample_freq - hfo_start, hfo_end = start, min(end, self.sample_freq) + biomarker_start, biomarker_end = start, min(end, self.sample_freq) elif end > len(data) - self.sample_freq // 2: plot_start, plot_end = len(data) - self.sample_freq, len(data) - hfo_start, hfo_end = max(plot_start, start) - plot_start, min(plot_end, end) - plot_start + biomarker_start, biomarker_end = max(plot_start, start) - plot_start, min(plot_end, end) - plot_start else: plot_start, plot_end = (start + end)//2-self.sample_freq // 2, (start+end)//2+self.sample_freq // 2 - hfo_start, hfo_end = max(plot_start, start) - plot_start, min(plot_end, end) - plot_start - plot_start, plot_end, hfo_start, hfo_end = int(plot_start), int(plot_end), int(hfo_start), int(hfo_end) + biomarker_start, biomarker_end = max(plot_start, start) - plot_start, min(plot_end, end) - plot_start + plot_start, plot_end, biomarker_start, biomarker_end = int(plot_start), int(plot_end), int(biomarker_start), int(biomarker_end) channel_data = data[plot_start:plot_end] channel_data_f = data_filtered[plot_start:plot_end] #print(hfo_start, hfo_end, start, end, plot_start, plot_end, channel_data.shape, channel_data_f.shape) - return channel_data, channel_data_f, hfo_start, hfo_end + return channel_data, channel_data_f, biomarker_start, biomarker_end def extract_waveform(data, data_filtered, starts, ends, channel_names, unique_channel_names): - hfo_waveform_l, hfo_waveform_f_l, hfo_start_l , hfo_end_l = np.zeros((len(starts), 2000)), np.zeros((len(starts), 2000)), [], [] + biomarker_waveform_l, biomarker_waveform_f_l, biomarker_start_l , biomarker_end_l = np.zeros((len(starts), 2000)), np.zeros((len(starts), 2000)), [], [] for i in tqdm(range(len(starts))): channel_name = channel_names[i] start = starts[i] end = ends[i] channel_index = np.where(unique_channel_names == channel_name)[0] - hfo_waveform, hfo_waveform_f, hfo_start, hfo_end = extract_data(data[channel_index], data_filtered[channel_index], start, end) - hfo_waveform_l[i] = hfo_waveform - hfo_waveform_f_l[i] = hfo_waveform_f - hfo_start_l.append(hfo_start) - hfo_end_l.append(hfo_end) - return hfo_waveform_l, hfo_waveform_f_l, np.array(hfo_start_l), np.array(hfo_end_l) - - if not self.hfo_features: + biomarker_waveform, biomarker_waveform_f, biomarker_start, biomarker_end = extract_data(data[channel_index], data_filtered[channel_index], start, end) + biomarker_waveform_l[i] = biomarker_waveform + biomarker_waveform_f_l[i] = biomarker_waveform_f + biomarker_start_l.append(biomarker_start) + biomarker_end_l.append(biomarker_end) + return biomarker_waveform_l, biomarker_waveform_f_l, np.array(biomarker_start_l), np.array(biomarker_end_l) + + if not self.event_features: return None os.makedirs(folder, exist_ok=True) artifact_folder = os.path.join(folder, "artifact") @@ -499,11 +501,11 @@ def extract_waveform(data, data_filtered, starts, ends, channel_names, unique_ch clean_folder(artifact_folder) clean_folder(spike_folder) clean_folder(non_spike_folder) - starts = self.hfo_features.starts - ends = self.hfo_features.ends - feature = self.hfo_features.features - channel_names = self.hfo_features.channel_names - spike_predictions = self.hfo_features.spike_predictions + starts = self.event_features.starts + ends = self.event_features.ends + feature = self.event_features.features + channel_names = self.event_features.channel_names + spike_predictions = self.event_features.spike_predictions index_s = np.where(spike_predictions == 1)[0] start_s, end_s, feature_s, channel_names_s = starts[index_s], ends[index_s], feature[index_s], channel_names[index_s] index_a = np.where(spike_predictions == -1)[0] @@ -511,12 +513,12 @@ def extract_waveform(data, data_filtered, starts, ends, channel_names, unique_ch index_r = np.where(spike_predictions == 0)[0] start_r, end_r, feature_r, channel_names_r = starts[index_r], ends[index_r], feature[index_r], channel_names[index_r] #print("plotting HFO with spike") - waveform_s, waveform_f_s, hfo_start_s, hfo_end_s = extract_waveform(self.eeg_data, self.filter_data, start_s, end_s, channel_names_s, self.channel_names) - param_list = [{"folder": spike_folder, "start": start_s[i], "end": end_s[i], "feature": feature_s[i], "channel_name": channel_names_s[i], "data":waveform_s[i], "data_filtered":waveform_f_s[i], "hfo_start":hfo_start_s[i], "hfo_end":hfo_end_s[i]} for i in range(len(start_s))] + waveform_s, waveform_f_s, biomarker_start_s, biomarker_end_s = extract_waveform(self.eeg_data, self.filter_data, start_s, end_s, channel_names_s, self.channel_names) + param_list = [{"folder": spike_folder, "start": start_s[i], "end": end_s[i], "feature": feature_s[i], "channel_name": channel_names_s[i], "data":waveform_s[i], "data_filtered":waveform_f_s[i], "hfo_start":biomarker_start_s[i], "hfo_end":biomarker_end_s[i]} for i in range(len(start_s))] ret = parallel_process(param_list, plot_feature, self.n_jobs, use_kwargs=True, front_num=3) - waveform_a, waveform_f_a, hfo_start_a, hfo_end_a = extract_waveform(self.eeg_data, self.filter_data, start_a, end_a, channel_names_a, self.channel_names) - param_list = [{"folder": artifact_folder, "start": start_a[i], "end": end_a[i], "feature": feature_a[i], "channel_name": channel_names_a[i], "data":waveform_a[i], "data_filtered":waveform_f_a[i], "hfo_start":hfo_start_a[i], "hfo_end":hfo_end_a[i]} for i in range(len(start_a))] + waveform_a, waveform_f_a, biomarker_start_a, biomarker_end_a = extract_waveform(self.eeg_data, self.filter_data, start_a, end_a, channel_names_a, self.channel_names) + param_list = [{"folder": artifact_folder, "start": start_a[i], "end": end_a[i], "feature": feature_a[i], "channel_name": channel_names_a[i], "data":waveform_a[i], "data_filtered":waveform_f_a[i], "hfo_start":biomarker_start_a[i], "hfo_end":biomarker_end_a[i]} for i in range(len(start_a))] ret = parallel_process(param_list, plot_feature, self.n_jobs, use_kwargs=True, front_num=3) - waveform_r, waveform_f_r, hfo_start_r, hfo_end_r = extract_waveform(self.eeg_data, self.filter_data, start_r, end_r, channel_names_r, self.channel_names) - param_list = [{"folder": non_spike_folder, "start": start_r[i], "end": end_r[i], "feature": feature_r[i], "channel_name": channel_names_r[i], "data":waveform_r[i], "data_filtered":waveform_f_r[i], "hfo_start":hfo_start_r[i], "hfo_end":hfo_end_r[i]} for i in range(len(start_r))] + waveform_r, waveform_f_r, biomarker_start_r, biomarker_end_r = extract_waveform(self.eeg_data, self.filter_data, start_r, end_r, channel_names_r, self.channel_names) + param_list = [{"folder": non_spike_folder, "start": start_r[i], "end": end_r[i], "feature": feature_r[i], "channel_name": channel_names_r[i], "data":waveform_r[i], "data_filtered":waveform_f_r[i], "hfo_start":biomarker_start_r[i], "hfo_end":biomarker_end_r[i]} for i in range(len(start_r))] ret = parallel_process(param_list, plot_feature, self.n_jobs, use_kwargs=True, front_num=3) diff --git a/src/hfo_feature.py b/src/hfo_feature.py index 7e2ce75..92bc4ea 100644 --- a/src/hfo_feature.py +++ b/src/hfo_feature.py @@ -46,7 +46,7 @@ def construct(channel_names, start_end, HFO_type = "STE", sample_freq = 2000, fr start_end = np.concatenate(start_end) if len(start_end) > 0 else np.array([]) return HFO_Feature(channel_names, start_end, np.array([]), HFO_type, sample_freq, freq_range, time_range, feature_size) - def get_num_HFO(self): + def get_num_biomarker(self): return self.num_HFO def has_prediction(self): @@ -155,10 +155,10 @@ def from_dict(data): feature_size = data["feature_size"] freq_range = data["freq_range"] time_range = data["time_range"] - hfo_feature = HFO_Feature(channel_names, np.array([starts, ends]).T, feature, HFO_type, sample_freq, freq_range, time_range, feature_size) - hfo_feature.update_pred(artifact_predictions, spike_predictions) + biomarker_feature = HFO_Feature(channel_names, np.array([starts, ends]).T, feature, HFO_type, sample_freq, freq_range, time_range, feature_size) + biomarker_feature.update_pred(artifact_predictions, spike_predictions) - return hfo_feature + return biomarker_feature def update_artifact_pred(self, artifact_predictions): self.artifact_predicted = True @@ -195,7 +195,7 @@ def group_by_channel(self): spike_predictions_g.append(spike_predictions[channel_index]) return channel_name_g, interval_g, artifact_predictions_g, spike_predictions_g - def get_HFOs_for_channel(self, channel_name:str, min_start:int=None, max_end:int=None): + def get_biomarkers_for_channel(self, channel_name:str, min_start:int=None, max_end:int=None): channel_names = self.channel_names starts = self.starts ends = self.ends diff --git a/src/model.py b/src/model.py index cb3f63f..345ea8b 100644 --- a/src/model.py +++ b/src/model.py @@ -164,7 +164,7 @@ def __call__(self, data): data = self._cropping(data) return data - def process_hfo_feature(self, feature): + def process_biomarker_feature(self, feature): data = feature.get_features() self.freq_range = feature.freq_range self.event_length = max(feature.time_range) diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..1f8896e --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,4 @@ +# from .mini_plot_model import MiniPlotModel +# from .main_waveform_plot_model import MainWaveformPlotModel +# from .annotation_model import AnnotationModel +# from .main_window_model import MainWindowModel diff --git a/src/models/annotation_model.py b/src/models/annotation_model.py new file mode 100644 index 0000000..e06a961 --- /dev/null +++ b/src/models/annotation_model.py @@ -0,0 +1,37 @@ +import numpy as np +from src.hfo_app import HFO_App +from src.ui.annotation_plot import AnnotationPlot, FFTPlot + + +class AnnotationModel: + def __init__(self, backend): + self.backend = backend + + def create_waveform_plot(self): + self.waveform_plot = AnnotationPlot(backend=self.backend) + + def create_fft_plot(self): + self.fft_plot = FFTPlot(backend=self.backend) + + def get_current_event(self): + channel, start, end = self.backend.event_features.get_current() + return channel, start, end + + def get_previous_event(self): + channel, start, end = self.backend.event_features.get_prev() + return channel, start, end + + def get_next_event(self): + channel, start, end = self.backend.event_features.get_next() + return channel, start, end + + def get_jumped_event(self, index): + channel, start, end = self.backend.event_features.get_jump(index) + return channel, start, end + + def set_doctor_annotation(self, ann): + self.backend.event_features.doctor_annotation(ann) + # Update the text of the selected item in the dropdown menu + selected_index = self.backend.event_features.index + item_text = self.backend.event_features.get_annotation_text(selected_index) + return selected_index, item_text diff --git a/src/models/main_waveform_plot_model.py b/src/models/main_waveform_plot_model.py new file mode 100644 index 0000000..cc735d7 --- /dev/null +++ b/src/models/main_waveform_plot_model.py @@ -0,0 +1,143 @@ +import numpy as np +from src.hfo_app import HFO_App +from src.spindle_app import SpindleApp +import sys + + +class MainWaveformPlotModel: + def __init__(self, backend: HFO_App): + self.backend = backend + self.color_dict={"artifact":(245,130,48), #orange + "spike":(240,30,250), #pink + "non_spike":(60,180,75), #green + "HFO":(60,180,75), #green + "waveform": (0,0,255), + } + + self.start_in_time = 0 + self.end_in_time = 20 + self.time_window = 20 #20 second time window + self.first_channel_to_plot = 0 + self.n_channels_to_plot = 10 + self.filtered = False + self.normalize_vertical = False + + def update_backend(self, new_backend): + self.backend = new_backend + + def init_eeg_data(self): + eeg_data, self.channel_names = self.backend.get_eeg_data() + self.edf_info = self.backend.get_edf_info() + self.sample_freq = self.edf_info['sfreq'] + self.time = np.arange(0, eeg_data.shape[1]/self.sample_freq, 1/self.sample_freq) + + self.filtered = False + self.plot_biomarkers = False + self.channel_names = list(self.channel_names) + self.n_channels = len(self.channel_names) + self.n_channels_to_plot = min(self.n_channels,self.n_channels_to_plot) + self.channels_to_plot = self.channel_names.copy() + self.channel_indices_to_plot = np.arange(self.n_channels) + + def set_time_window(self, time_window:int): + self.time_window = time_window + + def set_plot_biomarkers(self, plot_biomarkers:bool): + self.plot_biomarkers = plot_biomarkers + + def set_current_time_window(self, start_in_time): + self.start_in_time = max(start_in_time, 0) + self.end_in_time = min(start_in_time + self.time_window, self.time[-1]) + self.start_in_time, self.end_in_time = self.time[int(start_in_time * self.sample_freq)], self.time[int(self.end_in_time * self.sample_freq)] + + def get_current_start_end(self): + return self.start_in_time, self.end_in_time + + def get_current_time_window(self): + return self.time[int(self.start_in_time * self.sample_freq): int(self.end_in_time * self.sample_freq)] + + def set_first_channel_to_plot(self, first_channel_to_plot): + self.first_channel_to_plot = first_channel_to_plot + + def set_n_channels_to_plot(self, n_channels_to_plot:int): + self.n_channels_to_plot = n_channels_to_plot + + def set_channels_to_plot(self, channels_to_plot:list): + self.channels_to_plot = channels_to_plot + self.channel_indices_to_plot = [self.channel_names.index(channel) for channel in channels_to_plot] + + def set_channel_indices_to_plot(self,channel_indices_to_plot:list): + self.channel_indices_to_plot = channel_indices_to_plot + self.channels_to_plot = [self.channel_names[index] for index in channel_indices_to_plot] + + def update_channel_names(self, new_channel_names): + self.channel_names = list(new_channel_names) + + def set_waveform_filter(self, filtered): + self.filtered = filtered + + def set_normalize_vertical(self, normalize_vertical:bool): + self.normalize_vertical = normalize_vertical + + def get_waveform_color(self): + return self.color_dict["waveform"] + + def get_all_current_eeg_data_to_display(self): + eeg_data_to_display, _ = self.backend.get_eeg_data(int(self.start_in_time*self.sample_freq),int(self.end_in_time*self.sample_freq), self.filtered) + eeg_data_to_display = eeg_data_to_display[self.channel_indices_to_plot,:] + + if self.normalize_vertical: + eeg_data_to_display = (eeg_data_to_display-eeg_data_to_display.min(axis = 1,keepdims = True)) + eeg_data_to_display = eeg_data_to_display/np.max(eeg_data_to_display) + else: + if self.filtered: + means = np.mean(eeg_data_to_display) + self.stds = np.std(eeg_data_to_display) * 2 + eeg_data_to_display = (eeg_data_to_display - means) / self.stds + eeg_data_to_display[np.isnan(eeg_data_to_display)] = 0 + else: + # standardized signal globally + means = np.mean(eeg_data_to_display) + self.stds = np.std(eeg_data_to_display) + eeg_data_to_display = (eeg_data_to_display - means) / self.stds + #replace nans with 0 + eeg_data_to_display[np.isnan(eeg_data_to_display)] = 0 + #shift the ith channel by 1.1*i + # eeg_data_to_display = eeg_data_to_display-1.1*np.arange(eeg_data_to_display.shape[0])[:,None] + if self.filtered: + # Add scale indicators + # Set the length of the scale lines + y_100_length = 50 # 100 microvolts + offset_value = 6 + y_scale_length = y_100_length / self.stds + else: + y_100_length = 100 # 100 microvolts + offset_value = 6 + y_scale_length = y_100_length / self.stds + + return eeg_data_to_display, y_100_length, y_scale_length, offset_value + + def get_all_biomarkers_for_all_current_channels_and_color(self, channel_in_name): + starts, ends, artifacts, spikes = self.backend.event_features.get_biomarkers_for_channel(channel_in_name, int(self.start_in_time*self.sample_freq),int(self.end_in_time*self.sample_freq)) + colors = [] + windows_in_time = [] + + for j in range(len(starts)): + try: + if int(artifacts[j])<1: + color = self.color_dict["artifact"] + elif spikes[j]: + color = self.color_dict["spike"] + else: + color = self.color_dict["non_spike"] + except: + color = self.color_dict["non_spike"] + colors.append(color) + # s_ind, e_ind = np.searchsorted(self.time, starts[j]), np.searchsorted(self.time, ends[j]) + windows_in_time.append(self.time[int(starts[j]):int(ends[j])]) + + starts_in_time = [self.time[int(i)] for i in starts] + ends_in_time = [self.time[min(int(i), len(self.time)-1)] for i in ends] + + return starts, ends, starts_in_time, ends_in_time, windows_in_time, colors + diff --git a/src/models/main_window_model.py b/src/models/main_window_model.py new file mode 100644 index 0000000..cec8b24 --- /dev/null +++ b/src/models/main_window_model.py @@ -0,0 +1,1069 @@ +from pathlib import Path +from queue import Queue + +from PyQt5 import uic +from PyQt5.QtGui import * +from PyQt5.QtCore import * +from PyQt5.QtWidgets import * +from PyQt5.QtWidgets import QMessageBox +from PyQt5.QtCore import pyqtSignal +import ast +import multiprocessing as mp +import torch +from src.hfo_app import HFO_App +from src.spindle_app import SpindleApp +from src.ui.quick_detection import HFOQuickDetector +from src.ui.channels_selection import ChannelSelectionWindow +from src.param.param_classifier import ParamClassifier +from src.param.param_detector import ParamDetector, ParamSTE, ParamMNI, ParamHIL, ParamYASA +from src.param.param_filter import ParamFilter, ParamFilterSpindle +from src.ui.bipolar_channel_selection import BipolarChannelSelectionWindow +from src.ui.annotation import Annotation +from src.utils.utils_gui import * +from src.ui.plot_waveform import * + + +class MainWindowModel(QObject): + def __init__(self, main_window): + super(MainWindowModel, self).__init__() + self.window = main_window + self.backend = None + self.biomarker_type = None + + def set_biomarker_type_and_init_backend(self, bio_type): + self.biomarker_type = bio_type + if bio_type == 'HFO': + self.backend = HFO_App() + elif bio_type == 'Spindle': + self.backend = SpindleApp() + elif bio_type == 'Spike': + self.backend = HFO_App() + + def init_error_terminal_display(self): + self.window.stdout = Queue() + self.window.stderr = Queue() + sys.stdout = WriteStream(self.window.stdout) + sys.stderr = WriteStream(self.window.stderr) + self.window.thread_stdout = STDOutReceiver(self.window.stdout) + self.window.thread_stdout.std_received_signal.connect(self.message_handler) + self.window.thread_stdout.start() + + self.window.thread_stderr = STDErrReceiver(self.window.stderr) + self.window.thread_stderr.std_received_signal.connect(self.message_handler) + self.window.thread_stderr.start() + + def init_menu_bar(self): + self.window.action_Open_EDF.triggered.connect(self.open_file) + self.window.actionQuick_Detection.triggered.connect(self.open_quick_detection) + self.window.action_Load_Detection.triggered.connect(self.load_from_npz) + + ## top toolbar buttoms + self.window.actionOpen_EDF_toolbar.triggered.connect(self.open_file) + self.window.actionQuick_Detection_toolbar.triggered.connect(self.open_quick_detection) + self.window.actionLoad_Detection_toolbar.triggered.connect(self.load_from_npz) + + def init_waveform_display(self): + # waveform display widget + self.window.waveform_plot_widget = pg.PlotWidget() + self.window.waveform_mini_widget = pg.PlotWidget() + self.window.widget.layout().addWidget(self.window.waveform_plot_widget, 0, 1) + self.window.widget.layout().addWidget(self.window.waveform_mini_widget, 1, 1) + self.window.widget.layout().setRowStretch(0, 9) + self.window.widget.layout().setRowStretch(1, 1) + + def set_backend(self, backend): + self.backend = backend + + def filter_data(self): + self.message_handler("Filtering data...") + try: + # get filter parameters + fp_raw = self.window.fp_input.text() + fs_raw = self.window.fs_input.text() + rp_raw = self.window.rp_input.text() + rs_raw = self.window.rs_input.text() + # self.pop_window() + param_dict = {"fp": float(fp_raw), "fs": float(fs_raw), "rp": float(rp_raw), "rs": float(rs_raw)} + filter_param = ParamFilter.from_dict(param_dict) + self.backend.set_filter_parameter(filter_param) + except: + # there is error of the filter machine + # therefore pop up window to show that filter failed + msg = QMessageBox() + msg.setIcon(QMessageBox.Critical) + msg.setText("Error") + msg.setInformativeText('Filter could not be constructed with the given parameters') + msg.setWindowTitle("Filter Construction Error") + msg.exec_() + return + worker = Worker(self._filter) + worker.signals.finished.connect(self.filtering_complete) + self.window.threadpool.start(worker) + + def create_center_waveform_and_mini_plot(self): + self.window.channels_to_plot = [] + self.window.waveform_plot = CenterWaveformAndMiniPlotController(self.window.waveform_plot_widget, + self.window.waveform_mini_widget, + self.backend) + + # part of “clear everything if exit”, optimize in the future + safe_connect_signal_slot(self.window.waveform_time_scroll_bar.valueChanged, self.scroll_time_waveform_plot) + safe_connect_signal_slot(self.window.channel_scroll_bar.valueChanged, self.scroll_channel_waveform_plot) + self.window.waveform_time_scroll_bar.valueChanged.disconnect(self.scroll_time_waveform_plot) + self.window.channel_scroll_bar.valueChanged.disconnect(self.scroll_channel_waveform_plot) + + def init_classifier_param(self): + self.window.classifier_param = ParamClassifier() + # self.classifier_save_button.clicked.connect(self.hfo_app.set_classifier()) + + def init_param(self, biomarker_type='HFO'): + if biomarker_type == 'HFO': + self.init_classifier_param() + self.init_default_filter_input_params() + self.init_default_ste_input_params() + self.init_default_mni_input_params() + self.init_default_hil_input_params() + + self.set_mni_input_len(8) + self.set_ste_input_len(8) + self.set_hil_input_len(8) + elif biomarker_type == 'Spindle': + self.init_classifier_param() + self.init_default_filter_input_params() + self.init_default_yasa_input_params() + + self.set_yasa_input_len(8) + + def init_default_filter_input_params(self): + if self.biomarker_type == 'HFO': + default_params = ParamFilter() + self.window.fp_input.setText(str(default_params.fp)) + self.window.fs_input.setText(str(default_params.fs)) + self.window.rp_input.setText(str(default_params.rp)) + self.window.rs_input.setText(str(default_params.rs)) + elif self.biomarker_type == 'Spindle': + default_params = ParamFilterSpindle() + self.window.fp_input.setText(str(default_params.fp)) + self.window.fs_input.setText(str(default_params.fs)) + self.window.rp_input.setText(str(default_params.rp)) + self.window.rs_input.setText(str(default_params.rs)) + + def init_default_ste_input_params(self): + default_params = ParamSTE(2000) + self.window.ste_rms_window_input.setText(str(default_params.rms_window)) + self.window.ste_rms_threshold_input.setText(str(default_params.rms_thres)) + self.window.ste_min_window_input.setText(str(default_params.min_window)) + self.window.ste_epoch_length_input.setText(str(default_params.epoch_len)) + self.window.ste_min_gap_input.setText(str(default_params.min_gap)) + self.window.ste_min_oscillation_input.setText(str(default_params.min_osc)) + self.window.ste_peak_threshold_input.setText(str(default_params.peak_thres)) + + def init_default_mni_input_params(self): + """this is how I got the params, I reversed it here + + epoch_time = self.mni_epoch_time_input.text() + epo_CHF = self.mni_epoch_CHF_input.text() + per_CHF = self.mni_chf_percentage_input.text() + min_win = self.mni_min_window_input.text() + min_gap = self.mni_min_gap_time_input.text() + thrd_perc = self.mni_threshold_percentage_input.text() + base_seg = self.mni_baseline_window_input.text() + base_shift = self.mni_baseline_shift_input.text() + base_thrd = self.mni_baseline_threshold_input.text() + base_min = self.mni_baseline_min_time_input.text() + """ + default_params = ParamMNI(200) + self.window.mni_epoch_time_input.setText(str(default_params.epoch_time)) + self.window.mni_epoch_chf_input.setText(str(default_params.epo_CHF)) + self.window.mni_chf_percentage_input.setText(str(default_params.per_CHF)) + self.window.mni_min_window_input.setText(str(default_params.min_win)) + self.window.mni_min_gap_time_input.setText(str(default_params.min_gap)) + self.window.mni_threshold_percentage_input.setText(str(default_params.thrd_perc * 100)) + self.window.mni_baseline_window_input.setText(str(default_params.base_seg)) + self.window.mni_baseline_shift_input.setText(str(default_params.base_shift)) + self.window.mni_baseline_threshold_input.setText(str(default_params.base_thrd)) + self.window.mni_baseline_min_time_input.setText(str(default_params.base_min)) + + def init_default_hil_input_params(self): + default_params = ParamHIL(2000) + self.window.hil_sample_freq_input.setText(str(default_params.sample_freq)) + self.window.hil_pass_band_input.setText(str(default_params.pass_band)) + self.window.hil_stop_band_input.setText(str(default_params.stop_band)) + self.window.hil_epoch_time_input.setText(str(default_params.epoch_time)) + self.window.hil_sd_threshold_input.setText(str(default_params.sd_threshold)) + self.window.hil_min_window_input.setText(str(default_params.min_window)) + + def init_default_yasa_input_params(self): + default_params = ParamYASA(2000) + self.window.yasa_freq_sp_input.setText(str(default_params.freq_sp)) + self.window.yasa_freq_broad_input.setText(str(default_params.freq_broad)) + self.window.yasa_duration_input.setText(str(default_params.duration)) + self.window.yasa_min_distance_input.setText(str(default_params.min_distance)) + self.window.yasa_thresh_rel_pow_input.setText(str(default_params.rel_pow)) + self.window.yasa_thresh_corr_input.setText(str(default_params.corr)) + self.window.yasa_thresh_rms_input.setText(str(default_params.rms)) + + def connect_signal_and_slot(self, biomarker_type='HFO'): + # classifier default buttons + safe_connect_signal_slot(self.window.default_cpu_button.clicked, self.set_classifier_param_cpu_default) + safe_connect_signal_slot(self.window.default_gpu_button.clicked, self.set_classifier_param_gpu_default) + + # choose model files connection + safe_connect_signal_slot(self.window.choose_artifact_model_button.clicked, lambda: self.choose_model_file("artifact")) + safe_connect_signal_slot(self.window.choose_spike_model_button.clicked, lambda: self.choose_model_file("spike")) + + # custom model param connection + safe_connect_signal_slot(self.window.classifier_save_button.clicked, self.set_custom_classifier_param) + + # detect_all_button + safe_connect_signal_slot(self.window.detect_all_button.clicked, lambda: self.classify(True)) + self.window.detect_all_button.setEnabled(False) + # # self.detect_artifacts_button.clicked.connect(lambda : self.classify(False)) + + safe_connect_signal_slot(self.window.save_csv_button.clicked, self.save_to_excel) + self.window.save_csv_button.setEnabled(False) + + # set n_jobs min and max + self.window.n_jobs_spinbox.setMinimum(1) + self.window.n_jobs_spinbox.setMaximum(mp.cpu_count()) + + # set default n_jobs + self.window.n_jobs_spinbox.setValue(self.backend.n_jobs) + safe_connect_signal_slot(self.window.n_jobs_ok_button.clicked, self.set_n_jobs) + + safe_connect_signal_slot(self.window.save_npz_button.clicked, self.save_to_npz) + self.window.save_npz_button.setEnabled(False) + + safe_connect_signal_slot(self.window.Filter60Button.toggled, self.switch_60) + self.window.Filter60Button.setEnabled(False) + + safe_connect_signal_slot(self.window.bipolar_button.clicked, self.open_bipolar_channel_selection) + self.window.bipolar_button.setEnabled(False) + + # annotation button + safe_connect_signal_slot(self.window.annotation_button.clicked, self.open_annotation) + self.window.annotation_button.setEnabled(False) + + self.window.Choose_Channels_Button.setEnabled(False) + self.window.waveform_plot_button.setEnabled(False) + + # check if gpu is available + self.gpu = torch.cuda.is_available() + # print(f"GPU available: {self.gpu}") + if not self.gpu: + # disable gpu buttons + self.window.default_gpu_button.setEnabled(False) + + if biomarker_type == 'HFO': + safe_connect_signal_slot(self.window.overview_filter_button.clicked, self.filter_data) + # set filter button to be disabled by default + self.window.overview_filter_button.setEnabled(False) + # # self.show_original_button.clicked.connect(self.toggle_filtered) + + safe_connect_signal_slot(self.window.mni_detect_button.clicked, self.detect_HFOs) + self.window.mni_detect_button.setEnabled(False) + safe_connect_signal_slot(self.window.ste_detect_button.clicked, self.detect_HFOs) + self.window.ste_detect_button.setEnabled(False) + safe_connect_signal_slot(self.window.hil_detect_button.clicked, self.detect_HFOs) + self.window.hil_detect_button.setEnabled(False) + + safe_connect_signal_slot(self.window.STE_save_button.clicked, self.save_ste_params) + safe_connect_signal_slot(self.window.MNI_save_button.clicked, self.save_mni_params) + safe_connect_signal_slot(self.window.HIL_save_button.clicked, self.save_hil_params) + self.window.STE_save_button.setEnabled(False) + self.window.MNI_save_button.setEnabled(False) + self.window.HIL_save_button.setEnabled(False) + elif biomarker_type == 'Spindle': + safe_connect_signal_slot(self.window.overview_filter_button.clicked, self.filter_data) + + # set filter button to be disabled by default + self.window.overview_filter_button.setEnabled(False) + + safe_connect_signal_slot(self.window.yasa_detect_button.clicked, self.detect_Spindles) + self.window.yasa_detect_button.setEnabled(False) + + safe_connect_signal_slot(self.window.YASA_save_button.clicked, self.save_yasa_params) + # self.window.YASA_save_button.setEnabled(False) + + def set_classifier_param_display(self): + classifier_param = self.backend.get_classifier_param() + + self.window.overview_artifact_path_display.setText(classifier_param.artifact_path) + self.window.overview_spike_path_display.setText(classifier_param.spike_path) + self.window.overview_use_spike_checkbox.setChecked(classifier_param.use_spike) + self.window.overview_device_display.setText(str(classifier_param.device)) + self.window.overview_batch_size_display.setText(str(classifier_param.batch_size)) + + # set also the input fields + self.window.classifier_artifact_filename.setText(classifier_param.artifact_path) + self.window.classifier_spike_filename.setText(classifier_param.spike_path) + self.window.use_spike_checkbox.setChecked(classifier_param.use_spike) + self.window.classifier_device_input.setText(str(classifier_param.device)) + self.window.classifier_batch_size_input.setText(str(classifier_param.batch_size)) + + def set_classifier_param_gpu_default(self): + self.backend.set_default_gpu_classifier() + self.set_classifier_param_display() + + def set_classifier_param_cpu_default(self): + self.backend.set_default_cpu_classifier() + self.set_classifier_param_display() + + def set_custom_classifier_param(self): + artifact_path = self.window.classifier_artifact_filename.text() + spike_path = self.window.classifier_spike_filename.text() + use_spike = self.window.use_spike_checkbox.isChecked() + device = self.window.classifier_device_input.text() + if device == "cpu": + model_type = "default_cpu" + elif device == "cuda:0" and self.window.gpu: + model_type = "default_gpu" + else: + # print("device not recognized, please set to cpu for cpu or cuda:0 for gpu") + msg = QMessageBox() + msg.setIcon(QMessageBox.Critical) + msg.setText("Error!") + msg.setInformativeText('Device not recognized, please set to CPU for CPU or cuda:0 for GPU') + msg.setWindowTitle("Device not recognized") + msg.exec_() + return + batch_size = self.window.classifier_batch_size_input.text() + + classifier_param = ParamClassifier(artifact_path=artifact_path, spike_path=spike_path, use_spike=use_spike, + device=device, batch_size=int(batch_size), model_type=model_type) + self.backend.set_classifier(classifier_param) + self.set_classifier_param_display() + + def choose_model_file(self, model_type): + fname, _ = QFileDialog.getOpenFileName(self.window, 'Open file', "", ".tar files (*.tar)") + if model_type == "artifact": + self.window.classifier_artifact_filename.setText(fname) + elif model_type == "spike": + self.window.classifier_spike_filename.setText(fname) + + def _classify(self, artifact_only=False): + threshold = 0.5 + seconds_to_ignore_before = float(self.window.overview_ignore_before_input.text()) + seconds_to_ignore_after = float(self.window.overview_ignore_after_input.text()) + self.backend.classify_artifacts([seconds_to_ignore_before, seconds_to_ignore_after], threshold) + if not artifact_only: + self.backend.classify_spikes() + return [] + + def _classify_finished(self): + self.message_handler("Classification finished!..") + self.update_statistics_label() + self.window.waveform_plot.set_plot_biomarkers(True) + self.window.save_csv_button.setEnabled(True) + + def classify(self, check_spike=True): + self.message_handler("Classifying HFOs...") + if check_spike: + use_spike = self.window.overview_use_spike_checkbox.isChecked() + else: + use_spike = False + worker = Worker(lambda progress_callback: self._classify((not use_spike))) + worker.signals.result.connect(self._classify_finished) + self.window.threadpool.start(worker) + + def update_statistics_label(self): + if self.biomarker_type == 'HFO': + num_HFO = self.backend.event_features.get_num_biomarker() + num_artifact = self.backend.event_features.get_num_artifact() + num_spike = self.backend.event_features.get_num_spike() + num_real = self.backend.event_features.get_num_real() + + self.window.statistics_label.setText(" Number of HFOs: " + str(num_HFO) + \ + "\n Number of artifacts: " + str(num_artifact) + \ + "\n Number of spikes: " + str(num_spike) + \ + "\n Number of real HFOs: " + str(num_real)) + elif self.biomarker_type == 'Spindle': + num_spindle = self.backend.event_features.get_num_biomarker() + num_artifact = self.backend.event_features.get_num_artifact() + num_spike = self.backend.event_features.get_num_spike() + num_real = self.backend.event_features.get_num_real() + + self.window.statistics_label.setText(" Number of Spindles: " + str(num_spindle) + \ + "\n Number of artifacts: " + str(num_artifact) + \ + "\n Number of spikes: " + str(num_spike) + \ + "\n Number of real Spindles: " + str(num_real)) + elif self.biomarker_type == 'Spike': + num_spindle = self.backend.event_features.get_num_biomarker() + num_artifact = self.backend.event_features.get_num_artifact() + num_spike = self.backend.event_features.get_num_spike() + num_real = self.backend.event_features.get_num_real() + + self.window.statistics_label.setText(" Number of Spindles: " + str(num_spindle) + \ + "\n Number of artifacts: " + str(num_artifact) + \ + "\n Number of spikes: " + str(num_spike) + \ + "\n Number of real Spindles: " + str(num_real)) + + def save_to_excel(self): + # open file dialog + fname, _ = QFileDialog.getSaveFileName(self.window, 'Save file', "", ".xlsx files (*.xlsx)") + if fname: + self.backend.export_excel(fname) + + def _save_to_npz(self, fname, progress_callback): + self.backend.export_app(fname) + return [] + + def save_to_npz(self): + # open file dialog + # print("saving to npz...",end="") + fname, _ = QFileDialog.getSaveFileName(self.window, 'Save file', "", ".npz files (*.npz)") + if fname: + # print("saving to {fname}...",end="") + worker = Worker(self._save_to_npz, fname) + worker.signals.result.connect(lambda: 0) + self.window.threadpool.start(worker) + + def _load_from_npz(self, fname, progress_callback): + self.backend = self.backend.import_app(fname) + return [] + + def load_from_npz(self): + # open file dialog + fname, _ = QFileDialog.getOpenFileName(self.window, 'Open file', "", ".npz files (*.npz)") + self.message_handler("Loading from npz...") + if fname: + self.reinitialize() + worker = Worker(self._load_from_npz, fname) + worker.signals.result.connect(self.load_from_npz_finished) + self.window.threadpool.start(worker) + # print(self.hfo_app.get_edf_info()) + + def load_from_npz_finished(self): + edf_info = self.backend.get_edf_info() + self.window.waveform_plot.update_backend(self.backend) + self.window.waveform_plot.init_eeg_data() + edf_name = str(edf_info["edf_fn"]) + edf_name = edf_name[edf_name.rfind("/") + 1:] + self.update_edf_info([edf_name, str(edf_info["sfreq"]), + str(edf_info["nchan"]), str(self.backend.eeg_data.shape[1])]) + # update number of jobs + self.window.n_jobs_spinbox.setValue(self.backend.n_jobs) + if self.backend.filtered: + self.filtering_complete() + filter_param = self.backend.param_filter + # update filter params + self.window.fp_input.setText(str(filter_param.fp)) + self.window.fs_input.setText(str(filter_param.fs)) + self.window.rp_input.setText(str(filter_param.rp)) + self.window.rs_input.setText(str(filter_param.rs)) + # update the detector parameters: + if self.backend.detected: + self.set_detector_param_display() + self._detect_finished() + self.update_statistics_label() + # update classifier param + if self.backend.classified: + self.set_classifier_param_display() + self._classify_finished() + self.update_statistics_label() + + def open_channel_selection(self): + self.window.channel_selection_window = ChannelSelectionWindow(self.backend, self, self.window.close_signal) + self.window.channel_selection_window.show() + + def channel_selection_update(self): + self.window.channel_scroll_bar.setValue(0) + self.window.waveform_time_scroll_bar.setValue(0) + is_empty = self.window.n_channel_input.maximum() == 0 + self.window.waveform_plot.plot(0, 0, empty=is_empty, update_biomarker=True) + + def switch_60(self): + # get the value of the Filter60Button radio button + filter_60 = self.window.Filter60Button.isChecked() + # print("filtering:", filter_60) + # if yes + if filter_60: + self.backend.set_filter_60() + # if not + else: + self.backend.set_unfiltered_60() + + # replot + self.window.waveform_plot.plot() + # add a warning to the text about the HFO info saying that it is outdated now + + @pyqtSlot(str) + def message_handler(self, s): + s = s.replace("\n", "") + horScrollBar = self.window.STDTextEdit.horizontalScrollBar() + verScrollBar = self.window.STDTextEdit.verticalScrollBar() + scrollIsAtEnd = verScrollBar.maximum() - verScrollBar.value() <= 10 + + contain_percentage = re.findall(r'%', s) + contain_one_hundred_percentage = re.findall(r'100%', s) + if contain_one_hundred_percentage: + cursor = self.window.STDTextEdit.textCursor() + cursor.movePosition(QTextCursor.End - 1) + cursor.select(QTextCursor.BlockUnderCursor) + cursor.removeSelectedText() + self.window.STDTextEdit.setTextCursor(cursor) + self.window.STDTextEdit.insertPlainText(s) + elif contain_percentage: + cursor = self.window.STDTextEdit.textCursor() + cursor.movePosition(QTextCursor.End) + cursor.select(QTextCursor.BlockUnderCursor) + cursor.removeSelectedText() + self.window.STDTextEdit.setTextCursor(cursor) + self.window.STDTextEdit.insertPlainText(s) + else: + self.window.STDTextEdit.append(s) + + if scrollIsAtEnd: + verScrollBar.setValue(verScrollBar.maximum()) # Scrolls to the bottom + horScrollBar.setValue(0) # scroll to the left + + @pyqtSlot(list) + def update_edf_info(self, results): + self.window.main_filename.setText(results[0]) + self.window.main_sampfreq.setText(results[1]) + self.window.sample_freq = float(results[1]) + self.window.main_numchannels.setText(results[2]) + # print("updated") + self.window.main_length.setText(str(round(float(results[3]) / (60 * float(results[1])), 3)) + " min") + # self.window.waveform_plot.plot(0, update_biomarker=True) + self.window.waveform_plot.set_plot_biomarkers(False) + + # print("plotted") + # connect buttons + self.window.waveform_time_scroll_bar.valueChanged.connect(self.scroll_time_waveform_plot) + self.window.channel_scroll_bar.valueChanged.connect(self.scroll_channel_waveform_plot) + + self.signal_connected = True + + self.window.waveform_plot_button.clicked.connect(self.waveform_plot_button_clicked) + self.window.waveform_plot_button.setEnabled(True) + self.window.Choose_Channels_Button.clicked.connect(self.open_channel_selection) + self.window.Choose_Channels_Button.setEnabled(True) + # set the display time window spin box + self.window.display_time_window_input.setValue(self.window.waveform_plot.get_time_window()) + self.window.display_time_window_input.setMaximum(self.window.waveform_plot.get_total_time()) + self.window.display_time_window_input.setMinimum(0.1) + # set the n channel spin box + self.window.n_channel_input.setValue(self.window.waveform_plot.get_n_channels_to_plot()) + self.window.n_channel_input.setMaximum(self.window.waveform_plot.get_n_channels()) + self.window.n_channel_input.setMinimum(1) + # set the time scroll bar range + self.window.waveform_time_scroll_bar.setMaximum(int(self.window.waveform_plot.get_total_time() / ( + self.window.waveform_plot.get_time_window() * self.window.waveform_plot.get_time_increment() / 100))) + self.window.waveform_time_scroll_bar.setValue(0) + # set the channel scroll bar range + self.window.channel_scroll_bar.setMaximum( + self.window.waveform_plot.get_n_channels() - self.window.waveform_plot.get_n_channels_to_plot()) + # enable the filter button + self.window.overview_filter_button.setEnabled(True) + self.window.toggle_filtered_checkbox.stateChanged.connect(self.toggle_filtered) + self.window.normalize_vertical_input.stateChanged.connect(self.waveform_plot_button_clicked) + # enable the plot out the 60Hz bandstopped signal + self.window.Filter60Button.setEnabled(True) + self.window.bipolar_button.setEnabled(True) + # print("EDF file loaded") + + def toggle_filtered(self): + # self.message_handler('Showing original data...') + if self.window.is_data_filtered: + self.window.show_filtered = not self.window.show_filtered + self.window.waveform_plot.set_filtered(self.window.show_filtered) + self.waveform_plot_button_clicked() + + def read_edf(self, fname, progress_callback): + self.reinitialize() + self.backend.load_edf(fname) + eeg_data, channel_names = self.backend.get_eeg_data() + edf_info = self.backend.get_edf_info() + self.window.waveform_plot.init_eeg_data() + filename = os.path.basename(fname) + sample_freq = str(self.backend.sample_freq) + num_channels = str(len(self.backend.channel_names)) + length = str(self.backend.eeg_data.shape[1]) + return [filename, sample_freq, num_channels, length] + + def _filter(self, progress_callback): + self.backend.filter_eeg_data() + return [] + + # def open_detector(self): + # # Pass the function to execute, function, args, kwargs + # worker = Worker(self.quick_detect) + # self.window.threadpool.start(worker) + # + # def round_dict(self, d: dict, n: int): + # for key in d.keys(): + # if type(d[key]) == float: + # d[key] = round(d[key], n) + # return d + + def scroll_time_waveform_plot(self, event): + t_start = self.window.waveform_time_scroll_bar.value() * self.window.waveform_plot.get_time_window() * self.window.waveform_plot.get_time_increment() / 100 + self.window.waveform_plot.plot(t_start) + + def scroll_channel_waveform_plot(self, event): + channel_start = self.window.channel_scroll_bar.value() + self.window.waveform_plot.plot(first_channel_to_plot=channel_start, update_biomarker=True) + + def get_channels_to_plot(self): + return self.window.waveform_plot.get_channels_to_plot() + + def get_channel_indices_to_plot(self): + return self.window.waveform_plot.get_channel_indices_to_plot() + + def waveform_plot_button_clicked(self): + time_window = self.window.display_time_window_input.value() + self.window.waveform_plot.set_time_window(time_window) + n_channels_to_plot = self.window.n_channel_input.value() + self.window.waveform_plot.set_n_channels_to_plot(n_channels_to_plot) + time_increment = self.window.Time_Increment_Input.value() + self.window.waveform_plot.set_time_increment(time_increment) + normalize_vertical = self.window.normalize_vertical_input.isChecked() + self.window.waveform_plot.set_normalize_vertical(normalize_vertical) + is_empty = self.window.n_channel_input.maximum() == 0 + start = self.window.waveform_plot.t_start + first_channel_to_plot = self.window.waveform_plot.first_channel_to_plot + + t_value = int(start // (self.window.waveform_plot.get_time_window() * self.window.waveform_plot.get_time_increment() / 100)) + self.window.waveform_time_scroll_bar.setMaximum(int(self.window.waveform_plot.get_total_time() / ( + self.window.waveform_plot.get_time_window() * self.window.waveform_plot.get_time_increment() / 100))) + self.window.waveform_time_scroll_bar.setValue(t_value) + c_value = self.window.channel_scroll_bar.value() + self.window.channel_scroll_bar.setMaximum(len(self.window.waveform_plot.get_channels_to_plot()) - n_channels_to_plot) + self.window.channel_scroll_bar.setValue(c_value) + self.window.waveform_plot.plot(start, first_channel_to_plot, empty=is_empty, update_biomarker=True) + + def open_file(self): + # reinitialize the app + self.set_biomarker_type_and_init_backend(self.biomarker_type) + fname, _ = QFileDialog.getOpenFileName(self.window, "Open File", "", "Recordings Files (*.edf *.eeg *.vhdr *.vmrk)") + if fname: + worker = Worker(self.read_edf, fname) + worker.signals.result.connect(self.update_edf_info) + self.window.threadpool.start(worker) + + def filtering_complete(self): + self.message_handler('Filtering COMPLETE!') + filter_60 = self.window.Filter60Button.isChecked() + # print("filtering:", filter_60) + # if yes + if filter_60: + self.backend.set_filter_60() + # if not + else: + self.backend.set_unfiltered_60() + + if self.biomarker_type == 'HFO': + self.window.STE_save_button.setEnabled(True) + self.window.ste_detect_button.setEnabled(True) + self.window.MNI_save_button.setEnabled(True) + self.window.mni_detect_button.setEnabled(True) + self.window.HIL_save_button.setEnabled(True) + self.window.hil_detect_button.setEnabled(True) + self.window.is_data_filtered = True + self.window.show_filtered = True + self.window.waveform_plot.set_filtered(True) + self.window.save_npz_button.setEnabled(True) + elif self.biomarker_type == 'Spindle': + self.window.is_data_filtered = True + self.window.show_filtered = True + self.window.waveform_plot.set_filtered(True) + self.window.save_npz_button.setEnabled(True) + + def detect_HFOs(self): + print("Detecting HFOs...") + worker = Worker(self._detect) + worker.signals.result.connect(self._detect_finished) + self.window.threadpool.start(worker) + + def detect_Spindles(self): + print("Detecting Spindles...") + worker = Worker(self._detect) + worker.signals.result.connect(self._detect_finished) + self.window.threadpool.start(worker) + + def _detect_finished(self): + # right now do nothing beyond message handler saying that + # it has detected HFOs + self.message_handler("Biomarker detected") + self.update_statistics_label() + self.window.waveform_plot.set_plot_biomarkers(True) + self.window.detect_all_button.setEnabled(True) + self.window.annotation_button.setEnabled(True) + + def _detect(self, progress_callback): + # call detect HFO function on backend + self.backend.detect_biomarker() + return [] + + def open_quick_detection(self): + # if we want to open multiple qd dialog + if not self.window.quick_detect_open: + qd = HFOQuickDetector(HFO_App(), self, self.window.close_signal) + # print("created new quick detector") + qd.show() + self.window.quick_detect_open = True + + def set_quick_detect_open(self, open): + self.window.quick_detect_open = open + + def reinitialize_buttons(self): + self.window.mni_detect_button.setEnabled(False) + self.window.ste_detect_button.setEnabled(False) + self.window.hil_detect_button.setEnabled(False) + self.window.detect_all_button.setEnabled(False) + self.window.save_csv_button.setEnabled(False) + self.window.save_npz_button.setEnabled(False) + self.window.STE_save_button.setEnabled(False) + self.window.MNI_save_button.setEnabled(False) + self.window.HIL_save_button.setEnabled(False) + self.window.Filter60Button.setEnabled(False) + + def set_mni_input_len(self, max_len=5): + self.window.mni_epoch_time_input.setMaxLength(max_len) + self.window.mni_epoch_chf_input.setMaxLength(max_len) + self.window.mni_chf_percentage_input.setMaxLength(max_len) + self.window.mni_min_window_input.setMaxLength(max_len) + self.window.mni_min_gap_time_input.setMaxLength(max_len) + self.window.mni_threshold_percentage_input.setMaxLength(max_len) + self.window.mni_baseline_window_input.setMaxLength(max_len) + self.window.mni_baseline_shift_input.setMaxLength(max_len) + self.window.mni_baseline_threshold_input.setMaxLength(max_len) + self.window.mni_baseline_min_time_input.setMaxLength(max_len) + + def set_ste_input_len(self, max_len=5): + self.window.ste_rms_window_input.setMaxLength(max_len) + self.window.ste_min_window_input.setMaxLength(max_len) + self.window.ste_min_gap_input.setMaxLength(max_len) + self.window.ste_epoch_length_input.setMaxLength(max_len) + self.window.ste_min_oscillation_input.setMaxLength(max_len) + self.window.ste_rms_threshold_input.setMaxLength(max_len) + self.window.ste_peak_threshold_input.setMaxLength(max_len) + + def set_hil_input_len(self, max_len=5): + self.window.hil_sample_freq_input.setMaxLength(max_len) + self.window.hil_pass_band_input.setMaxLength(max_len) + self.window.hil_stop_band_input.setMaxLength(max_len) + self.window.hil_epoch_time_input.setMaxLength(max_len) + self.window.hil_sd_threshold_input.setMaxLength(max_len) + self.window.hil_min_window_input.setMaxLength(max_len) + + def set_yasa_input_len(self, max_len=5): + self.window.yasa_freq_sp_input.setMaxLength(max_len) + self.window.yasa_freq_broad_input.setMaxLength(max_len) + self.window.yasa_duration_input.setMaxLength(max_len) + self.window.yasa_min_distance_input.setMaxLength(max_len) + self.window.yasa_thresh_rel_pow_input.setMaxLength(max_len) + self.window.yasa_thresh_corr_input.setMaxLength(max_len) + self.window.yasa_thresh_rms_input.setMaxLength(max_len) + + def close_other_window(self): + self.window.close_signal.emit() + + def set_n_jobs(self): + self.backend.n_jobs = int(self.window.n_jobs_spinbox.value()) + # print(f"n_jobs set to {self.hfo_app.n_jobs}") + + def set_channels_to_plot(self, channels_to_plot, display_all=True): + self.window.waveform_plot.set_channels_to_plot(channels_to_plot) + # print(f"Channels to plot: {self.channels_to_plot}") + self.window.n_channel_input.setMaximum(len(channels_to_plot)) + if display_all: + self.window.n_channel_input.setValue(len(channels_to_plot)) + self.waveform_plot_button_clicked() + + def save_ste_params(self): + # get filter parameters + rms_window_raw = self.window.ste_rms_window_input.text() + min_window_raw = self.window.ste_min_window_input.text() + min_gap_raw = self.window.ste_min_gap_input.text() + epoch_len_raw = self.window.ste_epoch_length_input.text() + min_osc_raw = self.window.ste_min_oscillation_input.text() + rms_thres_raw = self.window.ste_rms_threshold_input.text() + peak_thres_raw = self.window.ste_peak_threshold_input.text() + try: + param_dict = {"sample_freq": 2000, "pass_band": 1, "stop_band": 80, + # these are placeholder params, will be updated later + "rms_window": float(rms_window_raw), "min_window": float(min_window_raw), + "min_gap": float(min_gap_raw), + "epoch_len": float(epoch_len_raw), "min_osc": float(min_osc_raw), + "rms_thres": float(rms_thres_raw), + "peak_thres": float(peak_thres_raw), "n_jobs": self.backend.n_jobs} + detector_params = {"detector_type": "STE", "detector_param": param_dict} + self.backend.set_detector(ParamDetector.from_dict(detector_params)) + + # set display parameters + self.window.ste_epoch_display.setText(epoch_len_raw) + self.window.ste_min_window_display.setText(min_window_raw) + self.window.ste_rms_window_display.setText(rms_window_raw) + self.window.ste_min_gap_time_display.setText(min_gap_raw) + self.window.ste_min_oscillations_display.setText(min_osc_raw) + self.window.ste_peak_threshold_display.setText(peak_thres_raw) + self.window.ste_rms_threshold_display.setText(rms_thres_raw) + self.update_detector_tab("STE") + except: + msg = QMessageBox() + msg.setIcon(QMessageBox.Critical) + msg.setText("Error!") + msg.setInformativeText('Detector could not be constructed given the parameters') + msg.setWindowTitle("Detector Construction Failed") + msg.exec_() + + def save_mni_params(self): + try: + epoch_time = self.window.mni_epoch_time_input.text() + epo_CHF = self.window.mni_epoch_chf_input.text() + per_CHF = self.window.mni_chf_percentage_input.text() + min_win = self.window.mni_min_window_input.text() + min_gap = self.window.mni_min_gap_time_input.text() + thrd_perc = self.window.mni_threshold_percentage_input.text() + base_seg = self.window.mni_baseline_window_input.text() + base_shift = self.window.mni_baseline_shift_input.text() + base_thrd = self.window.mni_baseline_threshold_input.text() + base_min = self.window.mni_baseline_min_time_input.text() + + param_dict = {"sample_freq": 2000, "pass_band": 1, "stop_band": 80, + # these are placeholder params, will be updated later + "epoch_time": float(epoch_time), "epo_CHF": float(epo_CHF), "per_CHF": float(per_CHF), + "min_win": float(min_win), "min_gap": float(min_gap), "base_seg": float(base_seg), + "thrd_perc": float(thrd_perc) / 100, + "base_shift": float(base_shift), "base_thrd": float(base_thrd), "base_min": float(base_min), + "n_jobs": self.backend.n_jobs} + # param_dict = self.round_dict(param_dict, 3) + detector_params = {"detector_type": "MNI", "detector_param": param_dict} + self.backend.set_detector(ParamDetector.from_dict(detector_params)) + + # set display parameters + self.window.mni_epoch_display.setText(epoch_time) + self.window.mni_epoch_chf_display.setText(epo_CHF) + self.window.mni_chf_percentage_display.setText(per_CHF) + self.window.mni_min_window_display.setText(min_win) + self.window.mni_min_gap_time_display.setText(min_gap) + self.window.mni_threshold_percentile_display.setText(thrd_perc) + self.window.mni_baseline_window_display.setText(base_seg) + self.window.mni_baseline_shift_display.setText(base_shift) + self.window.mni_baseline_threshold_display.setText(base_thrd) + self.window.mni_baseline_min_time_display.setText(base_min) + + self.update_detector_tab("MNI") + except Exception as e: + msg = QMessageBox() + msg.setIcon(QMessageBox.Critical) + msg.setText("Error!") + msg.setInformativeText('Detector could not be constructed given the parameters') + msg.setWindowTitle("Detector Construction Failed") + msg.exec_() + + def save_hil_params(self): + try: + sample_freq = self.window.hil_sample_freq_input.text() + pass_band = self.window.hil_pass_band_input.text() + stop_band = self.window.hil_stop_band_input.text() + epoch_time = self.window.hil_epoch_time_input.text() + sd_threshold = self.window.hil_sd_threshold_input.text() + min_window = self.window.hil_min_window_input.text() + + param_dict = { + "sample_freq": float(sample_freq), + "pass_band": float(pass_band), + "stop_band": float(stop_band), + "epoch_time": float(epoch_time), + "sd_threshold": float(sd_threshold), + "min_window": float(min_window), + "n_jobs": self.backend.n_jobs, + } + + detector_params = {"detector_type": "HIL", "detector_param": param_dict} + self.backend.set_detector(ParamDetector.from_dict(detector_params)) + + self.window.hil_sample_freq_display.setText(sample_freq) + self.window.hil_pass_band_display.setText(pass_band) + self.window.hil_stop_band_display.setText(stop_band) + self.window.hil_epoch_time_display.setText(epoch_time) + self.window.hil_sd_threshold_display.setText(sd_threshold) + self.window.hil_min_window_display.setText(min_window) + + self.update_detector_tab("HIL") + + except Exception as e: + msg = QMessageBox() + msg.setIcon(QMessageBox.Critical) + msg.setText("Error!") + msg.setInformativeText(f'HIL Detector could not be constructed given the parameters. Error: {str(e)}') + msg.setWindowTitle("HIL Detector Construction Failed") + msg.exec_() + + def save_yasa_params(self): + # get filter parameters + + freq_sp_raw = self.window.yasa_freq_sp_input.text() + freq_broad_raw = self.window.yasa_freq_broad_input.text() + duration_raw = self.window.yasa_duration_input.text() + min_distance_raw = self.window.yasa_min_distance_input.text() + thresh_rel_pow_raw = self.window.yasa_thresh_rel_pow_input.text() + thresh_corr_raw = self.window.yasa_thresh_corr_input.text() + thresh_rms_raw = self.window.yasa_thresh_rms_input.text() + try: + param_dict = {"sample_freq": 2000, + # these are placeholder params, will be updated later + "freq_sp": ast.literal_eval(freq_sp_raw), "freq_broad": ast.literal_eval(freq_broad_raw), + "duration": ast.literal_eval(duration_raw), + "min_distance": float(min_distance_raw), "rel_pow": float(thresh_rel_pow_raw), + "corr": float(thresh_corr_raw), + "rms": float(thresh_rms_raw), "n_jobs": self.backend.n_jobs} + detector_params = {"detector_type": "YASA", "detector_param": param_dict} + self.backend.set_detector(ParamDetector.from_dict(detector_params)) + + # set display parameters + self.window.yasa_freq_sp_display.setText(freq_sp_raw) + self.window.yasa_freq_broad_display.setText(freq_broad_raw) + self.window.yasa_duration_display.setText(duration_raw) + self.window.yasa_min_distance_display.setText(min_distance_raw) + self.window.yasa_thresh_rel_pow_display.setText(thresh_rel_pow_raw) + self.window.yasa_thresh_corr_display.setText(thresh_corr_raw) + self.window.yasa_thresh_rms_display.setText(thresh_rms_raw) + # self.update_detector_tab("STE") + self.window.yasa_detect_button.setEnabled(True) + except: + msg = QMessageBox() + msg.setIcon(QMessageBox.Critical) + msg.setText("Error!") + msg.setInformativeText('Detector could not be constructed given the parameters') + msg.setWindowTitle("Detector Construction Failed") + msg.exec_() + + def update_detector_tab(self, index): + if index == "STE": + self.window.stacked_widget_detection_param.setCurrentIndex(0) + elif index == "MNI": + self.window.stacked_widget_detection_param.setCurrentIndex(1) + elif index == "HIL": + self.window.stacked_widget_detection_param.setCurrentIndex(2) + + def reinitialize(self): + # kill all threads in self.threadpool + self.close_other_window() + # self.backend = HFO_App() + self.set_biomarker_type_and_init_backend(self.biomarker_type) + self.window.waveform_plot.update_backend(self.backend, False) + self.window.main_filename.setText("") + self.window.main_sampfreq.setText("") + self.window.main_numchannels.setText("") + self.window.main_length.setText("") + self.window.statistics_label.setText("") + + def update_ste_params(self, ste_params): + rms_window = str(ste_params["rms_window"]) + min_window = str(ste_params["min_window"]) + min_gap = str(ste_params["min_gap"]) + epoch_len = str(ste_params["epoch_len"]) + min_osc = str(ste_params["min_osc"]) + rms_thres = str(ste_params["rms_thres"]) + peak_thres = str(ste_params["peak_thres"]) + + self.window.ste_rms_window_input.setText(rms_window) + self.window.ste_min_window_input.setText(min_window) + self.window.ste_min_gap_input.setText(min_gap) + self.window.ste_epoch_length_input.setText(epoch_len) + self.window.ste_min_oscillation_input.setText(min_osc) + self.window.ste_rms_threshold_input.setText(rms_thres) + self.window.ste_peak_threshold_input.setText(peak_thres) + + # set display parameters + self.window.ste_epoch_display.setText(epoch_len) + self.window.ste_min_window_display.setText(min_window) + self.window.ste_rms_window_display.setText(rms_window) + self.window.ste_min_gap_time_display.setText(min_gap) + self.window.ste_min_oscillations_display.setText(min_osc) + self.window.ste_peak_threshold_display.setText(peak_thres) + self.window.ste_rms_threshold_display.setText(rms_thres) + + self.update_detector_tab("STE") + self.window.detector_subtabs.setCurrentIndex(0) + + def update_mni_params(self, mni_params): + epoch_time = str(mni_params["epoch_time"]) + epo_CHF = str(mni_params["epo_CHF"]) + per_CHF = str(mni_params["per_CHF"]) + min_win = str(mni_params["min_win"]) + min_gap = str(mni_params["min_gap"]) + thrd_perc = str(mni_params["thrd_perc"]) + base_seg = str(mni_params["base_seg"]) + base_shift = str(mni_params["base_shift"]) + base_thrd = str(mni_params["base_thrd"]) + base_min = str(mni_params["base_min"]) + + self.window.mni_epoch_time_input.setText(epoch_time) + self.window.mni_epoch_chf_input.setText(epo_CHF) + self.window.mni_chf_percentage_input.setText(per_CHF) + self.window.mni_min_window_input.setText(min_win) + self.window.mni_min_gap_time_input.setText(min_gap) + self.window.mni_threshold_percentage_input.setText(thrd_perc) + self.window.mni_baseline_window_input.setText(base_seg) + self.window.mni_baseline_shift_input.setText(base_shift) + self.window.mni_baseline_threshold_input.setText(base_thrd) + self.window.mni_baseline_min_time_input.setText(base_min) + + # set display parameters + self.window.mni_epoch_display.setText(epoch_time) + self.window.mni_epoch_chf_display.setText(epo_CHF) + self.window.mni_chf_percentage_display.setText(per_CHF) + self.window.mni_min_window_display.setText(min_win) + self.window.mni_min_gap_time_display.setText(min_gap) + self.window.mni_threshold_percentile_display.setText(thrd_perc) + self.window.mni_baseline_window_display.setText(base_seg) + self.window.mni_baseline_shift_display.setText(base_shift) + self.window.mni_baseline_threshold_display.setText(base_thrd) + self.window.mni_baseline_min_time_display.setText(base_min) + + self.update_detector_tab("MNI") + self.window.detector_subtabs.setCurrentIndex(1) + + def update_hil_params(self, hil_params): + sample_freq = str(hil_params["sample_freq"]) + pass_band = str(hil_params["pass_band"]) + stop_band = str(hil_params["stop_band"]) + epoch_time = str(hil_params["epoch_time"]) + sd_threshold = str(hil_params["sd_threshold"]) + min_window = str(hil_params["min_window"]) + + self.window.hil_sample_freq_input.setText(sample_freq) + self.window.hil_pass_band_input.setText(pass_band) + self.window.hil_stop_band_input.setText(stop_band) + self.window.hil_epoch_time_input.setText(epoch_time) + self.window.hil_sd_threshold_input.setText(sd_threshold) + self.window.hil_min_window_input.setText(min_window) + + # set display parameters + self.window.hil_sample_freq_display.setText(sample_freq) + self.window.hil_pass_band_display.setText(pass_band) + self.window.hil_stop_band_display.setText(stop_band) + self.window.hil_epoch_time_display.setText(epoch_time) + self.window.hil_sd_threshold_display.setText(sd_threshold) + self.window.hil_min_window_display.setText(min_window) + + self.update_detector_tab("HIL") + self.window.detector_subtabs.setCurrentIndex(2) + + def set_detector_param_display(self): + detector_params = self.backend.param_detector + detector_type = detector_params.detector_type.lower() + if detector_type == "ste": + self.update_ste_params(detector_params.detector_param.to_dict()) + elif detector_type == "mni": + self.update_mni_params(detector_params.detector_param.to_dict()) + elif detector_type == "hil": + self.update_hil_params(detector_params.detector_param.to_dict()) + + def open_bipolar_channel_selection(self): + self.window.bipolar_channel_selection_window = BipolarChannelSelectionWindow(self, + self.backend, + self.window, + self.window.close_signal, + self.window.waveform_plot) + self.window.bipolar_channel_selection_window.show() + + def open_annotation(self): + self.window.save_csv_button.setEnabled(True) + annotation = Annotation(self.backend, self.window, self.window.close_signal) + annotation.show() \ No newline at end of file diff --git a/src/models/mini_plot_model.py b/src/models/mini_plot_model.py new file mode 100644 index 0000000..32bb479 --- /dev/null +++ b/src/models/mini_plot_model.py @@ -0,0 +1,70 @@ +import numpy as np +from src.hfo_app import HFO_App +from src.spindle_app import SpindleApp +import sys + + +class MiniPlotModel: + def __init__(self, backend: HFO_App): + self.backend = backend + self.color_dict={"artifact":(245,130,48), #orange + "spike":(240,30,250), #pink + "non_spike":(60,180,75), #green + "HFO":(60,180,75), #green + } + self.first_channel_to_plot = 0 + self.n_channels_to_plot = 10 + + def update_backend(self, new_backend): + self.backend = new_backend + + def init_eeg_data(self): + eeg_data, self.channel_names = self.backend.get_eeg_data() + self.edf_info = self.backend.get_edf_info() + self.sample_freq = self.edf_info['sfreq'] + self.time = np.arange(0, eeg_data.shape[1]/self.sample_freq, 1/self.sample_freq) + + self.channel_names = list(self.channel_names) + self.n_channels = len(self.channel_names) + self.n_channels_to_plot = min(self.n_channels,self.n_channels_to_plot) + self.channels_to_plot = self.channel_names.copy() + self.channel_indices_to_plot = np.arange(self.n_channels) + + def set_first_channel_to_plot(self, first_channel_to_plot): + self.first_channel_to_plot = first_channel_to_plot + + def set_n_channels_to_plot(self, n_channels_to_plot:int): + self.n_channels_to_plot = n_channels_to_plot + + def set_channels_to_plot(self, channels_to_plot:list): + self.channels_to_plot = channels_to_plot + self.channel_indices_to_plot = [self.channel_names.index(channel) for channel in channels_to_plot] + + def set_channel_indices_to_plot(self,channel_indices_to_plot:list): + self.channel_indices_to_plot = channel_indices_to_plot + self.channels_to_plot = [self.channel_names[index] for index in channel_indices_to_plot] + + def update_channel_names(self, new_channel_names): + self.channel_names = list(new_channel_names) + + def get_all_biomarkers_for_channel(self, channel, t_start=0, t_end=sys.maxsize): + return self.backend.event_features.get_biomarkers_for_channel(channel, t_start, t_end) + + def get_all_biomarkers_for_channel_and_color(self, channel, t_start=0, t_end=sys.maxsize): + starts, ends, artifacts, spikes = self.get_all_biomarkers_for_channel(channel, t_start, t_end) + colors = [] + for j in range(len(starts)): + try: + if int(artifacts[j])<1: + color = self.color_dict["artifact"] + elif spikes[j]: + color = self.color_dict["spike"] + else: + color = self.color_dict["non_spike"] + except: + color = self.color_dict["non_spike"] + colors.append(color) + + starts_in_time = [self.time[int(i)] for i in starts] + ends_in_time = [self.time[min(int(i), len(self.time)-1)] for i in ends] + return starts_in_time, ends_in_time, colors \ No newline at end of file diff --git a/src/param/param_detector.py b/src/param/param_detector.py index b044cb8..36c2b68 100644 --- a/src/param/param_detector.py +++ b/src/param/param_detector.py @@ -13,6 +13,10 @@ def from_dict(param_dict): param.detector_param = ParamSTE.from_dict(param_dict['detector_param']) elif detector_type.lower() == 'mni': param.detector_param = ParamMNI.from_dict(param_dict['detector_param']) + elif detector_type.lower() == 'hil': + param.detector_param = ParamHIL.from_dict(param_dict['detector_param']) + elif detector_type.lower() == 'yasa': + param.detector_param = ParamYASA.from_dict(param_dict['detector_param']) return param class ParamSTE: @@ -61,7 +65,6 @@ def from_dict(d): peak_thres = d["peak_thres"], n_jobs = d["n_jobs"] ) - class ParamMNI: def __init__(self,sample_freq, pass_band = 80, stop_band = 500, epoch_time=10, epo_CHF=60, per_CHF=95/100, @@ -118,3 +121,80 @@ def from_dict(d): d["base_min"], d["n_jobs"] ) + +class ParamHIL: + def __init__(self, sample_freq=2000, pass_band=80, stop_band=500, epoch_time=10, sd_threshold=5, min_window=0.006, n_jobs=32): + self.sample_freq = sample_freq + self.pass_band = pass_band + self.stop_band = stop_band + self.epoch_time = epoch_time + self.sd_threshold = sd_threshold + self.min_window = min_window + self.n_jobs = n_jobs + + def to_dict(self): + return { + "sample_freq": self.sample_freq, + "pass_band": self.pass_band, + "stop_band": self.stop_band, + "epoch_time": self.epoch_time, + "sd_threshold": self.sd_threshold, + "min_window": self.min_window, + "n_jobs": self.n_jobs + } + + @staticmethod + def from_dict(d): + return ParamHIL( + d["sample_freq"], + d["pass_band"], + d["stop_band"], + d["epoch_time"], + d["sd_threshold"], + d["min_window"], + d["n_jobs"] + ) + + +class ParamYASA: + def __init__(self, sample_freq=2000, freq_sp=(12, 15), freq_broad=(1, 30), duration=(0.5, 2), + min_distance=500, corr=0.65, rel_pow=0.2, rms=1.5, n_jobs=8): + self.sample_freq = sample_freq + self.freq_sp = freq_sp + self.freq_broad = freq_broad + self.duration = duration + self.min_distance = min_distance + # self.thresh = {'corr': corr, 'rel_pow': rel_pow, 'rms': rms} + self.corr = corr + self.rel_pow = rel_pow + self.rms = rms + self.n_jobs = n_jobs + + def to_dict(self): + return { + "sample_freq": self.sample_freq, + "freq_sp": self.freq_sp, + "freq_broad": self.freq_broad, + "duration": self.duration, + "min_distance": self.min_distance, + # "thresh": self.thresh, + "corr": self.corr, + "rel_pow": self.rel_pow, + "rms": self.rms, + "n_jobs": self.n_jobs + } + + @staticmethod + def from_dict(d): + return ParamYASA( + d["sample_freq"], + d["freq_sp"], + d["freq_broad"], + d["duration"], + d["min_distance"], + # d["thresh"], + d["corr"], + d["rel_pow"], + d["rms"], + d["n_jobs"] + ) \ No newline at end of file diff --git a/src/param/param_filter.py b/src/param/param_filter.py index b6ee99c..5679a14 100644 --- a/src/param/param_filter.py +++ b/src/param/param_filter.py @@ -9,6 +9,30 @@ def __init__(self, fp=80, fs=500, rp=0.5, rs=93, space=0.5, sample_freq=2000): def to_dict(self): return {'fp':self.fp, 'fs':self.fs, 'rp':self.rp, 'rs':self.rs, 'space':self.space, 'sample_freq':self.sample_freq} + @staticmethod + def from_dict(param_filter): + if not 'sample_freq' in param_filter: + param_filter['sample_freq'] = None + return ParamFilter( + fp = param_filter['fp'], + fs = param_filter['fs'], + rp = param_filter['rp'], + rs = param_filter['rs'], + sample_freq = param_filter['sample_freq'] + ) + + +class ParamFilterSpindle: + def __init__(self, fp=1, fs=30, rp=0.5, rs=93, space=0.5, sample_freq=2000): + self.fp = fp + self.fs = fs + self.rp = rp + self.rs = rs + self.space = space + self.sample_freq = sample_freq + def to_dict(self): + return {'fp':self.fp, 'fs':self.fs, 'rp':self.rp, 'rs':self.rs, 'space':self.space, 'sample_freq':self.sample_freq} + @staticmethod def from_dict(param_filter): if not 'sample_freq' in param_filter: diff --git a/src/spindle_app.py b/src/spindle_app.py new file mode 100644 index 0000000..6eb4547 --- /dev/null +++ b/src/spindle_app.py @@ -0,0 +1,563 @@ +import mne +import numpy as np +import scipy.signal as signal +# from models import ArtifactDetector, SpikeDetector +# import torch +from src.spindle_feature import SpindleFeature +from src.classifer import Classifier +from src.utils.utils_feature import * +from src.utils.utils_filter import construct_filter, filter_data +from src.utils.utils_detector import set_YASA_detector +from src.utils.utils_io import get_edf_info, read_eeg_data, dump_to_npz +from src.utils.utils_plotting import plot_feature + +from src.param.param_detector import ParamDetector +from src.param.param_filter import ParamFilter +from src.param.param_classifier import ParamClassifier + +import os +from p_tqdm import p_map +from pathlib import Path + + +class SpindleApp(object): + def __init__(self): + self.version = "1.0.0" + self.biomarker_type = 'Spindle' + self.n_jobs = 4 + ## eeg related + self.eeg_data = None + self.raw = None + self.channel_names = None + self.event_channel_names = None + self.sample_freq = 0 # Hz + self.edf_param = None + + ## filter related + self.sos = None + self.filter_data = None + self.param_filter = None + self.filtered = False + + # 60Hz filter related + self.eeg_data_un60 = None + self.filter_data_un60 = None + self.eeg_data_60 = None + self.filter_data_60 = None + + ## detector related + self.param_detector = None + self.detector = None + self.detected = False + + ## feature related + self.feature_param = None + self.event_features = None + + ## classifier related + self.param_classifier = None + self.classifier = None + self.classified = False + self.Spindles = None + + def load_edf(self, file_path): + print("Loading recording: " + file_path) + if file_path.split(".")[-1] == "edf": + self.raw = mne.io.read_raw_edf(file_path, verbose=0) + # otherwise if its a brainvision file + elif file_path.split(".")[-1] == "vhdr": + # first check if the .eeg and .vmrk files also exist + assert os.path.exists( + file_path.replace(".vhdr", ".eeg")), "The .eeg file does not exist, cannot load the data" + assert os.path.exists( + file_path.replace(".vhdr", ".vmrk")), "The .vmrk file does not exist, cannot load the data" + self.raw = mne.io.read_raw_brainvision(file_path, verbose=0) + elif file_path.split(".")[-1] == "eeg": + # first check if the .vhdr and .vmrk files also exist + assert os.path.exists( + file_path.replace(".eeg", ".vhdr")), "The .vhdr file does not exist, cannot load the data" + assert os.path.exists( + file_path.replace(".eeg", ".vmrk")), "The .vmrk file does not exist, cannot load the data" + self.raw = mne.io.read_raw_brainvision(file_path.replace(".eeg", ".vhdr") + , verbose=0) + elif file_path.split(".")[-1] == "vmrk": + # first check if the .vhdr and .eeg files also exist + assert os.path.exists( + file_path.replace(".vmrk", ".vhdr")), "The .vhdr file does not exist, cannot load the data" + assert os.path.exists( + file_path.replace(".vmrk", ".eeg")), "The .eeg file does not exist, cannot load the data" + self.raw = mne.io.read_raw_brainvision(file_path.replace(".vmrk", ".vhdr") + , verbose=0) + else: + raise ValueError("File type not supported") + self.edf_param = get_edf_info(self.raw) + self.sample_freq = int(self.edf_param['sfreq']) + self.edf_param["edf_fn"] = file_path + self.eeg_data, self.channel_names = read_eeg_data(self.raw) + self.eeg_data_un60 = self.eeg_data.copy() + self.eeg_data_60 = self.filter_60(self.eeg_data) + # print("channel names: ", self.channel_names) + # print("Loading COMPLETE!") + + def load_database(self): + # @TODO; load database + pass + + def get_edf_info(self): + return self.edf_param + + def get_eeg_data(self, start: int = None, end: int = None, filtered: bool = False): + data = self.eeg_data if not filtered else self.filter_data + if start is None and end is None: + return data, self.channel_names + elif start is None: + return data[:, :end], self.channel_names + elif end is None: + return data[:, start:], self.channel_names + else: + return data[:, start:end], self.channel_names + + def get_eeg_data_shape(self): + return self.eeg_data.shape + + def get_sample_freq(self): + return self.sample_freq + + def add_bipolar_channel(self, ch_1, ch_2): + + def bipolar(data, channels, ch1, ch2): + return data[channels == ch1] - data[channels == ch2] + + bipolar_signal = bipolar(self.eeg_data, self.channel_names, ch_1, ch_2) + bipolar_signalun60 = bipolar(self.eeg_data_un60, self.channel_names, ch_1, ch_2) + bipolar_signal60 = bipolar(self.eeg_data_60, self.channel_names, ch_1, ch_2) + + if self.filtered == True: + bipolar_filtered_60 = bipolar(self.filter_data_60, self.channel_names, ch_1, ch_2) + bipolar_filtered_un60 = bipolar(self.filter_data_un60, self.channel_names, ch_1, ch_2) + + self.channel_names = np.concatenate([[f"{ch_1}#-#{ch_2}"], self.channel_names]) + + # add filtered/unfiltered 60/un60 signals to different arrays + self.eeg_data = np.concatenate([bipolar_signal, self.eeg_data]) + self.eeg_data_un60 = np.concatenate([bipolar_signalun60, self.eeg_data_un60]) + self.eeg_data_60 = np.concatenate([bipolar_signal60, self.eeg_data_60]) + if self.filtered == True: + self.filtered_data_60 = np.concatenate([self.filter_data_60, bipolar_filtered_60]) + self.filtered_data_un60 = np.concatenate([self.filter_data_un60, bipolar_filtered_un60]) + self.filter_data = self.filtered_data_un60.copy() + + ''' + Filter API + ''' + + def set_filter_parameter(self, param_filter: ParamFilter): + self.param_filter = param_filter + self.param_filter.sample_freq = self.sample_freq + sos = construct_filter(param_filter.fp, param_filter.fs, param_filter.rp, param_filter.rs, param_filter.space, + param_filter.sample_freq) + # if any value in sos is nan, then raise error + if np.isnan(sos).any(): + raise ValueError("filter parameter is invalid") + self.sos = sos + + def filter_eeg_data(self, param_filter: ParamFilter = None): + ''' + This is a function should be linked to the filter button + ''' + + if param_filter is not None: + param_filter.sample_freq = self.sample_freq + self.set_filter_parameter(param_filter) + elif self.sos is None: + raise ValueError("filter parameter is not set") + + self.filter_data = [] + param_list = [{"data": self.eeg_data_un60[i], "sos": self.sos} for i in range(len(self.eeg_data_un60))] + # for i in range(len(param_list)): + # print("data shape:",param_list[i]["data"].shape, "sos shape:", param_list[i]["sos"].shape) + ret = parallel_process(param_list, filter_data, n_jobs=self.n_jobs, use_kwargs=True, front_num=2) + for r in ret: + self.filter_data.append(r) + self.filter_data = np.array(self.filter_data) + self.filter_data_un60 = self.filter_data.copy() + + # Spindle frequency range do not have 60hz problem + self.filter_data_60 = self.filter_60(self.filter_data) + if self.filter_data_60.size == 0: + self.filter_data_60 = self.filter_data.copy() + + self.filtered = True + + def has_filtered_data(self): + return self.filter_data is not None or len(self.filter_data) > 0 + + def filter_60(self, data): + """filter 60Hz noise""" + filter_sos = signal.butter(5, [58, 62], 'bandstop', fs=self.sample_freq, output='sos') + param_list = [{"data": data[i], "sos": filter_sos} for i in range(len(data))] + ret = parallel_process(param_list, filter_data, n_jobs=self.n_jobs, use_kwargs=True, front_num=2) + data_out = [] + for r in ret: + data_out.append(r) + return np.array(data_out) + + def set_filter_60(self): + # self.eeg_data_un60 = self.eeg_data.copy() + self.eeg_data = self.eeg_data_60.copy() + if self.filtered: + # self.filter_data_un60 = self.filter_data.copy() + self.filter_data = self.filter_data_60.copy() + + def set_unfiltered_60(self): + self.eeg_data = self.eeg_data_un60.copy() + if self.filtered: + self.filter_data = self.filter_data_un60.copy() + + ''' + Detector APIs + ''' + + def set_detector(self, param: ParamDetector): + ''' + This is the function should be linked to the confirm button in the set detector window + + param should be a type of ParamDetector param/param_detector.py + it should contain the following fields: + detector_type: str, "STE", "MNI" or "HIL" + detector_param: param_detector/ParamSTE or param_detector/ParamMNI + + ''' + + self.param_detector = param + # self.param_detector.detector_param.sample_freq = self.sample_freq + # self.param_detector.detector_param.pass_band = int(self.param_filter.fp) + # self.param_detector.detector_param.stop_band = int(self.param_filter.fs) + # print("detector param: ", param.detector_param.to_dict()) + if param.detector_type.lower() == "yasa": + self.detector = set_YASA_detector(param.detector_param) + else: + print('To be continued') + + + def detect_biomarker(self, param_filter: ParamFilter = None, param_detector: ParamDetector = None): + ''' + This the function should be linked to the detect button in the overview window, + it can also be called with a param to set the detector, the detector will be reseted if the param is not None + ''' + ## TODO: what is the detector param's filter is not the same as filter param + # if param_filter is not None and not self.has_filtered_data(): + # self.filter_eeg_data(param_filter) + # if param_detector is not None: + # self.set_detector(param_detector) + # if self.filter_data is None or len(self.filter_data) == 0: + # self.filter_eeg_data() + # self.event_channel_names, self.HFOs = self.detector.detect_multi_channels(self.filter_data, self.channel_names, + # filtered=True) + param_detector = self.detector['args'] + detector_yasa = self.detector['yasa'] + sp = detector_yasa.spindles_detect(self.eeg_data, sf=param_detector.sample_freq, + ch_names=self.channel_names.tolist(), freq_sp=param_detector.freq_sp, + freq_broad=param_detector.freq_broad, duration=param_detector.duration, + min_distance=param_detector.min_distance, thresh={'corr': param_detector.corr, + 'rel_pow': param_detector.rel_pow, + 'rms': param_detector.rms}) + if self.filter_data is None and sp is not None: + self.filter_data = sp._data_filt.copy() + self.filter_data_60 = sp._data_filt.copy() + self.filter_data_un60 = sp._data_filt.copy() + self.event_features = SpindleFeature.construct(sp, self.param_detector.detector_type, self.sample_freq) + self.detected = True + + ''' + Feature APIs + + ''' + + def generate_biomarker_features(self): + ''' + Todo: feature generation parameter. + ''' + freq_range = [10, self.edf_param['lowpass'] // 2] + win_size = 224 + time_range = [0, 1000] # 0~1000ms + + starts = self.event_features.starts + ends = self.event_features.ends + channel_names = self.event_features.channel_names + hfo_waveforms = extract_waveforms(self.eeg_data, starts, ends, channel_names, self.channel_names, + self.sample_freq, time_range) + param_list = [{"start": starts[i], "end": ends[i], "data": hfo_waveforms[i], "channel_name": channel_names[i], + "sample_rate": self.sample_freq, + "win_size": win_size, + "ps_MinFreqHz": freq_range[0], + "ps_MaxFreqHz": freq_range[1], + "time_window_ms": (time_range[1] - time_range[0]) // 2, + } for i in range(len(starts))] + ret = parallel_process(param_list, compute_biomarker_feature, n_jobs=self.n_jobs, use_kwargs=True, front_num=2) + starts, ends, channel_names, time_frequncy_img, amplitude_coding_plot = np.zeros(len(ret)), np.zeros( + len(ret)), np.empty(len(ret), dtype=object), np.zeros((len(ret), win_size, win_size)), np.zeros( + (len(ret), win_size, win_size)) + for i in range(len(ret)): + channel_names[i], starts[i], ends[i], time_frequncy_img[i], amplitude_coding_plot[i] = ret[i] + interval = np.concatenate([starts[:, None], ends[:, None]], axis=1) + feature = np.concatenate([time_frequncy_img[:, None, :, :], amplitude_coding_plot[:, None, :, :]], axis=1) + self.event_features = SpindleFeature(channel_names, interval, feature, sample_freq=self.sample_freq, + detector_type=self.param_detector.detector_type, feature_size=win_size, + freq_range=freq_range, time_range=time_range) + + ''' + Classifier APIs + + ''' + + def has_cuda(self): + ''' + first check if cuda is available then set param + if it returns true, then the the user can select device + else the user can only select cpu + ''' + # return torch.cuda.is_available() + return False + + def get_classifier_param(self): + ## todo: change it to database + ''' + return the param_classifier This is the information should be shown in the ovewview window + it also should be retrived when user set the classifier + ''' + return self.param_classifier + + def set_classifier(self, param: ParamClassifier): + ''' + This is the function should be linked to the confirm button in the set classifier window + + ''' + self.set_artifact_classifier(param) + self.set_spike_classifier(param) + + def set_artifact_classifier(self, param: ParamClassifier): + ''' + This is the function should be linked to the confirm button in the set artifact window + + ''' + + self.param_classifier = param + if self.classifier is None: + self.classifier = Classifier(param) + else: + self.classifier.update_model_a(param) + + def set_spike_classifier(self, param: ParamClassifier): + ''' + This is the function should be linked to the confirm button in the set spike window + + ''' + self.param_classifier = param + self.classifier.update_model_s(param) + + def set_default_cpu_classifier(self): + ''' + This is the function should be linked to the default cpu button in the set artifact window + ''' + artifact_path = os.path.join(Path(os.path.dirname(__file__)).parent, "ckpt", "model_a.tar") + spike_path = os.path.join(Path(os.path.dirname(__file__)).parent, "ckpt", "model_s.tar") + self.param_classifier = ParamClassifier(artifact_path=artifact_path, spike_path=spike_path, use_spike=True, + device="cpu", batch_size=32, model_type="default_cpu") + self.classifier = Classifier(self.param_classifier) + + def set_default_gpu_classifier(self): + ''' + This is the function should be linked to the default gpu button in the set artifact window + ''' + artifact_path = os.path.join(Path(os.path.dirname(__file__)).parent, "ckpt", "model_a.tar") + spike_path = os.path.join(Path(os.path.dirname(__file__)).parent, "ckpt", "model_s.tar") + self.param_classifier = ParamClassifier(artifact_path=artifact_path, spike_path=spike_path, use_spike=True, + device="cuda:0", batch_size=32, model_type="default_gpu") + self.classifier = Classifier(self.param_classifier) + + def classify_artifacts(self, ignore_region=[1, 1], threshold=0.5): + if not self.event_features.has_feature(): + self.generate_biomarker_features() + ignore_region = np.array(ignore_region) * self.sample_freq + ignore_region = np.array([ignore_region[0], len(self.eeg_data[0]) - ignore_region[1]]) + self.classifier.artifact_detection(self.event_features, ignore_region, threshold=threshold) + self.classified = True + + def classify_spikes(self): + if not self.event_features.has_feature(): + self.generate_biomarker_features() + self.classifier.spike_detection(self.event_features) + + ''' + results APIs + ''' + + def get_res_overview(self): + ''' + return the overview of the results + ''' + if not self.event_features.has_feature(): + self.generate_biomarker_features() + return { + "n_Spindle": self.event_features.num_spindle, + "n_artifact": self.event_features.num_artifact, + "n_real": self.event_features.num_real, + "n_spike": self.event_features.num_spike + } + + def export_report(self, path): + if not self.event_features: + return None + self.event_features.export_csv(path) + + def export_excel(self, path): + if not self.event_features: + return None + self.event_features.export_excel(path) + + def export_app(self, path): + ''' + export all the data from app to a tar file + ''' + checkpoint = { + "n_jobs": self.n_jobs, + "eeg_data": self.eeg_data, + "edf_param": self.edf_param, + "sample_freq": self.sample_freq, + "channel_names": self.channel_names, + "param_filter": self.param_filter.to_dict() if self.param_filter else None, + "Spindles": self.Spindles, + "param_detector": self.param_detector.to_dict() if self.param_detector else None, + "Spindle_features": self.event_features.to_dict() if self.event_features else None, + "param_classifier": self.param_classifier.to_dict() if self.param_classifier else None, + "classified": self.classified, + "filtered": self.filtered, + "detected": self.detected, + "artifact_predictions": np.array(self.event_features.artifact_predictions), + "spike_predictions": np.array(self.event_features.spike_predictions), + "artifact_annotations": np.array(self.event_features.artifact_annotations), + "spike_annotations": np.array(self.event_features.spike_annotations), + "annotated": np.array(self.event_features.annotated), + } + dump_to_npz(checkpoint, path) + + @staticmethod + def import_app(path): + ''' + import all the data from a tar file to app + ''' + checkpoint = np.load(path, allow_pickle=True) + app = SpindleApp() + app.n_jobs = checkpoint["n_jobs"].item() + app.eeg_data = checkpoint["eeg_data"] + app.edf_param = checkpoint["edf_param"].item() + app.sample_freq = checkpoint["sample_freq"] + app.channel_names = checkpoint["channel_names"] + app.classified = checkpoint["classified"].item() + app.filtered = checkpoint["filtered"].item() + app.detected = checkpoint["detected"].item() + app.event_features.artifact_predictions = checkpoint["artifact_predictions"].item() + app.event_features.spike_predictions = checkpoint["spike_predictions"].item() + app.event_features.artifact_annotations = checkpoint["artifact_annotations"].item() + app.event_features.spike_annotations = checkpoint["spike_annotations"].item() + app.event_features.annotated = checkpoint["annotated"].item() + if app.filtered: + app.param_filter = ParamFilter.from_dict(checkpoint["param_filter"].item()) + app.filter_eeg_data(app.param_filter) + if app.detected: + # print("detected Spindles") + app.Spindles = checkpoint["Spindles"] + app.param_detector = ParamDetector.from_dict(checkpoint["param_detector"].item()) + # print("new Spindle features") + app.event_features = SpindleFeature.from_dict(checkpoint["Spindle_features"].item()) + # print(app.event_features) + if app.classified: + app.param_classifier = ParamClassifier.from_dict(checkpoint["param_classifier"].item()) + return app + + def export_features(self, folder): + def clean_folder(folder): + import shutil + if os.path.exists(folder): + shutil.rmtree(folder) + os.makedirs(folder) + + def extract_data(data, data_filtered, start, end): + data = np.squeeze(data) + data_filtered = np.squeeze(data_filtered) + if start < self.sample_freq // 2: + plot_start, plot_end = 0, self.sample_freq + hfo_start, hfo_end = start, min(end, self.sample_freq) + elif end > len(data) - self.sample_freq // 2: + plot_start, plot_end = len(data) - self.sample_freq, len(data) + hfo_start, hfo_end = max(plot_start, start) - plot_start, min(plot_end, end) - plot_start + else: + plot_start, plot_end = (start + end) // 2 - self.sample_freq // 2, ( + start + end) // 2 + self.sample_freq // 2 + hfo_start, hfo_end = max(plot_start, start) - plot_start, min(plot_end, end) - plot_start + plot_start, plot_end, hfo_start, hfo_end = int(plot_start), int(plot_end), int(hfo_start), int(hfo_end) + channel_data = data[plot_start:plot_end] + channel_data_f = data_filtered[plot_start:plot_end] + # print(hfo_start, hfo_end, start, end, plot_start, plot_end, channel_data.shape, channel_data_f.shape) + return channel_data, channel_data_f, hfo_start, hfo_end + + def extract_waveform(data, data_filtered, starts, ends, channel_names, unique_channel_names): + hfo_waveform_l, hfo_waveform_f_l, hfo_start_l, hfo_end_l = np.zeros((len(starts), 2000)), np.zeros( + (len(starts), 2000)), [], [] + for i in tqdm(range(len(starts))): + channel_name = channel_names[i] + start = starts[i] + end = ends[i] + channel_index = np.where(unique_channel_names == channel_name)[0] + hfo_waveform, hfo_waveform_f, hfo_start, hfo_end = extract_data(data[channel_index], + data_filtered[channel_index], start, + end) + hfo_waveform_l[i] = hfo_waveform + hfo_waveform_f_l[i] = hfo_waveform_f + hfo_start_l.append(hfo_start) + hfo_end_l.append(hfo_end) + return hfo_waveform_l, hfo_waveform_f_l, np.array(hfo_start_l), np.array(hfo_end_l) + + if not self.event_features: + return None + os.makedirs(folder, exist_ok=True) + artifact_folder = os.path.join(folder, "artifact") + spike_folder = os.path.join(folder, "spike") + non_spike_folder = os.path.join(folder, "non_spike") + clean_folder(artifact_folder) + clean_folder(spike_folder) + clean_folder(non_spike_folder) + starts = self.event_features.starts + ends = self.event_features.ends + feature = self.event_features.features + channel_names = self.event_features.channel_names + spike_predictions = self.event_features.spike_predictions + index_s = np.where(spike_predictions == 1)[0] + start_s, end_s, feature_s, channel_names_s = starts[index_s], ends[index_s], feature[index_s], channel_names[ + index_s] + index_a = np.where(spike_predictions == -1)[0] + start_a, end_a, feature_a, channel_names_a = starts[index_a], ends[index_a], feature[index_a], channel_names[ + index_a] + index_r = np.where(spike_predictions == 0)[0] + start_r, end_r, feature_r, channel_names_r = starts[index_r], ends[index_r], feature[index_r], channel_names[ + index_r] + # print("plotting Spindle with spike") + waveform_s, waveform_f_s, hfo_start_s, hfo_end_s = extract_waveform(self.eeg_data, self.filter_data, start_s, + end_s, channel_names_s, self.channel_names) + param_list = [{"folder": spike_folder, "start": start_s[i], "end": end_s[i], "feature": feature_s[i], + "channel_name": channel_names_s[i], "data": waveform_s[i], "data_filtered": waveform_f_s[i], + "hfo_start": hfo_start_s[i], "hfo_end": hfo_end_s[i]} for i in range(len(start_s))] + ret = parallel_process(param_list, plot_feature, self.n_jobs, use_kwargs=True, front_num=3) + waveform_a, waveform_f_a, hfo_start_a, hfo_end_a = extract_waveform(self.eeg_data, self.filter_data, start_a, + end_a, channel_names_a, self.channel_names) + param_list = [{"folder": artifact_folder, "start": start_a[i], "end": end_a[i], "feature": feature_a[i], + "channel_name": channel_names_a[i], "data": waveform_a[i], "data_filtered": waveform_f_a[i], + "hfo_start": hfo_start_a[i], "hfo_end": hfo_end_a[i]} for i in range(len(start_a))] + ret = parallel_process(param_list, plot_feature, self.n_jobs, use_kwargs=True, front_num=3) + waveform_r, waveform_f_r, hfo_start_r, hfo_end_r = extract_waveform(self.eeg_data, self.filter_data, start_r, + end_r, channel_names_r, self.channel_names) + param_list = [{"folder": non_spike_folder, "start": start_r[i], "end": end_r[i], "feature": feature_r[i], + "channel_name": channel_names_r[i], "data": waveform_r[i], "data_filtered": waveform_f_r[i], + "hfo_start": hfo_start_r[i], "hfo_end": hfo_end_r[i]} for i in range(len(start_r))] + ret = parallel_process(param_list, plot_feature, self.n_jobs, use_kwargs=True, front_num=3) diff --git a/src/spindle_feature.py b/src/spindle_feature.py new file mode 100644 index 0000000..05cc57c --- /dev/null +++ b/src/spindle_feature.py @@ -0,0 +1,303 @@ +import numpy as np +import pandas as pd + + +class SpindleFeature(object): + def __init__(self, channel_names, starts, ends, features=[], detector_type="STE", sample_freq=2000, freq_range=[10, 500], + time_range=[0, 1000], feature_size=224): + self.channel_names = channel_names + if starts.size == 0: + self.starts = np.array([]) + self.ends = np.array([]) + self.artifact_predictions = np.array([]) + self.artifact_annotations = np.array([]) + self.spike_annotations = np.array([]) + self.annotated = np.array([]) + else: + self.starts = starts + self.ends = ends + self.features = features + self.artifact_predictions = np.zeros(self.starts.shape) + self.spike_predictions = [] + self.artifact_annotations = np.zeros(self.starts.shape) + self.spike_annotations = np.zeros(self.starts.shape) + self.annotated = np.zeros(self.starts.shape) + self.detector_type = detector_type + self.sample_freq = sample_freq + self.feature_size = 0 + self.freq_range = freq_range + self.time_range = time_range + self.feature_size = feature_size + self.num_artifact = 0 + self.num_spike = 0 + self.num_spindle = len(self.starts) + self.num_real = 0 + self.index = 0 + self.artifact_predicted = False + self.spike_predicted = False + + def __str__(self): + return "Spindle_Feature: {} Spindles, {} artifacts, {} spikes, {} real Spindles".format(self.num_spindle, self.num_artifact, + self.num_spike, self.num_real) + + @staticmethod + def construct(result, detector_type="STE", sample_freq=2000, freq_range=[10, 500], + time_range=[0, 1000], feature_size=224): + ''' + Construct SpindleFeature object from detector output + ''' + result_summary = result.summary() + channel_names = result_summary['Channel'].to_numpy() + start = result_summary['Start'].to_numpy() * sample_freq + end = result_summary['End'].to_numpy() * sample_freq + return SpindleFeature(channel_names, start, end, np.array([]), detector_type, sample_freq, freq_range, time_range, feature_size) + + def get_num_biomarker(self): + return self.num_spindle + + def has_prediction(self): + return self.artifact_predicted + + # def generate_psedo_label(self): + # self.artifact_predictions = np.ones(self.num_spindle) + # self.spike_predictions = np.zeros(self.num_spindle) + # self.artifact_predicted = True + + def doctor_annotation(self, annotation: str): + if annotation == "Artifact": + self.artifact_annotations[self.index] = 0 + elif annotation == "Spike": + self.spike_annotations[self.index] = 1 + self.artifact_annotations[self.index] = 1 + elif annotation == "Real": + self.spike_annotations[self.index] = 0 + self.artifact_annotations[self.index] = 1 + self.annotated[self.index] = 1 + + def get_next(self): + if self.index >= self.num_spindle - 1: + self.index = 0 + else: + self.index += 1 + # returns the next hfo start and end index instead of next window start and end index + return self.channel_names[self.index], self.starts[self.index], self.ends[self.index] + + def get_prev(self): + if self.index <= 0: + self.index = 0 + else: + self.index -= 1 + # the same as above + return self.channel_names[self.index], self.starts[self.index], self.ends[self.index] + + def get_jump(self, index): + self.index = index + # the same as above + return self.channel_names[self.index], self.starts[self.index], self.ends[self.index] + + def get_current(self): + return self.channel_names[self.index], self.starts[self.index], self.ends[self.index] + + def _get_prediction(self, artifact_prediction, spike_prediction): + if artifact_prediction < 1: + return "Artifact" + elif spike_prediction == 1: + return "Spike" + else: + return "Spindle" + + def get_current_info(self): + print("self.artifact_predicted:", self.artifact_predicted) + channel_name = self.channel_names[self.index] + start = self.starts[self.index] + end = self.ends[self.index] + prediction = self._get_prediction(self.artifact_predictions[self.index], + self.spike_predictions[self.index]) if self.artifact_predicted else None + annotation = self._get_prediction(self.artifact_annotations[self.index], self.spike_annotations[self.index]) if \ + self.annotated[self.index] else None + return {"channel_name": channel_name, "start_index": start, "end_index": end, "prediction": prediction, + "annotation": annotation} + + def get_num_artifact(self): + return self.num_artifact + + def get_num_spike(self): + return self.num_spike + + def get_num_real(self): + return self.num_real + + def has_feature(self): + return len(self.features) > 0 + + def get_features(self): + return self.features + + def to_dict(self): + channel_names = self.channel_names + starts = self.starts + ends = self.ends + artifact_predictions = np.array(self.artifact_predictions) + spike_predictions = np.array(self.spike_predictions) + feature = self.features + detector_type = self.detector_type + sample_freq = self.sample_freq + feature_size = self.feature_size + freq_range = self.freq_range + time_range = self.time_range + return {"channel_names": channel_names, "starts": starts, "ends": ends, + "artifact_predictions": artifact_predictions, "spike_predictions": spike_predictions, + "feature": feature, "detector_type": detector_type, "sample_freq": sample_freq, "feature_size": feature_size, + "freq_range": freq_range, "time_range": time_range} + + @staticmethod + def from_dict(data): + ''' + construct SpindleFeature object from dictionary + ''' + channel_names = data["channel_names"] + starts = data["starts"] + ends = data["ends"] + artifact_predictions = data["artifact_predictions"] + spike_predictions = data["spike_predictions"] + feature = data["feature"] + detector_type = data["detector_type"] + sample_freq = data["sample_freq"] + feature_size = data["feature_size"] + freq_range = data["freq_range"] + time_range = data["time_range"] + biomarker_feature = SpindleFeature(channel_names, np.array([starts, ends]).T, feature, detector_type, sample_freq, freq_range, + time_range, feature_size) + biomarker_feature.update_pred(artifact_predictions, spike_predictions) + + return biomarker_feature + + def update_artifact_pred(self, artifact_predictions): + self.artifact_predicted = True + self.artifact_predictions = artifact_predictions + self.num_artifact = np.sum(artifact_predictions < 1) + self.num_real = np.sum(artifact_predictions > 0) + + def update_spike_pred(self, spike_predictions): + self.spike_predicted = True + self.spike_predictions = spike_predictions + self.num_spike = np.sum(spike_predictions == 1) + + def update_pred(self, artifact_predictions, spike_predictions): + self.update_artifact_pred(artifact_predictions) + self.update_spike_pred(spike_predictions) + + def group_by_channel(self): + channel_names = self.channel_names + starts = self.starts + ends = self.ends + artifact_predictions = np.array(self.artifact_predictions) + spike_predictions = np.array(self.spike_predictions) + channel_names_unique = np.unique(channel_names) + interval = np.array([starts, ends]).T + channel_name_g, interval_g, artifact_predictions_g, spike_predictions_g = [], [], [], [] + for channel_name in channel_names_unique: + channel_index = np.where(channel_names == channel_name)[0] + interval_g.append(interval[channel_index]) + channel_name_g.append(channel_name) + if len(artifact_predictions) > 0: + artifact_predictions_g.append(artifact_predictions[channel_index]) + if len(spike_predictions) > 0: + spike_predictions_g.append(spike_predictions[channel_index]) + return channel_name_g, interval_g, artifact_predictions_g, spike_predictions_g + + def get_biomarkers_for_channel(self, channel_name: str, min_start: int = None, max_end: int = None): + channel_names = self.channel_names + starts = self.starts + ends = self.ends + artifact_predictions = np.array(self.artifact_predictions) + spike_predictions = np.array(self.spike_predictions) + indexes = channel_names == channel_name + if min_start is not None and max_end is not None: + indexes = indexes & (starts >= min_start) & (ends <= max_end) + starts = starts[indexes] + ends = ends[indexes] + try: + artifact_predictions = artifact_predictions[indexes] + spike_predictions = spike_predictions[indexes] == 1 + except: + artifact_predictions = [] + spike_predictions = [] + return starts, ends, artifact_predictions, spike_predictions + + def get_annotation_text(self, index): + channel_name = self.channel_names[index] + if self.annotated[index] == 0: + suffix = "Unannotated" + elif self.artifact_annotations[index] == 0: + suffix = "Artifact" + elif self.spike_annotations[index] == 1: + suffix = "Spike" + else: + suffix = "Real" + return f" No.{index + 1}: {channel_name} : {suffix}" + + def to_df(self): + channel_names = self.channel_names + starts = self.starts + ends = self.ends + artifact_predictions = np.array(self.artifact_predictions) + spike_predictions = np.array(self.spike_predictions) + artifact_annotations = np.array(self.artifact_annotations) + spike_annotations = np.array(self.spike_annotations) + annotated = np.array(self.annotated) + df = pd.DataFrame() + df["channel_names"] = channel_names + df["starts"] = starts + df["ends"] = ends + # df["doctor_annotation"] = self.doctor_annotation + if len(artifact_predictions) > 0: + df["artifact"] = artifact_predictions + if len(spike_predictions) > 0: + df["spike"] = spike_predictions + df['annotated'] = annotated + if len(artifact_annotations) > 0: + df["artifact annotations"] = artifact_annotations + if len(spike_annotations) > 0: + df["spike annotations"] = spike_annotations + return df + + def export_csv(self, file_path): + df = self.to_df() + df.to_csv(file_path, index=False) + + def export_excel(self, file_path): + df = self.to_df() + df_out = df.copy() + if "artifact" not in df_out.columns: + df_out["artifact"] = 0 + if "spike" not in df_out.columns: + df_out["spike"] = 0 + if "artifact annotations" not in df_out.columns: + df_out["artifact annotations"] = 0 + if "spike annotations" not in df_out.columns: + df_out["spike annotations"] = 0 + df_out["artifact"] = (df_out["artifact"] > 0).astype(int) + df_out["spike"] = (df_out["spike"] > 0).astype(int) + df_out['annotated'] = 1 - (df_out["annotated"] > 0).astype(int) + df_out["artifact annotations"] = (df_out["artifact annotations"] > 0).astype(int) + df_out["spike annotations"] = (df_out["spike annotations"] > 0).astype(int) + df_channel = df_out.groupby("channel_names").agg({"starts": "count", + "artifact": "sum", "spike": "sum", + "annotated": "sum", + "artifact annotations": "sum", + "spike annotations": "sum"}).reset_index() + df_channel.rename(columns={"starts": "Total Detection", + "artifact": "Spindle", "spike": "spk-Spindle", + "annotated": "Unannotated", + "artifact annotations": "Spindle annotations", + "spike annotations": "spk-Spindle annotations"}, inplace=True) + df.rename(columns={"artifact": "Spindle", "spike": "spk-Spindle", + "annotated": "Annotated", + "artifact annotations": "Spindle annotations", "spike annotations": "spk-Spindle annotations"}, + inplace=True) + df['Annotated'] = df["Annotated"] > 0 + df['Annotated'] = df['Annotated'].replace({True: 'Yes', False: 'No'}) + with pd.ExcelWriter(file_path) as writer: + df_channel.to_excel(writer, sheet_name="Channels", index=False) + df.to_excel(writer, sheet_name="Events", index=False) \ No newline at end of file diff --git a/src/ui/annotation.py b/src/ui/annotation.py index ed2c494..e40d81f 100644 --- a/src/ui/annotation.py +++ b/src/ui/annotation.py @@ -1,83 +1,71 @@ -from PyQt5 import QtCore, QtWidgets, QtGui from PyQt5 import uic -from PyQt5.QtWidgets import QFileDialog, QMessageBox -import matplotlib.pyplot as plt - -import pyqtgraph as pg - -import sys -from PyQt5 import QtCore, QtGui, QtWidgets -import pyqtgraph as pg # We will try using pyqtgraph for plotting -import time -import mne -# from superqt import QDoubleRangeSlider -from tqdm import tqdm -import os -from src.hfo_app import HFO_App -from src.hfo_feature import HFO_Feature - -from src.utils.utils_annotation import * - -import random -import scipy.fft as fft # FFT plot (5) -import numpy as np - -import re from pathlib import Path -from src.hfo_app import HFO_App -from src.param.param_classifier import ParamClassifier -from src.param.param_detector import ParamDetector, ParamSTE, ParamMNI -from src.param.param_filter import ParamFilter from src.utils.utils_gui import * -from src.ui.plot_waveform import * +# from src.ui.plot_waveform import * +from PyQt5 import QtCore, QtGui, QtWidgets from PyQt5.QtCore import QSize - -# from src.ui.plot_annotation_waveform import * -# from src.ui.a_channel_selection import AnnotationChannelSelection - -# from src.plot_time_frequency import PlotTimeFrequencyNoLabel from src.utils.utils_plotting import * from src.ui.annotation_plot import AnnotationPlot, FFTPlot -# from src.plot_time_frequency import MainWindow - -import multiprocessing as mp -import torch +from src.controllers import AnnotationController ROOT_DIR = Path(__file__).parent -class HFOAnnotation(QtWidgets.QMainWindow): - def __init__(self, hfo_app=None, main_window=None, close_signal=None): - super(HFOAnnotation, self).__init__(main_window) - print("initializing HFOAnnotation") - self.hfo_app = hfo_app +class Annotation(QtWidgets.QMainWindow): + def __init__(self, backend=None, main_window=None, close_signal=None, biomarker_type='HFO'): + super(Annotation, self).__init__(main_window) + self.annotation_controller = AnnotationController(self, backend) + + self.biomarker_type = biomarker_type + print(f"initializing {self.biomarker_type} Annotation") + self.backend = backend self.ui = uic.loadUi(os.path.join(ROOT_DIR, 'annotation.ui'), self) - self.setWindowTitle("HFO Annotator") + self.setWindowTitle(f"{self.biomarker_type} Annotator") self.setWindowIcon(QtGui.QIcon(os.path.join(ROOT_DIR, 'src/ui/images/icon.png'))) self.threadpool = QThreadPool() self.close_signal = close_signal - self.close_signal.connect(self.close) - self.PreviousButton.clicked.connect(self.plot_prev) - self.NextButton.clicked.connect(self.plot_next) - self.Accept.clicked.connect(self.update_button_clicked) - - self.IntervalDropdownBox.currentIndexChanged.connect(self.update_interval) # Connect the interval dropdown box - # create the main waveform plot which we want to embed in VisulaizationVerticalLayout - self.waveform_plot = AnnotationPlot(hfo_app=self.hfo_app) - self.VisulaizationVerticalLayout.addWidget(self.waveform_plot) + safe_connect_signal_slot(self.close_signal, self.close) + safe_connect_signal_slot(self.PreviousButton.clicked, self.plot_prev) + safe_connect_signal_slot(self.NextButton.clicked, self.plot_next) + safe_connect_signal_slot(self.Accept.clicked, self.update_button_clicked) + + # init event type selection dropdown box + self.EventDropdown_Box.clear() + if self.backend.biomarker_type == 'HFO': + self.EventDropdown_Box.addItems(["--- Event Type ---", "Spike", "Real", "Artifact"]) + elif self.backend.biomarker_type == "Spindle": + self.EventDropdown_Box.addItems(["--- Event Type ---", "Spike", "Real", "Artifact"]) + self.EventDropdown_Box.setCurrentIndex(0) + + # init interval selection dropdown box + self.IntervalDropdownBox.clear() + if self.backend.biomarker_type == 'HFO': + self.IntervalDropdownBox.addItems(["1s", "0.5s", "0.25s"]) + elif self.backend.biomarker_type == "Spindle": + self.IntervalDropdownBox.addItems(["4s", "3.5s"]) + self.IntervalDropdownBox.setCurrentIndex(0) + + # Connect the interval dropdown box + safe_connect_signal_slot(self.IntervalDropdownBox.currentIndexChanged, self.update_interval) + + # create the main waveform plot + self.init_waveform_plot() + + # create fft plot + self.init_fft_plot() - self.fft_plot = FFTPlot(hfo_app=self.hfo_app) - self.FFT_layout.addWidget(self.fft_plot) - - channel, start, end = self.hfo_app.hfo_features.get_current() - self.waveform_plot.plot(start, end, channel, interval=1.0) # Default interval - self.fft_plot.plot(start, end, channel, interval=1.0) # Default interval self.init_annotation_dropdown() self.update_infos() self.setInitialSize() self.setWindowModality(QtCore.Qt.ApplicationModal) # Set as modal dialog - + + def init_waveform_plot(self): + self.annotation_controller.create_waveform_plot() + + def init_fft_plot(self): + self.annotation_controller.create_fft_plot() + def setInitialSize(self): # Getting screen resolution of your monitor screen = QApplication.primaryScreen() @@ -101,43 +89,38 @@ def get_current_interval(self): def plot_prev(self): # start, end: index of the prev hfo - channel, start, end = self.hfo_app.hfo_features.get_prev() + channel, start, end = self.annotation_controller.get_previous_event() interval = self.get_current_interval() - # interval = float(self.IntervalDropdownBox.currentText().rstrip('s')) # Get the current interval - self.waveform_plot.plot(start, end, channel, interval=interval) - self.fft_plot.plot(start, end, channel, interval=interval) + self.annotation_controller.update_plots(start, end, channel, interval) self.update_infos() def plot_next(self): # start, end: index of the next hfo - channel, start, end = self.hfo_app.hfo_features.get_next() + channel, start, end = self.annotation_controller.get_next_event() interval = self.get_current_interval() - # interval = float(self.IntervalDropdownBox.currentText().rstrip('s')) # Get the current interval - self.waveform_plot.plot(start, end, channel, interval=interval) - self.fft_plot.plot(start, end, channel, interval=interval) + self.annotation_controller.update_plots(start, end, channel, interval) self.update_infos() def plot_jump(self): selected_index = self.AnotationDropdownBox.currentIndex() # start, end: index of the next hfo - channel, start, end = self.hfo_app.hfo_features.get_jump(selected_index) + channel, start, end = self.annotation_controller.get_jumped_event(selected_index) try: interval = float(self.IntervalDropdownBox.currentText().rstrip('s')) except (ValueError, AttributeError): interval = 1.0 # Default interval # interval = float(self.IntervalDropdownBox.currentText().rstrip('s')) # Get the current interval - self.waveform_plot.plot(start, end, channel, interval=interval) - self.fft_plot.plot(start, end, channel, interval=interval) + self.annotation_controller.update_plots(start, end, channel, interval) self.update_infos() def update_infos(self): - info = self.hfo_app.hfo_features.get_current_info() - fs = self.hfo_app.sample_freq + info = self.backend.event_features.get_current_info() + fs = self.backend.sample_freq self.channel_name_textbox.setText(info["channel_name"]) self.start_textbox.setText(str(round(info["start_index"] / fs, 3)) + " s") self.end_textbox.setText(str(round(info["end_index"] / fs, 3)) + " s") self.length_textbox.setText(str(round((info["end_index"] - info["start_index"]) / fs, 3)) + " s") - self.AnotationDropdownBox.setCurrentIndex(self.hfo_app.hfo_features.index) + self.AnotationDropdownBox.setCurrentIndex(self.backend.event_features.index) print(info["prediction"]) if info["prediction"] is not None: self.model_textbox.setText(info["prediction"]) @@ -150,31 +133,27 @@ def update_button_clicked(self): # print("updating now...") selected_text = self.EventDropdown_Box.currentText() if selected_text in ["Artifact", "Spike", "Real"]: - self.hfo_app.hfo_features.doctor_annotation(selected_text) - # Update the text of the selected item in the dropdown menu - selected_index = self.hfo_app.hfo_features.index - item_text = self.hfo_app.hfo_features.get_annotation_text(selected_index) + selected_index, item_text = self.annotation_controller.set_doctor_annotation(selected_text) self.AnotationDropdownBox.setItemText(selected_index, item_text) self.plot_next() def init_annotation_dropdown(self): # initialize the text in the dropdown menu - for i in range(len(self.hfo_app.hfo_features.annotated)): - text = self.hfo_app.hfo_features.get_annotation_text(i) + for i in range(len(self.backend.event_features.annotated)): + text = self.backend.event_features.get_annotation_text(i) self.AnotationDropdownBox.addItem(text) - self.AnotationDropdownBox.activated.connect(self.plot_jump) - + safe_connect_signal_slot(self.AnotationDropdownBox.activated, self.plot_jump) + def update_interval(self): interval = self.get_current_interval() # Update the plots to reflect the new interval - channel, start, end = self.hfo_app.hfo_features.get_current() - self.waveform_plot.plot(start, end, channel, interval=interval) - self.fft_plot.plot(start, end, channel, interval=interval) + channel, start, end = self.annotation_controller.get_current_event() + self.annotation_controller.update_plots(start, end, channel, interval) if __name__ == '__main__': app = QtWidgets.QApplication(sys.argv) - mainWindow = HFOAnnotation() + mainWindow = Annotation() mainWindow.show() sys.exit(app.exec_()) \ No newline at end of file diff --git a/src/ui/annotation_plot.py b/src/ui/annotation_plot.py index 6ffe4b1..c74bcb8 100644 --- a/src/ui/annotation_plot.py +++ b/src/ui/annotation_plot.py @@ -30,7 +30,7 @@ from src.param.param_detector import ParamDetector, ParamSTE, ParamMNI from src.param.param_filter import ParamFilter from src.utils.utils_gui import * -from src.ui.plot_waveform import * +# from src.ui.plot_waveform import * # import FormatStrFormatter from matplotlib.ticker import FormatStrFormatter @@ -58,17 +58,18 @@ def custom_formatter(x, pos): formatted_number = f' {x:.0f}' if x >= 0 else f'{x:.0f}' return f'{formatted_number:>{max_width}}' + class AnnotationPlot(FigureCanvasQTAgg): - def __init__(self, parent=None, width=10, height=4, dpi=100, hfo_app=None): + def __init__(self, parent=None, width=10, height=4, dpi=100, backend=None): fig,self.axs = plt.subplots(3,1,figsize=(width, height), dpi=dpi) super(AnnotationPlot, self).__init__(fig) - self.hfo_app = hfo_app + self.backend = backend FigureCanvasQTAgg.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) FigureCanvasQTAgg.updateGeometry(self) # self.setParent(parent) # self.plot() - def plot(self,start_index: int = None, end_index: int = None, channel:str = None, interval=1.0): + def plot(self,start_index: int = None, end_index: int = None, channel:str = None, interval=1.0): #first clear the plot for ax in self.axs: ax.cla() @@ -77,16 +78,16 @@ def plot(self,start_index: int = None, end_index: int = None, channel:str = None return if start_index < 0: return - + channel_name = channel - fs = self.hfo_app.sample_freq - + fs = self.backend.sample_freq + #both sets of data (filtered/unfiltered) for plots - length = self.hfo_app.get_eeg_data_shape()[1] + length = self.backend.get_eeg_data_shape()[1] # window_start_index, window_end_index, relative_start_index, relative_end_end = calcuate_boundary(plot_start_index, plot_end_index, length, fs) - window_start_index, window_end_index, relative_start_index, relative_end_end = calcuate_boundary(start_index, end_index, length,fs) - unfiltered_eeg_data, self.channel_names = self.hfo_app.get_eeg_data(window_start_index, window_end_index) - filtered_eeg_data,_ = self.hfo_app.get_eeg_data(window_start_index, window_end_index, filtered=True) + window_start_index, window_end_index, relative_start_index, relative_end_end = calcuate_boundary(start_index, end_index, length, win_len=fs * interval) + unfiltered_eeg_data, self.channel_names = self.backend.get_eeg_data(window_start_index, window_end_index) + filtered_eeg_data,_ = self.backend.get_eeg_data(window_start_index, window_end_index, filtered=True) unfiltered_eeg_data_to_display_one = unfiltered_eeg_data[self.channel_names == channel_name,:][0] filtered_eeg_data_to_display = filtered_eeg_data[self.channel_names == channel_name,:][0] @@ -100,52 +101,52 @@ def plot(self,start_index: int = None, end_index: int = None, channel:str = None # print("this is filtered_eeg_data_to_display: ", filtered_eeg_data_to_display.shape) self.axs[0].set_title("EEG Tracing") self.axs[0].plot(time_to_display, unfiltered_eeg_data_to_display_one, color='blue') - # self.axs[0].plot(time_to_display[int(start_index - window_start_index):int(end_index - window_start_index)], + # self.axs[0].plot(time_to_display[int(start_index - window_start_index):int(end_index - window_start_index)], # unfiltered_eeg_data_to_display_one[int(start_index - window_start_index):int(end_index - window_start_index)], color='orange') self.axs[0].plot(time_to_display[relative_start_index:relative_end_end], unfiltered_eeg_data_to_display_one[relative_start_index:relative_end_end], color='orange') self.axs[0].set_xticks([]) # keep the y axis label fixed (not moving when the plot is updated) - + self.axs[0].set_ylabel('Amplitude (uV)', rotation=90, labelpad=5) self.axs[0].yaxis.set_major_formatter(ticker.FuncFormatter(custom_formatter)) - #self.axs[0].yaxis.set_label_coords(-0.1, 0.5) + #self.axs[0].yaxis.set_label_coords(-0.1, 0.5) # set the y axis label to the right side self.axs[0].yaxis.set_label_position("right") self.axs[0].set_ylim([unfiltered_eeg_data_to_display_one.min(), unfiltered_eeg_data_to_display_one.max()]) - + middle_index = (relative_start_index + relative_end_end) // 2 half_interval_samples = int((interval * fs) // 2) plot_start_index = max(0, int(middle_index - half_interval_samples)) - plot_end_index = int(min(self.hfo_app.get_eeg_data_shape()[1], middle_index + half_interval_samples)) + plot_end_index = int(min(self.backend.get_eeg_data_shape()[1], middle_index + half_interval_samples)) plot_start_index = max(0, min(len(time_to_display) - 1, plot_start_index)) plot_end_index = min(len(time_to_display) - 1, int(middle_index + half_interval_samples)) - + # print(f"time_to_display range: {time_to_display[0]} to {time_to_display[-1]}") # print(f"plot_start_index: {plot_start_index}") # print(f"plot_end_index: {plot_end_index}") # print(f"relative_start_index: {relative_start_index}") # print(f"relative_end_index: {relative_end_end}") - + self.axs[0].set_xlim(time_to_display[plot_start_index], time_to_display[plot_end_index]) - + #self.axs[0].grid() # print("this is time to display: ", time_to_display.shape) # print("this is filtered_eeg_data_to_display: ", filtered_eeg_data_to_display.shape) self.axs[1].set_title("Filtered Tracing") self.axs[1].plot(time_to_display, filtered_eeg_data_to_display, color='blue') - # self.axs[1].plot(time_to_display[int(start_index - window_start_index):int(end_index - window_start_index)], + # self.axs[1].plot(time_to_display[int(start_index - window_start_index):int(end_index - window_start_index)], # filtered_eeg_data_to_display[int(start_index - window_start_index):int(end_index - window_start_index)], color='orange') self.axs[1].plot(time_to_display[relative_start_index:relative_end_end], filtered_eeg_data_to_display[relative_start_index:relative_end_end], color='orange') - + self.axs[1].set_ylabel('Amplitude (uV)', rotation=90, labelpad=6) self.axs[1].set_xticks([]) self.axs[1].yaxis.set_major_formatter(ticker.FuncFormatter(custom_formatter)) - #self.axs[1].yaxis.set_label_coords(-0.1, 0.5) + #self.axs[1].yaxis.set_label_coords(-0.1, 0.5) # set the y axis label to the right side self.axs[1].yaxis.set_label_position("right") self.axs[1].set_ylim([filtered_eeg_data_to_display.min(), filtered_eeg_data_to_display.max()]) # Set y-axis limits self.axs[1].set_xlim(time_to_display[plot_start_index], time_to_display[plot_end_index]) - + #self.axs[1].grid() time_frequency = calculate_time_frequency(unfiltered_eeg_data_to_display_one,fs) @@ -159,11 +160,11 @@ def plot(self,start_index: int = None, end_index: int = None, channel:str = None self.axs[2].yaxis.set_major_formatter(ticker.FuncFormatter(custom_formatter)) self.axs[2].set_xlabel('Time (s)') self.axs[2].set_ylabel('Frequency (Hz)', rotation=90, labelpad=4) - #self.axs[2].yaxis.set_label_coords(-0.1, 0.5) + #self.axs[2].yaxis.set_label_coords(-0.1, 0.5) # set the y axis label to the right side self.axs[2].yaxis.set_label_position("right") self.axs[2].set_xlim(time_to_display[plot_start_index], time_to_display[plot_end_index]) - + #share x axis #self.axs[0].sharex(self.axs[1]) # self.axs[0].sharex(self.axs[2]) @@ -171,39 +172,39 @@ def plot(self,start_index: int = None, end_index: int = None, channel:str = None #call the draw function plt.tight_layout() self.draw() - + + class FFTPlot(FigureCanvasQTAgg): - def __init__(self, parent=None, width=5, height=4, dpi=100, hfo_app=None): - fig,self.axs = plt.subplots(1,1,figsize=(width, height), dpi=dpi) - super(FFTPlot, self).__init__(fig) - self.hfo_app = hfo_app - - FigureCanvasQTAgg.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) - FigureCanvasQTAgg.updateGeometry(self) - - - def plot(self, start_index: int = None, end_index: int = None, channel: str = None, interval=1.0): - self.axs.cla() - start_index = int(start_index) - fs = self.hfo_app.sample_freq - middle_index = (start_index + end_index) // 2 - half_interval_samples = int((interval * fs) // 2) - plot_start_index = int(max(0, middle_index - half_interval_samples)) - plot_end_index = int(min(self.hfo_app.get_eeg_data_shape()[1], middle_index + half_interval_samples)) - - unfiltered_eeg_data, channel_names = self.hfo_app.get_eeg_data(plot_start_index, plot_end_index) - unfiltered_eeg_data = unfiltered_eeg_data[channel_names == channel, :][0] - # Compute the FFT - f, Pxx_den = signal.periodogram(unfiltered_eeg_data, fs) - - # Plotting the FFT - self.axs.semilogy(f, Pxx_den) - self.axs.set_xlabel('Frequency (Hz)') - self.axs.set_ylabel(r"PSD (V$^2$/Hz)") - - self.axs.set_ylim([1e-7, 1e3]) - # self.axs.set_ylim([0, Pxx_den.max()]) - self.axs.set_xlim([min(f), max(f)]) # Ensure the x-axis covers the full frequency range - self.axs.grid() - plt.tight_layout() - self.draw() \ No newline at end of file + def __init__(self, parent=None, width=5, height=4, dpi=100, backend=None): + fig,self.axs = plt.subplots(1,1,figsize=(width, height), dpi=dpi) + super(FFTPlot, self).__init__(fig) + self.backend = backend + + FigureCanvasQTAgg.setSizePolicy(self, QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Expanding) + FigureCanvasQTAgg.updateGeometry(self) + + def plot(self, start_index: int = None, end_index: int = None, channel: str = None, interval=1.0): + self.axs.cla() + start_index = int(start_index) + fs = self.backend.sample_freq + middle_index = (start_index + end_index) // 2 + half_interval_samples = int((interval * fs) // 2) + plot_start_index = int(max(0, middle_index - half_interval_samples)) + plot_end_index = int(min(self.backend.get_eeg_data_shape()[1], middle_index + half_interval_samples)) + + unfiltered_eeg_data, channel_names = self.backend.get_eeg_data(plot_start_index, plot_end_index) + unfiltered_eeg_data = unfiltered_eeg_data[channel_names == channel, :][0] + # Compute the FFT + f, Pxx_den = signal.periodogram(unfiltered_eeg_data, fs) + + # Plotting the FFT + self.axs.semilogy(f, Pxx_den) + self.axs.set_xlabel('Frequency (Hz)') + self.axs.set_ylabel(r"PSD (V$^2$/Hz)") + + self.axs.set_ylim([1e-7, 1e3]) + # self.axs.set_ylim([0, Pxx_den.max()]) + self.axs.set_xlim([min(f), max(f)]) # Ensure the x-axis covers the full frequency range + self.axs.grid() + plt.tight_layout() + self.draw() \ No newline at end of file diff --git a/src/ui/bipolar_channel_selection.py b/src/ui/bipolar_channel_selection.py index 638f7b4..bc8de67 100644 --- a/src/ui/bipolar_channel_selection.py +++ b/src/ui/bipolar_channel_selection.py @@ -15,17 +15,18 @@ class BipolarChannelSelectionWindow(QtWidgets.QDialog): - def __init__(self, hfo_app=None, main_window=None, close_signal = None,waveform_plot = None): + def __init__(self, main_window_model=None, backend=None, main_window=None, close_signal = None,waveform_plot = None): # print(ROOT_DIR) super(BipolarChannelSelectionWindow, self).__init__() self.ui = uic.loadUi(os.path.join(ROOT_DIR, 'bipolar_channel_selection.ui'), self) - self.hfo_app = hfo_app + self.main_window_model = main_window_model + self.backend = backend self.main_window = main_window self.setWindowTitle("Bipolar Channel Selection") self.setWindowIcon(QtGui.QIcon(os.path.join(ROOT_DIR, 'images/icon1.png'))) - eeg_data,channel_names = self.hfo_app.get_eeg_data() + eeg_data,channel_names = self.backend.get_eeg_data() for channel in channel_names: #check if channel is not already in the list, then concat @@ -34,12 +35,12 @@ def __init__(self, hfo_app=None, main_window=None, close_signal = None,waveform_ self.ch_2_dropdown.addItem((channel)) #connect cancel button to close window - self.cancel_button.clicked.connect(self.close) + safe_connect_signal_slot(self.cancel_button.clicked, self.close) #conncet ok button to get channels to show - self.ok_button.clicked.connect(self.check_channels) + safe_connect_signal_slot(self.ok_button.clicked, self.check_channels) self.waveform_plot = waveform_plot self.close_signal = close_signal - self.close_signal.connect(self.close) + safe_connect_signal_slot(self.close_signal, self.close) def check_channels(self): @@ -51,11 +52,11 @@ def check_channels(self): # print(self.channel_2) if str(self.channel_1) != str(self.channel_2): - if f"{self.channel_1}#-#{self.channel_2}" not in self.hfo_app.channel_names: + if f"{self.channel_1}#-#{self.channel_2}" not in self.backend.channel_names: #create bipolar channel and add to data, channel_name lists - self.hfo_app.add_bipolar_channel(self.channel_1,self.channel_2) - self.waveform_plot.update_channel_names(self.hfo_app.channel_names) - self.main_window.set_channels_to_plot(self.hfo_app.channel_names, display_all=False) + self.backend.add_bipolar_channel(self.channel_1,self.channel_2) + self.waveform_plot.update_channel_names(self.backend.channel_names) + self.main_window_model.set_channels_to_plot(self.backend.channel_names, display_all=False) self.close() else: msg = QMessageBox() diff --git a/src/ui/channels_selection.py b/src/ui/channels_selection.py index 99bf416..5585555 100644 --- a/src/ui/channels_selection.py +++ b/src/ui/channels_selection.py @@ -19,11 +19,11 @@ class ChannelSelectionWindow(QtWidgets.QDialog): - def __init__(self, hfo_app=None, main_window=None, close_signal = None): + def __init__(self, backend=None, main_window_model=None, close_signal = None): super(ChannelSelectionWindow, self).__init__() - self.hfo_app = hfo_app - self.main_window = main_window + self.backend = backend + self.main_window_model = main_window_model self.layout = QGridLayout() self.setWindowTitle("Channel Selection") self.setWindowIcon(QtGui.QIcon(os.path.join(ROOT_DIR, 'images/icon1.png'))) @@ -48,16 +48,16 @@ def __init__(self, hfo_app=None, main_window=None, close_signal = None): self.layout.addWidget(self.cancel_button, 1, 1) #connect cancel button to close window - self.cancel_button.clicked.connect(self.close) + safe_connect_signal_slot(self.cancel_button.clicked, self.close) #conncet ok button to get channels to show - self.ok_button.clicked.connect(self.get_channels_to_show) - + safe_connect_signal_slot(self.ok_button.clicked, self.get_channels_to_show) + self.close_signal = close_signal - self.close_signal.connect(self.close_me) + safe_connect_signal_slot(self.close_signal, self.close_me) def set_channels(self): - eeg_data,channels = self.hfo_app.get_eeg_data() - channels_indexes_to_plot = self.main_window.get_channel_indices_to_plot() + eeg_data,channels = self.backend.get_eeg_data() + channels_indexes_to_plot = self.main_window_model.get_channel_indices_to_plot() self.channel_checkboxes = {} self.n_channels = len(channels) self.channels = channels @@ -67,8 +67,8 @@ def set_channels(self): self.check_box_all = QtWidgets.QCheckBox('Select All') self.check_box_none.setCheckState(Qt.Unchecked) self.check_box_all.setCheckState(Qt.Checked) - self.check_box_none.stateChanged.connect(lambda: self.select_channels(False)) - self.check_box_all.stateChanged.connect(lambda: self.select_channels(True)) + safe_connect_signal_slot(self.check_box_none.stateChanged, lambda: self.select_channels(False)) + safe_connect_signal_slot(self.check_box_all.stateChanged, lambda: self.select_channels(True)) self.scroll_layout.addWidget(self.check_box_none, 0, 0) self.scroll_layout.addWidget(self.check_box_all, 0, 1) for i,channel in enumerate(channels): @@ -78,7 +78,7 @@ def set_channels(self): checkbox.setObjectName(f"channel_{i}") self.channel_checkboxes[channel]=checkbox self.__dict__[f"channel_{i}"] = checkbox - checkbox.stateChanged.connect(self.channel_clicked) + safe_connect_signal_slot(checkbox.stateChanged, self.channel_clicked) # checkbox.setChecked(True) self.scroll_layout.addWidget(QtWidgets.QCheckBox(f"{channel}, amplitude: {round(np.ptp(eeg_data[i]),3)} uV"), i//2 + 1, i % 2) @@ -86,7 +86,7 @@ def set_channels(self): for i in range(self.n_channels): if i in channels_indexes_to_plot: self.scroll_layout.itemAtPosition(i//2 +1 ,i%2).widget().setChecked(True) - self.scroll_layout.itemAtPosition(i // 2 + 1, i % 2).widget().stateChanged.connect(self.check_channel_state) + safe_connect_signal_slot(self.scroll_layout.itemAtPosition(i // 2 + 1, i % 2).widget().stateChanged, self.check_channel_state) def channel_clicked(self): #print("clicked") @@ -115,12 +115,12 @@ def get_channels_to_show(self): if self.scroll_layout.itemAtPosition(1+i//2,i%2).widget().isChecked(): channels_to_show.append(self.channels[i]) - if self.main_window is not None: - self.main_window.set_channels_to_plot(channels_to_show) + if self.main_window_model is not None: + self.main_window_model.set_channels_to_plot(channels_to_show) # else: # print("main window is none") # print(channels_to_show) - self.main_window.channel_selection_update() + self.main_window_model.channel_selection_update() self.close() diff --git a/src/ui/main_window.py b/src/ui/main_window.py new file mode 100644 index 0000000..4a587a9 --- /dev/null +++ b/src/ui/main_window.py @@ -0,0 +1,33 @@ +import os +import re +import sys +import traceback +from queue import Queue +from PyQt5.QtWidgets import QMessageBox +from src.hfo_app import HFO_App +from src.controllers.main_window_controller import MainWindowController +from src.models.main_window_model import MainWindowModel +from src.views.main_window_view import MainWindowView +from src.utils.utils_gui import * +from PyQt5.QtCore import pyqtSignal + + +class MainWindow(QMainWindow): + close_signal = pyqtSignal() + + def __init__(self): + super(MainWindow, self).__init__() + + self.model = MainWindowModel(self) + self.view = MainWindowView(self) + self.main_window_controller = MainWindowController(self.view, self.model) + + # initialize general UI + self.main_window_controller.init_general_window() + + # initialize biomarker type + self.main_window_controller.init_biomarker_type() + + # initialize biomarker specific UI + biomarker = self.main_window_controller.get_biomarker_type() + self.main_window_controller.init_biomarker_window(biomarker) diff --git a/src/ui/main_window.ui b/src/ui/main_window.ui index d4b47fe..643ec2f 100644 --- a/src/ui/main_window.ui +++ b/src/ui/main_window.ui @@ -15,7 +15,67 @@ - + + + + + + + Qt::Horizontal + + + + 908 + 20 + + + + + + + + + 0 + 0 + + + + QComboBox::AdjustToContentsOnFirstShow + + + + HFO + + + + + Spindle + + + + + Spike + + + + + + + + Qt::Horizontal + + + + 907 + 20 + + + + + + + + QLayout::SetDefaultConstraint @@ -354,199 +414,13 @@ - + QFrame::StyledPanel QFrame::Raised - - - - - Qt::Horizontal - - - - 40 - 20 - - - - - - - - - - - - - 245 - 130 - 48 - - - - - - - - - 245 - 130 - 48 - - - - - - - - - 240 - 240 - 240 - - - - - - - - true - - - - - - - - Arial - 11 - - - - Artifact - - - - - - - - - - - - 240 - 50 - 230 - - - - - - - - - 240 - 50 - 230 - - - - - - - - - 240 - 240 - 240 - - - - - - - - true - - - - - - - - Arial - 11 - - - - spk-HFO - - - - - - - - - - - - 60 - 180 - 75 - - - - - - - - - 60 - 180 - 75 - - - - - - - - - 240 - 240 - 240 - - - - - - - - true - - - - - - - - Arial - 11 - - - - HFO - - - - @@ -576,6 +450,13 @@ 0 + + + + Qt::Horizontal + + + @@ -589,13 +470,6 @@ - - - - Qt::Horizontal - - - @@ -796,8114 +670,15 @@ - - - - - 0 - 0 - - - - 1 - - - - - - - - Arial - 13 - - - - Detection Parameters (MNI) - - - - - - - Arial - 11 - - - - Min Window (Sec) - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(255, 255, 255); - - - - - - - - - - - Arial - 11 - - - - Epoch CHF (Sec) - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(255, 255, 255); - - - - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(255, 255, 255); - - - - - - - - - - - Arial - 11 - - - - Min Gap Time (Sec) - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(255, 255, 255); - - - - - - - - - - - Arial - 11 - - - - Epoch (Sec) - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(255, 255, 255); - - - - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(255, 255, 255); - - - - - - - - - - - Arial - 11 - - - - Detect - - - - - - - - Arial - 11 - - - - CHF Percentage - - - - - - - - Arial - 11 - - - - Threshold Percentile - - - - - - - - Arial - 12 - - - - Baseline - - - - - - - Arial - 11 - - - - Window (Sec) - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(255, 255, 255); - - - - - - - - - - - Arial - 11 - - - - Shift - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(255, 255, 255); - - - - - - - - - - - Arial - 11 - - - - Threshold - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(255, 255, 255); - - - - - - - - - - - Arial - 11 - - - - Min Time - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(255, 255, 255); - - - - - - - - - - - - - - - - - - - - - Arial - 13 - - - - Detection Parameters (STE) - - - - - - - Arial - 11 - - - - Min Window (s) - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(235, 235, 235); - - - - - - - - - - - Arial - 11 - - - - RMS Window (s) - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(235, 235, 235); - - - - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(235, 235, 235); - - - - - - - - - - - Arial - 11 - - - - Min Oscillations - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(235, 235, 235); - - - - - - - - - - - Arial - 11 - - - - Epoch (s) - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(235, 235, 235); - - - - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(235, 235, 235); - - - - - - - - - - - Arial - 11 - - - - Peak Threshold - - - - - - - - Arial - 11 - - - - RMS Threshold - - - - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 0 - 0 - 0 - - - - - - - 255 - 255 - 255 - - - - - - - 0 - 0 - 0 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 212 - 127 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 213 - 127 - 255 - - - - - - - 191 - 63 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 113 - 0 - 170 - - - - - - - 85 - 0 - 127 - - - - - - - 255 - 255 - 255 - - - - - - - 85 - 0 - 127 - - - - - - - 235 - 235 - 235 - - - - - - - 235 - 235 - 235 - - - - - - - 0 - 0 - 0 - - - - - - - 170 - 0 - 255 - - - - - - - 255 - 255 - 220 - - - - - - - 0 - 0 - 0 - - - - - - - 0 - 0 - 0 - - - - - - - - - Arial - 11 - - - - background-color: rgb(235, 235, 235); - - - - - - - - - - - Arial - 11 - - - - Min Gap Time (s) - - - - - - - - Arial - 11 - - - - Detect - - - - - - - - - - - - + + 0 - - + + - + Arial @@ -8911,313 +686,11 @@ - Classifier Parameters + Filter Parameters - - - - - - - - Arial - 11 - true - - - - Device - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - - Arial - 11 - true - - - - background-color: rgb(235, 235, 235); - - - CPU - - - true - - - - - - - - Arial - 11 - true - - - - Batch Size - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - - Arial - 11 - true - - - - background-color: rgb(235, 235, 235); - - - 64 - - - true - - - - - - - - - - - - Arial - 11 - true - - - - Use spk-HFO - - - - - - - - Arial - 11 - true - - - - spk-HFO Path - - - - - - - - Arial - 11 - false - false - true - - - - background-color: rgb(235, 235, 235); - - - - - - true - - - File Name... - - - - - - - - - - - Qt::Horizontal - - - - 40 - 20 - - - - - - - - - Arial - 11 - - - - HFO Classification - - - - - - - - - - - - Arial - 11 - true - - - - Artifact Path - - - - - - - - Arial - 11 - false - false - true - - - - background-color: rgb(235, 235, 235); - - - - - - true - - - File Name... - - - - - - - - - - - - Arial - 11 - true - - - - Ignore - - - - - - - - Arial - 11 - - - - 1 - - - - - - - - Arial - 11 - - - - Sec Before - - - - - - - - Arial - 11 - - - - 1 - - - - - - - - Arial - 11 - - - - Sec After Recording - - - - - - - - - - - - - - - - - - - - - - Arial - 13 - - - - Filter Parameters - - - - + + + @@ -9409,504 +882,395 @@ - - - - - - - - - - - Arial - 13 - false - - - - Statistics - - - - - - - Arial - 11 - false - - - - Save As npz - - - - - - - - Arial - 11 - false - - - - Save As Excel - - - - - - - QFrame::NoFrame - - - - - - - - - - - Arial - 11 - false - - - - Annotation - - - - - - - - - - - - - - - - Detector - - - - - - Qt::LeftToRight - - - QTabWidget::Rounded - - - 0 - - - - STE - - - - - - - Arial - 11 - - - - Detection Parameters - - - Qt::AlignCenter - - - - - - sec - - - Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter - - - - - - - RMS Window - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - sec - - - Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter - - - - - - - sec - - - Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter - - - - - - - Epoch Length - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - Min Window - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - Min Oscillations - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - Min Gap - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - sec - - - Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter - - - - - - - RMS Threshold - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - Peak Threshold - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - Save - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - MNI - - - - - - - Arial - 11 - - - - Detection Parameters - - - Qt::AlignCenter - - - - - - Sec - - - Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter - - - - - - - Baseline Window - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - Baseline Shift - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - Baseline Threshold - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - Sec - - - Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter - - - - - - - Baseline Minimum Time - - - Qt::AlignCenter - - - - - - - - - - Qt::AlignCenter - - - - - - - CHF Percentage - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - % - - - Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter - - - - - - - Sec - - - Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter - - - + + + + 0 + + + + + + + + Arial + 13 + + + + Classifier Parameters + + + + + + + + + Arial + 11 + true + + + + Device + + + Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter + + + + + + + + Arial + 11 + true + + + + background-color: rgb(235, 235, 235); + + + CPU + + + true + + + + + + + + Arial + 11 + true + + + + Batch Size + + + Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter + + + + + + + + Arial + 11 + true + + + + background-color: rgb(235, 235, 235); + + + 64 + + + true + + + + + + + + + + + + Arial + 11 + true + + + + Use spk-HFO + + + + + + + + Arial + 11 + true + + + + spk-HFO Path + + + + + + + + Arial + 11 + false + false + true + + + + background-color: rgb(235, 235, 235); + + + + + + true + + + File Name... + + + + + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + + Arial + 11 + + + + HFO Classification + + + + + + + + + + + + Arial + 11 + true + + + + Artifact Path + + + + + + + + Arial + 11 + false + false + true + + + + background-color: rgb(235, 235, 235); + + + + + + true + + + File Name... + + + + + + + + + + + + Arial + 11 + true + + + + Ignore + + + + + + + + Arial + 11 + + + + 1 + + + + + + + + Arial + 11 + + + + Sec Before + + + + + + + + Arial + 11 + + + + 1 + + + + + + + + Arial + 11 + + + + Sec After Recording + + + + + + + + + + + + + + + + + + 0 + 0 + + + + -1 + + + + + + + + + 0 + + + + + + + + Arial + 13 + false + + + + Statistics + + - - - Epoch CHF - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - Min Gap Time - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter - - - - - - - % - - - Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + + Arial + 11 + false + - - - - - Epoch Time - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter + Save As npz - - - - Sec - - - Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + + + Arial + 11 + false + - - - - - Sec - - - Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + Save As Excel - - - - Min Window - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter + + + + QFrame::NoFrame - - - - - Threshold Percentage - - - Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + Arial + 11 + false + + - Save + Annotation @@ -9915,6 +1279,27 @@ + + + + + + + + Detector + + + + + + Qt::LeftToRight + + + QTabWidget::Rounded + + + -1 + @@ -9926,10 +1311,10 @@ - - + + - + Arial @@ -9937,19 +1322,16 @@ - Artifact Model + spk-HFO Model - + Arial 11 - false - false - true @@ -9964,7 +1346,7 @@ - + Arial @@ -9978,43 +1360,24 @@ - - - - - - - Arial - 12 - - - - Device - - - - - - - - Arial - 12 - - - - CPU - - - false - - - - + + + + + Arial + 11 + true + + + + Use spk-HFO + + - - + + - + Arial @@ -10022,16 +1385,19 @@ - spk-HFO Model + Artifact Model - + Arial 11 + false + false + true @@ -10046,7 +1412,7 @@ - + Arial @@ -10093,7 +1459,7 @@ - + @@ -10123,6 +1489,39 @@ + + + + + + + Arial + 12 + + + + Device + + + + + + + + Arial + 12 + + + + CPU + + + false + + + + + @@ -10153,19 +1552,35 @@ - - - - - Arial - 11 - true - - - - Use spk-HFO - - + + + + + + + Arial + 11 + + + + Download + + + + + + + + Arial + 11 + + + + Upload + + + + diff --git a/src/ui/plot_waveform.py b/src/ui/plot_waveform.py index ce8c5f9..bf2bb78 100644 --- a/src/ui/plot_waveform.py +++ b/src/ui/plot_waveform.py @@ -8,33 +8,26 @@ from tqdm import tqdm import os from src.hfo_app import HFO_App +from src.spindle_app import SpindleApp import random +from src.controllers import MiniPlotController, MainWaveformPlotController curr_dir = os.path.dirname(os.path.realpath(__file__)) sys.path.append(os.path.dirname(curr_dir)) -class PlotWaveform(QtWidgets.QGraphicsView): - def __init__(self, plot_loc:pg.PlotWidget, hfo_loc:pg.PlotWidget, backend: HFO_App): - super().__init__() - self.hfo_display = hfo_loc - self.hfo_display.setMouseEnabled(x=False, y=False) - self.hfo_display.getPlotItem().hideAxis('bottom') - self.hfo_display.getPlotItem().hideAxis('left') - self.hfo_loc = hfo_loc - self.hfo_display.setBackground('w') +class CenterWaveformAndMiniPlotController(): + def __init__(self, main_waveform_plot_widget: pg.PlotWidget, mini_plot_widget: pg.PlotWidget, backend: HFO_App): + self.mini_plot_controller = MiniPlotController(mini_plot_widget, backend) + self.main_waveform_plot_controller = MainWaveformPlotController(main_waveform_plot_widget, backend) - self.waveform_display = plot_loc #pg.PlotWidget(plot_loc) - # self.waveform_display.getPlotItem().getAxis('bottom').setHeight(10) - self.waveform_display.setMouseEnabled(x=False, y=False) - self.waveform_display.getPlotItem().hideAxis('bottom') - self.waveform_display.getPlotItem().hideAxis('left') - self.plot_loc = plot_loc - self.waveform_display.setBackground('w') + # clear everything if exit + self.main_waveform_plot_controller.clear() + self.mini_plot_controller.clear() self.time_window = 20 #20 second time window - self.time_increment =20 - self.old_size = (self.plot_loc.x(),self.plot_loc.y(),self.plot_loc.width(),self.plot_loc.height()) + self.time_increment = 20 + # self.old_size = (self.waveform_display.x(),self.waveform_display.y(),self.waveform_display.width(),self.waveform_display.height()) self.t_start = 0 self.first_channel_to_plot = 0 self.n_channels_to_plot = 10 @@ -48,67 +41,39 @@ def __init__(self, plot_loc:pg.PlotWidget, hfo_loc:pg.PlotWidget, backend: HFO_A self.HFO_color=self.non_spike_color self.color_dict={"artifact":self.artifact_color,"spike":self.spike_color, "non_spike":self.non_spike_color,"HFO":self.HFO_color} - self.plot_HFOs = False + self.plot_biomarkers = False self.normalize_vertical = False self.stds = None def set_filtered(self,filtered:bool): - self.filtered = filtered + self.main_waveform_plot_controller.set_waveform_filter(filtered) self.plot(self.t_start) def update_backend(self,new_backend:HFO_App,init_eeg_data:bool=True): self.backend = new_backend + self.mini_plot_controller.update_backend(new_backend) + self.main_waveform_plot_controller.update_backend(new_backend) if init_eeg_data: self.init_eeg_data() def init_eeg_data(self): - # print("reinit eeg data") - #reinitalize self - # self = PlotWaveform(self.plot_loc,self.hfo_loc,self.backend) - # self.eeg_data = eeg_data - # #normalize each to 0-1 - # self.eeg_data = (self.eeg_data-self.eeg_data.min(axis = 1,keepdims = True))/(np.ptp(self.eeg_data,axis = 1,keepdims = True)) - # #shift the ith channel by 1.1*i - # self.eeg_data = self.eeg_data-1.1*np.arange(self.eeg_data.shape[0])[:,None] - self.filtered = False - self.plot_HFOs = False - self.hfo_display.clear() - self.waveform_display.clear() - eeg_data,self.channel_names=self.backend.get_eeg_data() - ## print("eeg_data.shape",eeg_data.shape) - # print("self.channel_names",self.channel_names) - self.channel_names = list(self.channel_names) - self.edf_info=self.backend.get_edf_info() - # print(self.edf_info) - # self.channel_names_locs=np.mean(self.eeg_data,axis = 1) - self.sample_freq = self.edf_info['sfreq'] - self.time = np.arange(0,eeg_data.shape[1]/self.sample_freq,1/self.sample_freq) # time in seconds - self.n_channels = len(self.channel_names) - # print("here") - self.n_channels_to_plot = min(self.n_channels,self.n_channels_to_plot) - self.channels_to_plot = self.channel_names.copy() - self.channel_indices_to_plot = np.arange(self.n_channels) - self.init_hfo_display() - self.waveform_display.getPlotItem().showAxis('bottom') - self.waveform_display.getPlotItem().showAxis('left') - # print(self.plot_loc.x(),self.plot_loc.y(),self.plot_loc.width(),self.plot_loc.height()) - - def init_hfo_display(self): - # print("init hfo display") - self.hfo_display.getPlotItem().showAxis('bottom') - self.hfo_display.getPlotItem().showAxis('left') - self.lr = pg.LinearRegionItem([0,0], movable = False) - self.lr.setZValue(-20) - self.hfo_display.addItem(self.lr) + self.mini_plot_controller.clear() + self.main_waveform_plot_controller.clear() + self.mini_plot_controller.init_eeg_data() + self.main_waveform_plot_controller.init_eeg_data() + + self.mini_plot_controller.init_biomarker_display() + self.main_waveform_plot_controller.init_waveform_display() + def get_n_channels(self): - return self.n_channels + return self.main_waveform_plot_controller.model.n_channels def get_n_channels_to_plot(self): return self.n_channels_to_plot def get_total_time(self): - return self.time[-1] + return self.main_waveform_plot_controller.model.time[-1] def get_time_window(self): return self.time_window @@ -118,9 +83,11 @@ def get_time_increment(self): def set_normalize_vertical(self,normalize_vertical:bool): self.normalize_vertical = normalize_vertical + self.main_waveform_plot_controller.set_normalize_vertical(normalize_vertical) def set_time_window(self,time_window:float): self.time_window = time_window + self.main_waveform_plot_controller.set_time_window(time_window) #replot # self.plot(self.t_start) @@ -129,203 +96,66 @@ def set_time_increment(self,time_increment:float): def set_n_channels_to_plot(self,n_channels_to_plot:int): self.n_channels_to_plot = n_channels_to_plot + self.main_waveform_plot_controller.set_n_channels_to_plot(n_channels_to_plot) + self.mini_plot_controller.set_n_channels_to_plot(n_channels_to_plot) - def set_plot_HFOs(self,plot_HFOs:bool): - self.plot_HFOs = plot_HFOs - self.plot(self.t_start, update_hfo=True) + def set_plot_biomarkers(self,plot_biomarkers:bool): + self.plot_biomarkers = plot_biomarkers + self.main_waveform_plot_controller.set_plot_biomarkers(plot_biomarkers) + self.plot(self.t_start, update_biomarker=True) def get_channels_to_plot(self): - return self.channels_to_plot + return self.main_waveform_plot_controller.model.channels_to_plot def get_channel_indices_to_plot(self): - return self.channel_indices_to_plot + return self.main_waveform_plot_controller.model.channel_indices_to_plot def update_channel_names(self,new_channel_names): - self.channel_names = list(new_channel_names) + self.mini_plot_controller.update_channel_names(new_channel_names) + self.main_waveform_plot_controller.update_channel_names(new_channel_names) def set_channels_to_plot(self,channels_to_plot:list): - self.channels_to_plot = channels_to_plot - self.channel_indices_to_plot = [self.channel_names.index(channel) for channel in channels_to_plot] - # self.n_channels_to_plot = len(self.channels_to_plot) - # self.plot(self.t_start) + self.main_waveform_plot_controller.set_channels_to_plot(channels_to_plot) + self.mini_plot_controller.set_channels_to_plot(channels_to_plot) def set_channel_indices_to_plot(self,channel_indices_to_plot:list): - self.channel_indices_to_plot = channel_indices_to_plot - self.channels_to_plot = [self.channel_names[index] for index in channel_indices_to_plot] - # self.n_channels_to_plot = len(self.channels_to_plot) - # self.plot(self.t_start) + self.main_waveform_plot_controller.set_channel_indices_to_plot(channel_indices_to_plot) + self.mini_plot_controller.set_channel_indices_to_plot(channel_indices_to_plot) - def plot(self,t_start:float = None,first_channel_to_plot:int = None, empty=False, update_hfo=False): - # print("plot HFOs",self.plot_HFOs) + def plot(self, start_in_time:float = None, first_channel_to_plot:int = None, empty=False, update_biomarker=False): + if empty: - self.waveform_display.clear() - self.hfo_display.clear() + self.main_waveform_plot_controller.clear() + self.mini_plot_controller.clear() return - if t_start is None: - t_start = self.t_start - else: - self.t_start = t_start #this allows us to keep track of the start time of the plot and thus replot when the time window changes or when the number of channels - if first_channel_to_plot is None: - first_channel_to_plot = self.first_channel_to_plot - else: - self.first_channel_to_plot = first_channel_to_plot - self.waveform_display.clear() - if update_hfo: - self.hfo_display.clear() - self.init_hfo_display() - #to show changes - t_end = min(t_start+self.time_window,self.time[-1]) - # print(t_start,t_end,self.time[-1]) - # start_time = time.time() - eeg_data_to_display,_=self.backend.get_eeg_data(int(t_start*self.sample_freq),int(t_end*self.sample_freq), self.filtered) - #normalize each to 0-1 - eeg_data_to_display = eeg_data_to_display[self.channel_indices_to_plot,:] - if self.normalize_vertical: - eeg_data_to_display = (eeg_data_to_display-eeg_data_to_display.min(axis = 1,keepdims = True)) - eeg_data_to_display = eeg_data_to_display/np.max(eeg_data_to_display) - else: - # eeg_data_to_display = (eeg_data_to_display-eeg_data_to_display.min(axis = 1,keepdims = True))/(np.ptp(eeg_data_to_display,axis = 1,keepdims = True)) - - # standardized signal by channel - # means = np.mean(eeg_data_to_display, axis=1, keepdims=True) - # stds = np.std(eeg_data_to_display, axis=1, keepdims=True) - if self.filtered: - means = np.mean(eeg_data_to_display) - self.stds = np.std(eeg_data_to_display) * 2 - eeg_data_to_display = (eeg_data_to_display - means) / self.stds - eeg_data_to_display[np.isnan(eeg_data_to_display)] = 0 - else: - # standardized signal globally - means = np.mean(eeg_data_to_display) - self.stds = np.std(eeg_data_to_display) - eeg_data_to_display = (eeg_data_to_display - means) / self.stds - #replace nans with 0 - eeg_data_to_display[np.isnan(eeg_data_to_display)] = 0 - #shift the ith channel by 1.1*i - # eeg_data_to_display = eeg_data_to_display-1.1*np.arange(eeg_data_to_display.shape[0])[:,None] - if self.filtered: - # Add scale indicators - # Set the length of the scale lines - y_100_length = 50 # 100 microvolts - offset_value = 6 - y_scale_length = y_100_length / self.stds - else: - y_100_length = 100 # 100 microvolts - offset_value = 6 - y_scale_length = y_100_length / self.stds - time_to_display = self.time[int(t_start*self.sample_freq):int(t_end*self.sample_freq)] - top_value=eeg_data_to_display[first_channel_to_plot].max() - # print("top value:",top_value) - # bottom_value=eeg_data_to_display[-1].min() - # print("bottom value:",bottom_value) - # print("channel means",np.mean(eeg_data_to_display,axis = 1)) - for disp_i, ch_i in enumerate(range(first_channel_to_plot,first_channel_to_plot+self.n_channels_to_plot)): - channel = self.channels_to_plot[ch_i] - - self.waveform_display.plot(time_to_display, eeg_data_to_display[ch_i] - disp_i*offset_value, pen=pg.mkPen(color=self.waveform_color, width=0.5)) - if self.plot_HFOs: - starts, ends, artifacts, spikes = self.backend.hfo_features.get_HFOs_for_channel(channel,int(t_start*self.sample_freq),int(t_end*self.sample_freq)) - # print("channel:", channel, starts,ends, artifacts, spikes) - for j in range(len(starts)): - try: - if int(artifacts[j])<1: - color = self.artifact_color - name="artifact" - elif spikes[j]: - color = self.spike_color - name="spike" - else: - color = self.non_spike_color - name="non-spike" - except: - color = self.HFO_color - name="HFO" - # print(time_to_display[starts[j]:ends[j]]) - self.waveform_display.plot(self.time[int(starts[j]):int(ends[j])], - eeg_data_to_display[ch_i, int(starts[j])-int(t_start*self.sample_freq):int(ends[j])-int(t_start*self.sample_freq)]-disp_i*offset_value, - pen=pg.mkPen(color=color, width=2)) - # print("plotting",self.time[int(starts[j])],self.time[int(ends[j])],"name:",name,"channel:",channel) - # print(starts[j],ends[j]) - # print(eeg_data_to_display[i,int(starts[j]):int(ends[j])]) - self.waveform_display.plot([self.time[int(starts[j])],self.time[int(ends[j])]],[ - top_value+0.2,top_value+0.2 - ],pen = pg.mkPen(color = color,width=10)) - - # mini plot - if self.plot_HFOs and update_hfo: - starts, ends, artifacts, spikes = self.backend.hfo_features.get_HFOs_for_channel(channel, - 0, - sys.maxsize) - for j in range(len(starts)): - try: - if int(artifacts[j])<1: - color = self.artifact_color - name="artifact" - elif spikes[j]: - color = self.spike_color - name="spike" - else: - color = self.non_spike_color - name="non-spike" - except: - color = self.HFO_color - name="HFO" - # x = self.time[int(starts[j]):int(ends[j])] - # y = eeg_data_to_display[i, int(starts[j]):int(ends[j])] - # # self.waveform_mini_item.setData(x, y, pen=pg.mkPen(color=color, width=2)) - # self.hfo_display.plot(x, y, pen=pg.mkPen(color=color, width=2)) - end = min(int(ends[j]), len(self.time)-1) - self.hfo_display.plot([self.time[int(starts[j])], self.time[end]], [ - top_value, top_value - ], pen=pg.mkPen(color=color, width=5)) - - # Determine the position for the scale indicator (bottom right corner of the plot) - x_pos = t_end #+ 0.15 - # y_pos = top_value - 0.1 * (top_value - np.min(eeg_data_to_display)) # Adjust as needed - y_pos = np.min(eeg_data_to_display[-1]) - self.n_channels_to_plot * offset_value + 0.8 * offset_value + + if start_in_time is not None: + self.main_waveform_plot_controller.set_current_time_window(start_in_time) + start_in_time, end_in_time = self.main_waveform_plot_controller.get_current_start_end() - # # Draw the x and y scale lines - # self.waveform_display.plot([x_pos, x_pos], [y_pos, y_pos + y_scale_length], pen=pg.mkPen('black', width=2)) + if first_channel_to_plot is not None: + self.main_waveform_plot_controller.set_first_channel_to_plot(first_channel_to_plot) + self.mini_plot_controller.set_first_channel_to_plot(first_channel_to_plot) + first_channel_to_plot = self.main_waveform_plot_controller.get_first_channel_to_plot() - # # Add text annotations for the scale lines - # text_item = pg.TextItem(f'{y_100_length} μV', color='black', anchor=(1, 0.5)) - # text_item.setPos(x_pos, y_pos + y_scale_length / 2) - # self.waveform_display.addItem(text_item) + self.main_waveform_plot_controller.clear() - # Use a dashed line for the scale - scale_line = pg.PlotDataItem([x_pos, x_pos], [y_pos, y_pos + y_scale_length], - pen=pg.mkPen('black', width=10), fill=(0, 128, 255, 150)) - self.waveform_display.addItem(scale_line) - - text_item = pg.TextItem(f'Scale: {y_100_length} μV ', color='black', anchor=(1, 0.5)) - text_item.setFont(QtGui.QFont('Arial', 10, QtGui.QFont.Bold)) - text_item.setPos(x_pos, y_pos + y_scale_length / 2) - self.waveform_display.addItem(text_item) + if update_biomarker: + self.mini_plot_controller.clear() + self.mini_plot_controller.init_biomarker_display() - # print("time to plot:",time.time()-start_time) - #set y ticks to channel names - # channel_names_locs = -offset_value * np.arange(eeg_data_to_display.shape[0])[:, None] # + offset_value/2 - channel_names_locs = -offset_value * np.arange(self.n_channels_to_plot)[:, None] # + offset_value/2 + eeg_data_to_display, y_100_length, y_scale_length, offset_value = self.main_waveform_plot_controller.plot_all_current_channels_for_window() + top_value = eeg_data_to_display[first_channel_to_plot].max() + top_value_mini = 0.5 + if self.plot_biomarkers: + self.main_waveform_plot_controller.plot_all_current_biomarkers_for_window(eeg_data_to_display, offset_value, top_value) - self.waveform_display.getAxis('left').setTicks([[(channel_names_locs[disp_i], self.channels_to_plot[chi_i]) - for disp_i, chi_i in enumerate(range(first_channel_to_plot,first_channel_to_plot+self.n_channels_to_plot))]]) - #set the max and min of the x axis - self.waveform_display.setXRange(t_start,t_end) + if self.plot_biomarkers and update_biomarker: + self.mini_plot_controller.plot_all_current_biomarkers_for_all_channels(top_value_mini) - self.hfo_display.getAxis('left').setTicks([[(top_value, ' HFO ')]]) - self.hfo_display.setXRange(0, int(self.time.shape[0] / self.sample_freq)) - self.hfo_display.setYRange(top_value-0.25, top_value+0.25) - - self.lr.setRegion([t_start,t_end]) - self.lr.setZValue(top_value) - #set background to white - # self.waveform_display.setBackground('w') + self.main_waveform_plot_controller.draw_scale_bar(eeg_data_to_display, offset_value, y_100_length, y_scale_length) + self.main_waveform_plot_controller.draw_channel_names(offset_value) - # plot out on top bars of where the HFOs are - # all_channels_starts = np.array(all_channels_starts) - # all_channels_ends = np.array(all_channels_ends) - # all_channels_names = np.array(all_channels_names) - # for name in np.unique(all_channels_names): - # self.waveform_display.plot(self.time[all_channels_starts[all_channels_names==name]], - # [0]* - # ,pen = pg.mkPen(color = self.color_dict[name],width=5)) + self.mini_plot_controller.set_miniplot_title('biomarker', top_value_mini) + self.mini_plot_controller.set_total_x_y_range(top_value_mini) + self.mini_plot_controller.update_highlight_window(start_in_time, end_in_time, top_value_mini) diff --git a/src/ui/quick_detection.py b/src/ui/quick_detection.py index 3954319..de71d33 100644 --- a/src/ui/quick_detection.py +++ b/src/ui/quick_detection.py @@ -6,8 +6,9 @@ import os from pathlib import Path from src.hfo_app import HFO_App +from src.spindle_app import SpindleApp from src.param.param_classifier import ParamClassifier -from src.param.param_detector import ParamDetector, ParamSTE, ParamMNI +from src.param.param_detector import ParamDetector, ParamSTE, ParamMNI, ParamHIL from src.param.param_filter import ParamFilter from src.utils.utils_gui import * @@ -18,7 +19,7 @@ class HFOQuickDetector(QtWidgets.QDialog): - def __init__(self, hfo_app=None, main_window=None, close_signal = None): + def __init__(self, backend=None, main_window=None, close_signal = None): super(HFOQuickDetector, self).__init__() # print("initializing HFOQuickDetector") self.ui = uic.loadUi(os.path.join(ROOT_DIR, 'quick_detection.ui'), self) @@ -27,26 +28,29 @@ def __init__(self, hfo_app=None, main_window=None, close_signal = None): # print("loaded ui") self.filename = None self.threadpool = QThreadPool() - self.detectionTypeComboBox.currentIndexChanged['int'].connect( - lambda: self.update_detector_tab(self.detectionTypeComboBox.currentText())) # type: ignore + safe_connect_signal_slot(self.detectionTypeComboBox.currentIndexChanged['int'], + lambda: self.update_detector_tab(self.detectionTypeComboBox.currentText())) self.detectionTypeComboBox.setCurrentIndex(2) QtCore.QMetaObject.connectSlotsByName(self) # self.qd_loadEDF_button.clicked.connect(hfoMainWindow.Ui_MainWindow.openFile) - self.qd_loadEDF_button.clicked.connect(self.open_file) + safe_connect_signal_slot(self.qd_loadEDF_button.clicked, self.open_file) #print("hfo_app: ", hfo_app) - if hfo_app is None: + if backend is None: #print("hfo_app is None creating new HFO_App") - self.hfo_app = HFO_App() + self.backend = HFO_App() else: #print("hfo_app is not None") - self.hfo_app = hfo_app + self.backend = backend self.init_default_filter_input_params() self.init_default_mni_input_params() self.init_default_ste_input_params() - self.qd_choose_artifact_model_button.clicked.connect(lambda: self.choose_model_file("artifact")) - self.qd_choose_spike_model_button.clicked.connect(lambda: self.choose_model_file("spike")) + self.init_default_hil_input_params() + safe_connect_signal_slot(self.qd_choose_artifact_model_button.clicked, + lambda: self.choose_model_file("artifact")) + safe_connect_signal_slot(self.qd_choose_spike_model_button.clicked, + lambda: self.choose_model_file("spike")) - self.run_button.clicked.connect(self.run) + safe_connect_signal_slot(self.run_button.clicked, self.run) self.run_button.setEnabled(False) #set n_jobs min and max @@ -54,7 +58,7 @@ def __init__(self, hfo_app=None, main_window=None, close_signal = None): self.n_jobs_spinbox.setMaximum(mp.cpu_count()) #set default n_jobs - self.n_jobs_spinbox.setValue(self.hfo_app.n_jobs) + self.n_jobs_spinbox.setValue(self.backend.n_jobs) self.main_window = main_window self.stdout = Queue() @@ -62,31 +66,31 @@ def __init__(self, hfo_app=None, main_window=None, close_signal = None): # sys.stdout = WriteStream(self.stdout) # sys.stderr = WriteStream(self.stderr) self.thread_stdout = STDOutReceiver(self.stdout) - self.thread_stdout.std_received_signal.connect(self.main_window.message_handler) + safe_connect_signal_slot(self.thread_stdout.std_received_signal, self.main_window.message_handler) self.thread_stdout.start() # print("not here 2") self.thread_stderr = STDErrReceiver(self.stderr) - self.thread_stderr.std_received_signal.connect(self.main_window.message_handler) + safe_connect_signal_slot(self.thread_stderr.std_received_signal, self.main_window.message_handler) self.thread_stderr.start() #classifier default buttons - self.default_cpu_button.clicked.connect(self.set_classifier_param_cpu_default) - self.default_gpu_button.clicked.connect(self.set_classifier_param_gpu_default) + safe_connect_signal_slot(self.default_cpu_button.clicked, self.set_classifier_param_cpu_default) + safe_connect_signal_slot(self.default_gpu_button.clicked, self.set_classifier_param_gpu_default) if not torch.cuda.is_available(): self.default_gpu_button.setEnabled(False) - self.cancel_button.clicked.connect(self.close) + safe_connect_signal_slot(self.cancel_button.clicked, self.close) self.running = False # self.setWindowFlags( QtCore.Qt.CustomizeWindowHint ) self.close_signal = close_signal - self.close_signal.connect(self.close) + safe_connect_signal_slot(self.close_signal, self.close) def open_file(self): fname, _ = QFileDialog.getOpenFileName(self, "Open File", "", "Recordings Files (*.edf *.eeg *.vhdr *.vmrk)") if fname: worker = Worker(self.read_edf, fname) - worker.signals.result.connect(self.update_edf_info) + safe_connect_signal_slot(worker.signals.result, self.update_edf_info) # worker.signals.finished.connect(lambda: self.message_handler('Open File thread COMPLETE!')) # worker.signals.progress.connect(self.progress_fn) # Execute @@ -94,14 +98,14 @@ def open_file(self): def read_edf(self, fname, progress_callback): self.fname = fname - self.hfo_app.load_edf(fname) - eeg_data,channel_names=self.hfo_app.get_eeg_data() - edf_info=self.hfo_app.get_edf_info() + self.backend.load_edf(fname) + eeg_data,channel_names=self.backend.get_eeg_data() + edf_info=self.backend.get_edf_info() filename = os.path.basename(fname) self.filename = filename - sample_freq = str(self.hfo_app.sample_freq) - num_channels = str(len(self.hfo_app.channel_names)) - length = str(self.hfo_app.eeg_data.shape[1]) + sample_freq = str(self.backend.sample_freq) + num_channels = str(len(self.backend.channel_names)) + length = str(self.backend.eeg_data.shape[1]) return [filename, sample_freq, num_channels, length] @pyqtSlot(list) @@ -120,6 +124,9 @@ def update_detector_tab(self, index): elif index == "STE": self.stackedWidget.setCurrentIndex(1) self.detector = "STE" + elif index == "HIL": + self.stackedWidget.setCurrentIndex(2) + self.detector = "HIL" # filter stuff def init_default_filter_input_params(self): @@ -178,7 +185,7 @@ def get_mni_params(self): "min_win":float(min_win), "min_gap":float(min_gap), "base_seg":float(base_seg), "thrd_perc":float(thrd_perc)/100, "base_shift":float(base_shift), "base_thrd":float(base_thrd), "base_min":float(base_min), - "n_jobs":self.hfo_app.n_jobs} + "n_jobs":self.backend.n_jobs} detector_params = {"detector_type":"MNI", "detector_param":param_dict} return ParamDetector.from_dict(detector_params) @@ -203,10 +210,39 @@ def get_ste_params(self): param_dict={"sample_freq":2000,"pass_band":1, "stop_band":80, #these are placeholder params, will be updated later "rms_window":float(rms_window_raw), "min_window":float(min_window_raw), "min_gap":float(min_gap_raw), "epoch_len":float(epoch_len_raw), "min_osc":float(min_osc_raw), "rms_thres":float(rms_thres_raw), - "peak_thres":float(peak_thres_raw),"n_jobs":self.hfo_app.n_jobs} + "peak_thres":float(peak_thres_raw),"n_jobs":self.backend.n_jobs} detector_params={"detector_type":"STE", "detector_param":param_dict} return ParamDetector.from_dict(detector_params) + def init_default_hil_input_params(self): + default_params = ParamHIL(2000) + self.qd_hil_sample_freq_input.setText(str(default_params.sample_freq)) + self.qd_hil_pass_band_input.setText(str(default_params.pass_band)) + self.qd_hil_stop_band_input.setText(str(default_params.stop_band)) + self.qd_hil_epoch_time_input.setText(str(default_params.epoch_time)) + self.qd_hil_sd_threshold_input.setText(str(default_params.sd_threshold)) + self.qd_hil_min_window_input.setText(str(default_params.min_window)) + + def get_hil_params(self): + sample_freq_raw = self.qd_hil_sample_freq_input.text() + pass_band_raw = self.qd_hil_pass_band_input.text() + stop_band_raw = self.qd_hil_stop_band_input.text() + epoch_time_raw = self.qd_hil_epoch_time_input.text() + sd_threshold_raw = self.qd_hil_sd_threshold_input.text() + min_window_raw = self.qd_hil_min_window_input.text() + + param_dict = { + "sample_freq": float(sample_freq_raw), + "pass_band": float(pass_band_raw), + "stop_band": float(stop_band_raw), + "epoch_time": float(epoch_time_raw), + "sd_threshold": float(sd_threshold_raw), + "min_window": float(min_window_raw), + "n_jobs": self.backend.n_jobs + } + detector_params = {"detector_type": "HIL", "detector_param": param_dict} + return ParamDetector.from_dict(detector_params) + def get_classifier_param(self): artifact_path = self.qd_classifier_artifact_filename_display.text() spike_path = self.qd_classifier_spike_filename_display.text() @@ -222,7 +258,7 @@ def get_classifier_param(self): return {"classifier_param":classifier_param,"use_spike":use_spike, "seconds_before":seconds_before, "seconds_after":seconds_after} def set_classifier_param_display(self): - classifier_param = self.hfo_app.get_classifier_param() + classifier_param = self.backend.get_classifier_param() #set also the input fields self.qd_classifier_artifact_filename_display.setText(classifier_param.artifact_path) @@ -232,11 +268,11 @@ def set_classifier_param_display(self): self.qd_classifier_batch_size_input.setText(str(classifier_param.batch_size)) def set_classifier_param_gpu_default(self): - self.hfo_app.set_default_gpu_classifier() + self.backend.set_default_gpu_classifier() self.set_classifier_param_display() def set_classifier_param_cpu_default(self): - self.hfo_app.set_default_cpu_classifier() + self.backend.set_default_cpu_classifier() self.set_classifier_param_display() def choose_model_file(self, model_type): @@ -248,13 +284,13 @@ def choose_model_file(self, model_type): def _detect(self, progress_callback): #call detect HFO function on backend - self.hfo_app.detect_HFO() + self.backend.detect_biomarker() return [] - def detect_HFOs(self): + def detect_biomarkers(self): # print("Detecting HFOs...") worker=Worker(self._detect) - worker.signals.result.connect(self._detect_finished) + safe_connect_signal_slot(worker.signals.result, self._detect_finished) self.threadpool.start(worker) # def _detect_finished(self): @@ -265,19 +301,19 @@ def detect_HFOs(self): def filter_data(self): # print("Filtering data...") worker=Worker(self._filter) - worker.signals.finished.connect(self.filtering_complete) + safe_connect_signal_slot(worker.signals.finished, self.filtering_complete) self.threadpool.start(worker) def _filter(self, progress_callback): - self.hfo_app.filter_eeg_data(self.filter_params) + self.backend.filter_eeg_data(self.filter_params) # def filtering_complete(self): # self.message_handler('Filtering COMPLETE!') def _classify(self,classify_spikes,seconds_to_ignore_before=0,seconds_to_ignore_after=0): - self.hfo_app.classify_artifacts([seconds_to_ignore_before,seconds_to_ignore_after]) + self.backend.classify_artifacts([seconds_to_ignore_before,seconds_to_ignore_after]) if classify_spikes: - self.hfo_app.classify_spikes() + self.backend.classify_spikes() return [] # def _classify_finished(self): @@ -285,7 +321,7 @@ def _classify(self,classify_spikes,seconds_to_ignore_before=0,seconds_to_ignore_ def classify(self,params): #set the parameters - self.hfo_app.set_classifier(params["classifier_param"]) + self.backend.set_classifier(params["classifier_param"]) seconds_to_ignore_before = params["seconds_before"] seconds_to_ignore_after = params["seconds_after"] self._classify(params["use_spike"],seconds_to_ignore_before,seconds_to_ignore_after) @@ -295,7 +331,7 @@ def classify(self,params): def _run(self, progress_callback): self.run_button.setEnabled(False) - self.hfo_app.n_jobs = int(self.n_jobs_spinbox.value()) + self.backend.n_jobs = int(self.n_jobs_spinbox.value()) # get the filter parameters filter_param = self.get_filter_param() # get the detector parameters @@ -303,6 +339,8 @@ def _run(self, progress_callback): detector_param = self.get_mni_params() elif self.detector == "STE": detector_param = self.get_ste_params() + elif self.detector == "HIL": + detector_param = self.get_hil_params() #print("filter_param: ", filter_param.to_dict()) #print("detector_param: ", detector_param.to_dict()) # get the classifier parameters @@ -318,11 +356,11 @@ def _run(self, progress_callback): return [] # run the filter - self.hfo_app.filter_eeg_data(filter_param) + self.backend.filter_eeg_data(filter_param) # print("Filtering COMPLETE!") #run the detector - self.hfo_app.set_detector(detector_param) - self.hfo_app.detect_HFO() + self.backend.set_detector(detector_param) + self.backend.detect_biomarker() # print("HFOs DETECTED!") #if we use classifier, run the classifier use_classifier = self.qd_use_classifier_checkbox.isChecked() @@ -332,10 +370,10 @@ def _run(self, progress_callback): # print("Classification FINISH!") if save_as_excel: fname = self.fname.split(".")[0]+".xlsx" - self.hfo_app.export_excel(fname) + self.backend.export_excel(fname) if save_as_npz: fname = self.fname.split(".")[0]+".npz" - self.hfo_app.export_app(fname) + self.backend.export_app(fname) # print(f"Exporting {fname} FINISH!") return [] @@ -344,7 +382,7 @@ def run(self): self.running = True #disable cancel button self.cancel_button.setEnabled(False) - worker.signals.result.connect(self._run_finished) + safe_connect_signal_slot(worker.signals.result, self._run_finished) self.threadpool.start(worker) def _run_finished(self): diff --git a/src/ui/quick_detection.ui b/src/ui/quick_detection.ui index 66d9167..cd27dfb 100644 --- a/src/ui/quick_detection.ui +++ b/src/ui/quick_detection.ui @@ -1019,6 +1019,11 @@ STE + + + HIL + + @@ -1816,6 +1821,282 @@ + + + + + + + Arial + 11 + true + + + + Qt::LeftToRight + + + HIL Detector + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + + + + + Arial + 11 + false + + + + SD Threshold + + + Qt::AlignCenter + + + + + + + + Arial + 11 + false + + + + + + + + + Arial + 11 + false + + + + Pass Band + + + Qt::AlignCenter + + + + + + + + Arial + 11 + false + + + + + + + + + Arial + 11 + false + + + + + + + + + Arial + 11 + false + + + + Hz + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + + + + + + Arial + 11 + false + + + + + + + + + Arial + 11 + false + + + + Sample Frequency + + + Qt::AlignCenter + + + + + + + + Arial + 11 + false + + + + + + + + + Arial + 11 + false + + + + sec + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + + + + + + Arial + 11 + false + + + + Epoch Time + + + Qt::AlignCenter + + + + + + + + Arial + 11 + false + + + + + + + + + Arial + 11 + false + + + + Min Window + + + Qt::AlignCenter + + + + + + + + Arial + 11 + false + + + + Stop Band + + + Qt::AlignCenter + + + + + + + + Arial + 11 + false + + + + Hz + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + + + + + + Arial + 11 + false + + + + sec + + + Qt::AlignLeading|Qt::AlignLeft|Qt::AlignVCenter + + + + + + + + Arial + 11 + false + + + + Hz + + + + + + + + diff --git a/src/utils/utils_detector.py b/src/utils/utils_detector.py index 1e59983..f080637 100644 --- a/src/utils/utils_detector.py +++ b/src/utils/utils_detector.py @@ -1,4 +1,6 @@ -from HFODetector import ste, mni +from HFODetector import ste, mni, hil +import yasa + def set_STE_detector(args): detector = ste.STEDetector(sample_freq=args.sample_freq, filter_freq=[args.pass_band, args.stop_band], @@ -6,9 +8,23 @@ def set_STE_detector(args): epoch_len=args.epoch_len, min_osc=args.min_osc, rms_thres=args.rms_thres, peak_thres=args.peak_thres, n_jobs=args.n_jobs, front_num=1) return detector + + def set_MNI_detector(args): detector = mni.MNIDetector(sample_freq=args.sample_freq, filter_freq=[args.pass_band, args.stop_band],epoch_time=args.epoch_time, epo_CHF=args.epo_CHF,per_CHF=args.per_CHF,min_win=args.min_win,min_gap=args.min_gap, thrd_perc=args.thrd_perc,base_seg=args.base_seg,base_shift=args.base_shift, base_thrd=args.base_thrd,base_min=args.base_min,n_jobs=args.n_jobs,front_num=1) - return detector \ No newline at end of file + return detector + + +def set_HIL_detector(args): + detector = hil.HILDetector(sample_freq=args.sample_freq, filter_freq=[args.pass_band, args.stop_band], + sd_thres=args.sd_threshold, min_window=args.min_window, + epoch_len=args.epoch_time, n_jobs=args.n_jobs, front_num=1) + return detector + + +def set_YASA_detector(args): + detector = yasa + return {'yasa': detector, 'args': args} diff --git a/src/utils/utils_feature.py b/src/utils/utils_feature.py index 379dacd..b587434 100644 --- a/src/utils/utils_feature.py +++ b/src/utils/utils_feature.py @@ -26,20 +26,20 @@ def calcuate_boundary(start, end, length, win_len=2000): def extract_data(data, start, end, sampling_rate=2000): data = np.squeeze(data) start, end = calcuate_boundary(start, end, len(data), win_len=sampling_rate) - hfo_waveform = data[start:end] - return hfo_waveform + biomarker_waveform = data[start:end] + return biomarker_waveform win_len = int(sampling_rate*time_range[1]/1000) - hfo_waveforms = np.zeros((len(starts), win_len*2)) + biomarker_waveforms = np.zeros((len(starts), win_len*2)) for i in range(len(starts)): channel_name = channel_names[i] start = starts[i] end = ends[i] channel_index = np.where(unique_channel_names == channel_name)[0] - hfo_waveform = extract_data(data[channel_index], start, end, win_len) - hfo_waveforms[i] = hfo_waveform - return hfo_waveforms + biomarker_waveform = extract_data(data[channel_index], start, end, win_len) + biomarker_waveforms[i] = biomarker_waveform + return biomarker_waveforms -def compute_hfo_feature(start, end, channel_name, data, sample_rate, win_size, ps_MinFreqHz, ps_MaxFreqHz, time_window_ms): +def compute_biomarker_feature(start, end, channel_name, data, sample_rate, win_size, ps_MinFreqHz, ps_MaxFreqHz, time_window_ms): # generate one sec time-freqeucny image spectrum_img = compute_spectrum(data, ps_SampleRate=sample_rate, ps_FreqSeg=win_size, ps_MinFreqHz=ps_MinFreqHz, ps_MaxFreqHz=ps_MaxFreqHz) left_index = int((time_window_ms/1000)*sample_rate) diff --git a/src/utils/utils_gui.py b/src/utils/utils_gui.py index bdac1a5..ff81a5e 100644 --- a/src/utils/utils_gui.py +++ b/src/utils/utils_gui.py @@ -71,6 +71,7 @@ def run(self): finally: self.signals.finished.emit() + class WriteStream(object): def __init__(self, q: Queue): self.queue = q @@ -122,3 +123,38 @@ def run(self): def stop(self): self._isRunning = False + + +def clear_layout(layout): + if layout is not None: + while layout.count(): + item = layout.takeAt(0) + widget = item.widget() + if widget is not None: + widget.deleteLater() # Safely delete the widget + else: + layout.removeItem(item) + + +def clear_stacked_widget(stacked_widget): + # Remove all pages from the QStackedWidget + while stacked_widget.count() > 0: + widget = stacked_widget.widget(0) # Get the first page + stacked_widget.removeWidget(widget) # Remove the widget + widget.deleteLater() # Delete the widget + + +# def clear_frame(frame): +# # Clear the frame by removing all child widgets +# for child in frame.children(): +# # if isinstance(child, QLayout): +# # clear_layout(child) +# if isinstance(child, QWidget): +# child.deleteLater() +def safe_connect_signal_slot(signal, slot): + """Ensure the signal is connected only once.""" + try: + signal.disconnect(slot) + except TypeError: + pass + signal.connect(slot) diff --git a/src/utils/utils_plotting.py b/src/utils/utils_plotting.py index 90fd339..774904d 100644 --- a/src/utils/utils_plotting.py +++ b/src/utils/utils_plotting.py @@ -1,7 +1,9 @@ import matplotlib.pyplot as plt import numpy as np import os -def plot_feature(folder, feature, start, end, data, data_filtered, channel_name, hfo_start, hfo_end): + + +def plot_feature(folder, feature, start, end, data, data_filtered, channel_name, biomarker_start, biomarker_end): channel_data = data channel_data_f = data_filtered fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 5)) @@ -9,9 +11,9 @@ def plot_feature(folder, feature, start, end, data, data_filtered, channel_name, channel_data_f = np.squeeze(channel_data_f) ax1.imshow(feature[0]) ax2.plot(channel_data, color='blue') - ax2.plot(np.arange(hfo_start, hfo_end), channel_data[hfo_start:hfo_end], color='red') + ax2.plot(np.arange(biomarker_start, biomarker_end), channel_data[biomarker_start:biomarker_end], color='red') ax3.plot(channel_data_f, color='blue') - ax3.plot(np.arange(hfo_start, hfo_end), channel_data_f[hfo_start:hfo_end], color='red') + ax3.plot(np.arange(biomarker_start, biomarker_end), channel_data_f[biomarker_start:biomarker_end], color='red') plt.suptitle(f"{channel_name}_{start}_{end} with length: {(end - start)*0.5} ms") fn = f'{channel_name}_{start}_{end}.jpg' plt.savefig(os.path.join(folder,fn)) diff --git a/src/views/__init__.py b/src/views/__init__.py new file mode 100644 index 0000000..eaab7af --- /dev/null +++ b/src/views/__init__.py @@ -0,0 +1,4 @@ +from .mini_plot_view import MiniPlotView +from .main_waveform_plot_view import MainWaveformPlotView +from .annotation_view import AnnotationView +from .main_window_view import MainWindowView diff --git a/src/views/annotation_view.py b/src/views/annotation_view.py new file mode 100644 index 0000000..65ced7b --- /dev/null +++ b/src/views/annotation_view.py @@ -0,0 +1,13 @@ +import numpy as np + + +class AnnotationView: + def __init__(self, window_widget): + self.window_widget = window_widget + # self._init_plot_widget(plot_widget) + + def add_widget(self, layout, widget): + attr = getattr(self.window_widget, layout) + method = getattr(attr, 'addWidget') + method(widget) + diff --git a/src/views/main_waveform_plot_view.py b/src/views/main_waveform_plot_view.py new file mode 100644 index 0000000..1a5cc83 --- /dev/null +++ b/src/views/main_waveform_plot_view.py @@ -0,0 +1,46 @@ +from PyQt5 import QtWidgets +import pyqtgraph as pg +from PyQt5 import QtGui +import numpy as np + + +class MainWaveformPlotView(QtWidgets.QGraphicsView): + def __init__(self, plot_widget: pg.PlotWidget): + super(MainWaveformPlotView, self).__init__() + self.plot_widget = plot_widget + self._init_plot_widget(plot_widget) + + def _init_plot_widget(self, plot_widget: pg.PlotWidget): + plot_widget.setMouseEnabled(x=False, y=False) + plot_widget.getPlotItem().hideAxis('bottom') + plot_widget.getPlotItem().hideAxis('left') + plot_widget.setBackground('w') + + def clear(self): + self.plot_widget.clear() + + def enable_axis_information(self): + self.plot_widget.getPlotItem().showAxis('bottom') + self.plot_widget.getPlotItem().showAxis('left') + + def plot_waveform(self, x, y, color, width): + self.plot_widget.plot(x, y, pen=pg.mkPen(color=color, width=width)) + + def draw_scale_bar(self, x_pos, y_pos, y_100_length, y_scale_length): + scale_line = pg.PlotDataItem([x_pos, x_pos], [y_pos, y_pos + y_scale_length], + pen=pg.mkPen('black', width=10), fill=(0, 128, 255, 150)) + self.plot_widget.addItem(scale_line) + + text_item = pg.TextItem(f'Scale: {y_100_length} μV ', color='black', anchor=(1, 0.5)) + text_item.setFont(QtGui.QFont('Arial', 10, QtGui.QFont.Bold)) + text_item.setPos(x_pos, y_pos + y_scale_length / 2) + self.plot_widget.addItem(text_item) + + def draw_channel_names(self, offset_value, n_channels_to_plot, channels_to_plot, first_channel_to_plot, start_in_time, end_in_time): + #set y ticks to channel names + channel_names_locs = -offset_value * np.arange(n_channels_to_plot)[:, None] # + offset_value/2 + + self.plot_widget.getAxis('left').setTicks([[(channel_names_locs[disp_i], channels_to_plot[chi_i]) + for disp_i, chi_i in enumerate(range(first_channel_to_plot, first_channel_to_plot + n_channels_to_plot))]]) + #set the max and min of the x axis + self.plot_widget.setXRange(start_in_time, end_in_time) diff --git a/src/views/main_window_view.py b/src/views/main_window_view.py new file mode 100644 index 0000000..ff78fcd --- /dev/null +++ b/src/views/main_window_view.py @@ -0,0 +1,707 @@ +import numpy as np +from pathlib import Path +from PyQt5 import uic +from PyQt5 import QtCore, QtGui, QtWidgets +from src.utils.utils_gui import * + +ROOT_DIR = Path(__file__).parent.parent.parent + + +class MainWindowView(QObject): + def __init__(self, window): + super(MainWindowView, self).__init__() + self.window = window + # self._init_plot_widget(plot_widget) + + def init_general_window(self): + self.window.ui = uic.loadUi(os.path.join(ROOT_DIR, 'src/ui/main_window.ui'), self.window) + self.window.setWindowIcon(QtGui.QIcon(os.path.join(ROOT_DIR, 'src/ui/images/icon1.png'))) + self.window.setWindowTitle("pyBrain") + + self.window.threadpool = QThreadPool() + self.window.replace_last_line = False + + def get_biomarker_type(self): + return self.window.combo_box_biomarker.currentText() + + def create_stacked_widget_detection_param(self, biomarker_type='HFO'): + if biomarker_type == 'HFO': + clear_stacked_widget(self.window.stacked_widget_detection_param) + page_ste = self.create_detection_parameter_page_ste('Detection Parameters (STE)') + page_mni = self.create_detection_parameter_page_mni('Detection Parameters (MNI)') + page_hil = self.create_detection_parameter_page_hil('Detection Parameters (HIL)') + self.window.stacked_widget_detection_param.addWidget(page_ste) + self.window.stacked_widget_detection_param.addWidget(page_mni) + self.window.stacked_widget_detection_param.addWidget(page_hil) + + self.window.detector_subtabs.clear() + tab_ste = self.create_detection_parameter_tab_ste() + tab_mni = self.create_detection_parameter_tab_mni() + tab_hil = self.create_detection_parameter_tab_hil() + self.window.detector_subtabs.addTab(tab_ste, 'STE') + self.window.detector_subtabs.addTab(tab_mni, 'MNI') + self.window.detector_subtabs.addTab(tab_hil, 'HIL') + + elif biomarker_type == 'Spindle': + clear_stacked_widget(self.window.stacked_widget_detection_param) + page_yasa = self.create_detection_parameter_page_yasa('Detection Parameters (YASA)') + self.window.stacked_widget_detection_param.addWidget(page_yasa) + + self.window.detector_subtabs.clear() + tab_yasa = self.create_detection_parameter_tab_yasa() + self.window.detector_subtabs.addTab(tab_yasa, 'YASA') + + def create_frame_biomarker(self, biomarker_type='HFO'): + if biomarker_type == 'HFO': + clear_layout(self.window.frame_biomarker_layout) + self.create_frame_biomarker_hfo() + elif biomarker_type == 'Spindle': + clear_layout(self.window.frame_biomarker_layout) + self.create_frame_biomarker_spindle() + + def create_detection_parameter_page_ste(self, groupbox_title): + page = QWidget() + layout = QGridLayout() + + detection_groupbox_ste = QGroupBox(groupbox_title) + ste_parameter_layout = QGridLayout(detection_groupbox_ste) + + clear_layout(ste_parameter_layout) + # Create widgets + text_font = QFont('Arial', 11) + label1 = QLabel('Epoch (s)') + label2 = QLabel('Min Window (s)') + label3 = QLabel('RMS Window (s)') + label4 = QLabel('Min Gap Time (s)') + label5 = QLabel('Min Oscillations') + label6 = QLabel('Peak Threshold') + label7 = QLabel('RMS Threshold') + + self.window.ste_epoch_display = QLabel() + self.window.ste_epoch_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.ste_epoch_display.setFont(text_font) + self.window.ste_min_window_display = QLabel() + self.window.ste_min_window_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.ste_min_window_display.setFont(text_font) + self.window.ste_rms_window_display = QLabel() + self.window.ste_rms_window_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.ste_rms_window_display.setFont(text_font) + self.window.ste_min_gap_time_display = QLabel() + self.window.ste_min_gap_time_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.ste_min_gap_time_display.setFont(text_font) + self.window.ste_min_oscillations_display = QLabel() + self.window.ste_min_oscillations_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.ste_min_oscillations_display.setFont(text_font) + self.window.ste_peak_threshold_display = QLabel() + self.window.ste_peak_threshold_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.ste_peak_threshold_display.setFont(text_font) + self.window.ste_rms_threshold_display = QLabel() + self.window.ste_rms_threshold_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.ste_rms_threshold_display.setFont(text_font) + self.window.ste_detect_button = QPushButton('Detect') + + # Add widgets to the grid layout + ste_parameter_layout.addWidget(label1, 0, 0) # Row 0, Column 0 + ste_parameter_layout.addWidget(label2, 0, 1) # Row 0, Column 1 + ste_parameter_layout.addWidget(self.window.ste_epoch_display, 1, 0) # Row 1, Column 0 + ste_parameter_layout.addWidget(self.window.ste_min_window_display, 1, 1) # Row 1, Column 1 + ste_parameter_layout.addWidget(label3, 2, 0) + ste_parameter_layout.addWidget(label4, 2, 1) + ste_parameter_layout.addWidget(self.window.ste_rms_window_display, 3, 0) + ste_parameter_layout.addWidget(self.window.ste_min_gap_time_display, 3, 1) + ste_parameter_layout.addWidget(label5, 4, 0) + ste_parameter_layout.addWidget(label6, 4, 1) + ste_parameter_layout.addWidget(self.window.ste_min_oscillations_display, 5, 0) + ste_parameter_layout.addWidget(self.window.ste_peak_threshold_display, 5, 1) + ste_parameter_layout.addWidget(label7, 6, 0) + ste_parameter_layout.addWidget(self.window.ste_rms_threshold_display, 7, 0) + ste_parameter_layout.addWidget(self.window.ste_detect_button, 7, 1) + + # Set the layout for the page + layout.addWidget(detection_groupbox_ste) + page.setLayout(layout) + return page + + def create_detection_parameter_page_mni(self, groupbox_title): + page = QWidget() + layout = QGridLayout() + + detection_groupbox_mni = QGroupBox(groupbox_title) + mni_parameter_layout = QGridLayout(detection_groupbox_mni) + + clear_layout(mni_parameter_layout) + # self.detection_groupbox_mni.setTitle("Detection Parameters (MNI)") + + # Create widgets + text_font = QFont('Arial', 11) + label1 = QLabel('Epoch (s)') + label2 = QLabel('Min Window (s)') + label3 = QLabel('Epoch CHF (s)') + label4 = QLabel('Min Gap Time (s)') + label5 = QLabel('CHF Percentage') + label6 = QLabel('Threshold Percentile') + label7 = QLabel('Window (s)') + label8 = QLabel('Shift') + label9 = QLabel('Threshold') + label10 = QLabel('Min Time') + + self.window.mni_epoch_display = QLabel() + self.window.mni_epoch_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.mni_epoch_display.setFont(text_font) + self.window.mni_min_window_display = QLabel() + self.window.mni_min_window_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.mni_min_window_display.setFont(text_font) + self.window.mni_epoch_chf_display = QLabel() + self.window.mni_epoch_chf_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.mni_epoch_chf_display.setFont(text_font) + self.window.mni_min_gap_time_display = QLabel() + self.window.mni_min_gap_time_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.mni_min_gap_time_display.setFont(text_font) + self.window.mni_chf_percentage_display = QLabel() + self.window.mni_chf_percentage_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.mni_chf_percentage_display.setFont(text_font) + self.window.mni_threshold_percentile_display = QLabel() + self.window.mni_threshold_percentile_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.mni_threshold_percentile_display.setFont(text_font) + self.window.mni_baseline_window_display = QLabel() + self.window.mni_baseline_window_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.mni_baseline_window_display.setFont(text_font) + self.window.mni_baseline_shift_display = QLabel() + self.window.mni_baseline_shift_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.mni_baseline_shift_display.setFont(text_font) + self.window.mni_baseline_threshold_display = QLabel() + self.window.mni_baseline_threshold_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.mni_baseline_threshold_display.setFont(text_font) + self.window.mni_baseline_min_time_display = QLabel() + self.window.mni_baseline_min_time_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.mni_baseline_min_time_display.setFont(text_font) + self.window.mni_detect_button = QPushButton('Detect') + + # Add widgets to the grid layout + mni_parameter_layout.addWidget(label1, 0, 0) # Row 0, Column 0 + mni_parameter_layout.addWidget(label2, 0, 1) # Row 0, Column 1 + mni_parameter_layout.addWidget(self.window.mni_epoch_display, 1, 0) # Row 1, Column 0 + mni_parameter_layout.addWidget(self.window.mni_min_window_display, 1, 1) # Row 1, Column 1 + mni_parameter_layout.addWidget(label3, 2, 0) + mni_parameter_layout.addWidget(label4, 2, 1) + mni_parameter_layout.addWidget(self.window.mni_epoch_chf_display, 3, 0) + mni_parameter_layout.addWidget(self.window.mni_min_gap_time_display, 3, 1) + mni_parameter_layout.addWidget(label5, 4, 0) + mni_parameter_layout.addWidget(label6, 4, 1) + mni_parameter_layout.addWidget(self.window.mni_chf_percentage_display, 5, 0) + mni_parameter_layout.addWidget(self.window.mni_threshold_percentile_display, 5, 1) + + group_box = QGroupBox('Baseline') + baseline_parameter_layout = QVBoxLayout(group_box) + baseline_parameter_layout.addWidget(label7) + baseline_parameter_layout.addWidget(self.window.mni_baseline_window_display) + baseline_parameter_layout.addWidget(label8) + baseline_parameter_layout.addWidget(self.window.mni_baseline_shift_display) + baseline_parameter_layout.addWidget(label9) + baseline_parameter_layout.addWidget(self.window.mni_baseline_threshold_display) + baseline_parameter_layout.addWidget(label10) + baseline_parameter_layout.addWidget(self.window.mni_baseline_min_time_display) + + mni_parameter_layout.addWidget(group_box, 0, 2, 6, 1) # Row 0, Column 2, span 1 row, 6 columns + mni_parameter_layout.addWidget(self.window.mni_detect_button, 6, 2) + + # Set the layout for the page + layout.addWidget(detection_groupbox_mni) + page.setLayout(layout) + return page + + def create_detection_parameter_page_hil(self, groupbox_title): + page = QWidget() + layout = QGridLayout() + + detection_groupbox_hil = QGroupBox(groupbox_title) + hil_parameter_layout = QGridLayout(detection_groupbox_hil) + + clear_layout(hil_parameter_layout) + # self.detection_groupbox_hil.setTitle("Detection Parameters (HIL)") + + # Create widgets + text_font = QFont('Arial', 11) + label1 = QLabel('Epoch Length (s)') + label2 = QLabel('Min Window (s)') + label3 = QLabel('Pass Band (Hz)') + label4 = QLabel('Stop Band (Hz)') + label5 = QLabel('Sample Frequency') + label6 = QLabel('SD Threshold') + + self.window.hil_epoch_time_display = QLabel() + self.window.hil_epoch_time_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.hil_epoch_time_display.setFont(text_font) + self.window.hil_min_window_display = QLabel() + self.window.hil_min_window_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.hil_min_window_display.setFont(text_font) + self.window.hil_pass_band_display = QLabel() + self.window.hil_pass_band_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.hil_pass_band_display.setFont(text_font) + self.window.hil_stop_band_display = QLabel() + self.window.hil_stop_band_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.hil_stop_band_display.setFont(text_font) + self.window.hil_sample_freq_display = QLabel() + self.window.hil_sample_freq_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.hil_sample_freq_display.setFont(text_font) + self.window.hil_sd_threshold_display = QLabel() + self.window.hil_sd_threshold_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.hil_sd_threshold_display.setFont(text_font) + self.window.hil_detect_button = QPushButton('Detect') + + # Add widgets to the grid layout + hil_parameter_layout.addWidget(label1, 0, 0) # Row 0, Column 0 + hil_parameter_layout.addWidget(label2, 0, 1) # Row 0, Column 1 + hil_parameter_layout.addWidget(self.window.hil_epoch_time_display, 1, 0) # Row 1, Column 0 + hil_parameter_layout.addWidget(self.window.hil_min_window_display, 1, 1) # Row 1, Column 1 + hil_parameter_layout.addWidget(label3, 2, 0) + hil_parameter_layout.addWidget(label4, 2, 1) + hil_parameter_layout.addWidget(self.window.hil_pass_band_display, 3, 0) + hil_parameter_layout.addWidget(self.window.hil_stop_band_display, 3, 1) + hil_parameter_layout.addWidget(label5, 4, 0) + hil_parameter_layout.addWidget(label6, 4, 1) + hil_parameter_layout.addWidget(self.window.hil_sample_freq_display, 5, 0) + hil_parameter_layout.addWidget(self.window.hil_sd_threshold_display, 5, 1) + hil_parameter_layout.addWidget(self.window.hil_detect_button, 6, 1) + + # Set the layout for the page + layout.addWidget(detection_groupbox_hil) + page.setLayout(layout) + return page + + def create_detection_parameter_page_yasa(self, groupbox_title): + page = QWidget() + layout = QGridLayout() + + detection_groupbox_yasa = QGroupBox(groupbox_title) + yasa_parameter_layout = QGridLayout(detection_groupbox_yasa) + + clear_layout(yasa_parameter_layout) + # self.detection_groupbox_hil.setTitle("Detection Parameters (HIL)") + + # Create widgets + text_font = QFont('Arial', 11) + label1 = QLabel('Freq Spindle (Hz)') + label2 = QLabel('Freq Broad (Hz)') + label3 = QLabel('Duration (s)') + label4 = QLabel('Min Distance (ms)') + label5 = QLabel('rel_pow') + label6 = QLabel('corr') + label7 = QLabel('rms') + + self.window.yasa_freq_sp_display = QLabel() + self.window.yasa_freq_sp_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.yasa_freq_sp_display.setFont(text_font) + self.window.yasa_freq_broad_display = QLabel() + self.window.yasa_freq_broad_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.yasa_freq_broad_display.setFont(text_font) + self.window.yasa_duration_display = QLabel() + self.window.yasa_duration_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.yasa_duration_display.setFont(text_font) + self.window.yasa_min_distance_display = QLabel() + self.window.yasa_min_distance_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.yasa_min_distance_display.setFont(text_font) + self.window.yasa_thresh_rel_pow_display = QLabel() + self.window.yasa_thresh_rel_pow_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.yasa_thresh_rel_pow_display.setFont(text_font) + self.window.yasa_thresh_corr_display = QLabel() + self.window.yasa_thresh_corr_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.yasa_thresh_corr_display.setFont(text_font) + self.window.yasa_thresh_rms_display = QLabel() + self.window.yasa_thresh_rms_display.setStyleSheet("background-color: rgb(235, 235, 235);") + self.window.yasa_thresh_rms_display.setFont(text_font) + + self.window.yasa_detect_button = QPushButton('Detect') + + # Add widgets to the grid layout + yasa_parameter_layout.addWidget(label1, 0, 0) # Row 0, Column 0 + yasa_parameter_layout.addWidget(label2, 0, 1) # Row 0, Column 1 + yasa_parameter_layout.addWidget(self.window.yasa_freq_sp_display, 1, 0) # Row 1, Column 0 + yasa_parameter_layout.addWidget(self.window.yasa_freq_broad_display, 1, 1) # Row 1, Column 1 + yasa_parameter_layout.addWidget(label3, 2, 0) + yasa_parameter_layout.addWidget(label4, 2, 1) + yasa_parameter_layout.addWidget(self.window.yasa_duration_display, 3, 0) + yasa_parameter_layout.addWidget(self.window.yasa_min_distance_display, 3, 1) + + group_box = QGroupBox('thresh') + thresh_parameter_layout = QVBoxLayout(group_box) + thresh_parameter_layout.addWidget(label5) + thresh_parameter_layout.addWidget(self.window.yasa_thresh_rel_pow_display) + thresh_parameter_layout.addWidget(label6) + thresh_parameter_layout.addWidget(self.window.yasa_thresh_corr_display) + thresh_parameter_layout.addWidget(label7) + thresh_parameter_layout.addWidget(self.window.yasa_thresh_rms_display) + + yasa_parameter_layout.addWidget(group_box, 0, 2, 4, 1) # Row 0, Column 2, span 1 row, 6 columns + yasa_parameter_layout.addWidget(self.window.yasa_detect_button, 4, 2) + + # Set the layout for the page + layout.addWidget(detection_groupbox_yasa) + page.setLayout(layout) + return page + + def create_detection_parameter_tab_ste(self): + tab = QWidget() + layout = QGridLayout() + + detection_groupbox = QGroupBox('Detection Parameters') + parameter_layout = QGridLayout(detection_groupbox) + + clear_layout(parameter_layout) + + # Create widgets + text_font = QFont('Arial', 11) + label1 = QLabel('RMS Window') + label2 = QLabel('Min Window') + label3 = QLabel('Min Gap') + label4 = QLabel('Epoch Length') + label5 = QLabel('Min Oscillations') + label6 = QLabel('RMS Threshold') + label7 = QLabel('Peak Threshold') + label8 = QLabel('sec') + label9 = QLabel('sec') + label10 = QLabel('sec') + label11 = QLabel('sec') + + self.window.ste_rms_window_input = QLineEdit() + self.window.ste_rms_window_input.setFont(text_font) + self.window.ste_min_window_input = QLineEdit() + self.window.ste_min_window_input.setFont(text_font) + self.window.ste_min_gap_input = QLineEdit() + self.window.ste_min_gap_input.setFont(text_font) + self.window.ste_epoch_length_input = QLineEdit() + self.window.ste_epoch_length_input.setFont(text_font) + self.window.ste_min_oscillation_input = QLineEdit() + self.window.ste_min_oscillation_input.setFont(text_font) + self.window.ste_rms_threshold_input = QLineEdit() + self.window.ste_rms_threshold_input.setFont(text_font) + self.window.ste_peak_threshold_input = QLineEdit() + self.window.ste_peak_threshold_input.setFont(text_font) + self.window.STE_save_button = QPushButton('Save') + + # Add widgets to the grid layout + parameter_layout.addWidget(label1, 0, 0) # Row 0, Column 0 + parameter_layout.addWidget(self.window.ste_rms_window_input, 0, 1) # Row 0, Column 1 + parameter_layout.addWidget(label8, 0, 2) + parameter_layout.addWidget(label2, 1, 0) + parameter_layout.addWidget(self.window.ste_min_window_input, 1, 1) + parameter_layout.addWidget(label9, 1, 2) + parameter_layout.addWidget(label3, 2, 0) + parameter_layout.addWidget(self.window.ste_min_gap_input, 2, 1) + parameter_layout.addWidget(label10, 2, 2) + parameter_layout.addWidget(label4, 3, 0) + parameter_layout.addWidget(self.window.ste_epoch_length_input, 3, 1) + parameter_layout.addWidget(label11, 3, 2) + + parameter_layout.addWidget(label5, 4, 0) + parameter_layout.addWidget(self.window.ste_min_oscillation_input, 4, 1) + parameter_layout.addWidget(label6, 5, 0) + parameter_layout.addWidget(self.window.ste_rms_threshold_input, 5, 1) + parameter_layout.addWidget(label7, 6, 0) + parameter_layout.addWidget(self.window.ste_peak_threshold_input, 6, 1) + + parameter_layout.addWidget(self.window.STE_save_button, 7, 2) + + # Set the layout for the page + layout.addWidget(detection_groupbox) + tab.setLayout(layout) + return tab + + def create_detection_parameter_tab_mni(self): + tab = QWidget() + layout = QGridLayout() + + detection_groupbox = QGroupBox('Detection Parameters') + parameter_layout = QGridLayout(detection_groupbox) + + clear_layout(parameter_layout) + + # Create widgets + text_font = QFont('Arial', 11) + label1 = QLabel('Epoch Time') + label2 = QLabel('Epoch CHF') + label3 = QLabel('CHF Percentage') + label4 = QLabel('Min Window') + label5 = QLabel('Min Gap Time') + label6 = QLabel('Threshold Percentage') + label7 = QLabel('Baseline Window') + label8 = QLabel('Baseline Shift') + label9 = QLabel('Baseline Threshold') + label10 = QLabel('Baseline Minimum Time') + label11 = QLabel('sec') + label12 = QLabel('sec') + label13 = QLabel('sec') + label14 = QLabel('sec') + label15 = QLabel('sec') + label16 = QLabel('%') + label17 = QLabel('%') + + self.window.mni_epoch_time_input = QLineEdit() + self.window.mni_epoch_time_input.setFont(text_font) + self.window.mni_epoch_chf_input = QLineEdit() + self.window.mni_epoch_chf_input.setFont(text_font) + self.window.mni_chf_percentage_input = QLineEdit() + self.window.mni_chf_percentage_input.setFont(text_font) + self.window.mni_min_window_input = QLineEdit() + self.window.mni_min_window_input.setFont(text_font) + self.window.mni_min_gap_time_input = QLineEdit() + self.window.mni_min_gap_time_input.setFont(text_font) + self.window.mni_threshold_percentage_input = QLineEdit() + self.window.mni_threshold_percentage_input.setFont(text_font) + self.window.mni_baseline_window_input = QLineEdit() + self.window.mni_baseline_window_input.setFont(text_font) + self.window.mni_baseline_shift_input = QLineEdit() + self.window.mni_baseline_shift_input.setFont(text_font) + self.window.mni_baseline_threshold_input = QLineEdit() + self.window.mni_baseline_threshold_input.setFont(text_font) + self.window.mni_baseline_min_time_input = QLineEdit() + self.window.mni_baseline_min_time_input.setFont(text_font) + self.window.MNI_save_button = QPushButton('Save') + + # Add widgets to the grid layout + parameter_layout.addWidget(label1, 0, 0) # Row 0, Column 0 + parameter_layout.addWidget(self.window.mni_epoch_time_input, 0, 1) # Row 0, Column 1 + parameter_layout.addWidget(label11, 0, 2) + parameter_layout.addWidget(label2, 1, 0) + parameter_layout.addWidget(self.window.mni_epoch_chf_input, 1, 1) + parameter_layout.addWidget(label12, 1, 2) + parameter_layout.addWidget(label3, 2, 0) + parameter_layout.addWidget(self.window.mni_chf_percentage_input, 2, 1) + parameter_layout.addWidget(label16, 2, 2) + parameter_layout.addWidget(label4, 3, 0) + parameter_layout.addWidget(self.window.mni_min_window_input, 3, 1) + parameter_layout.addWidget(label13, 3, 2) + parameter_layout.addWidget(label5, 4, 0) + parameter_layout.addWidget(self.window.mni_min_gap_time_input, 4, 1) + parameter_layout.addWidget(label14, 4, 2) + parameter_layout.addWidget(label6, 5, 0) + parameter_layout.addWidget(self.window.mni_threshold_percentage_input, 5, 1) + parameter_layout.addWidget(label17, 5, 2) + parameter_layout.addWidget(label7, 6, 0) + parameter_layout.addWidget(self.window.mni_baseline_window_input, 6, 1) + parameter_layout.addWidget(label15, 6, 2) + + parameter_layout.addWidget(label8, 7, 0) + parameter_layout.addWidget(self.window.mni_baseline_shift_input, 7, 1) + parameter_layout.addWidget(label9, 8, 0) + parameter_layout.addWidget(self.window.mni_baseline_threshold_input, 8, 1) + parameter_layout.addWidget(label10, 9, 0) + parameter_layout.addWidget(self.window.mni_baseline_min_time_input, 9, 1) + + parameter_layout.addWidget(self.window.MNI_save_button, 10, 2) + + # Set the layout for the page + layout.addWidget(detection_groupbox) + tab.setLayout(layout) + return tab + + def create_detection_parameter_tab_hil(self): + tab = QWidget() + layout = QGridLayout() + + detection_groupbox = QGroupBox('Detection Parameters') + parameter_layout = QGridLayout(detection_groupbox) + + clear_layout(parameter_layout) + + # Create widgets + text_font = QFont('Arial', 11) + label1 = QLabel('Sample Frequency') + label2 = QLabel('Pass Band') + label3 = QLabel('Stop Band') + label4 = QLabel('Epoch Length') + label5 = QLabel('SD Threshold') + label6 = QLabel('Min Window') + label8 = QLabel('sec') + label9 = QLabel('sec') + label10 = QLabel('sec') + label11 = QLabel('Hz') + label12 = QLabel('Hz') + label13 = QLabel('Hz') + + self.window.hil_sample_freq_input = QLineEdit() + self.window.hil_sample_freq_input.setFont(text_font) + self.window.hil_pass_band_input = QLineEdit() + self.window.hil_pass_band_input.setFont(text_font) + self.window.hil_stop_band_input = QLineEdit() + self.window.hil_stop_band_input.setFont(text_font) + self.window.hil_epoch_time_input = QLineEdit() + self.window.hil_epoch_time_input.setFont(text_font) + self.window.hil_sd_threshold_input = QLineEdit() + self.window.hil_sd_threshold_input.setFont(text_font) + self.window.hil_min_window_input = QLineEdit() + self.window.hil_min_window_input.setFont(text_font) + self.window.HIL_save_button = QPushButton('Save') + + # Add widgets to the grid layout + parameter_layout.addWidget(label1, 0, 0) # Row 0, Column 0 + parameter_layout.addWidget(self.window.hil_sample_freq_input, 0, 1) # Row 0, Column 1 + parameter_layout.addWidget(label11, 0, 2) + parameter_layout.addWidget(label2, 1, 0) + parameter_layout.addWidget(self.window.hil_pass_band_input, 1, 1) + parameter_layout.addWidget(label12, 1, 2) + parameter_layout.addWidget(label3, 2, 0) + parameter_layout.addWidget(self.window.hil_stop_band_input, 2, 1) + parameter_layout.addWidget(label13, 2, 2) + parameter_layout.addWidget(label4, 3, 0) + parameter_layout.addWidget(self.window.hil_epoch_time_input, 3, 1) + parameter_layout.addWidget(label8, 3, 2) + parameter_layout.addWidget(label5, 4, 0) + parameter_layout.addWidget(self.window.hil_sd_threshold_input, 4, 1) + # parameter_layout.addWidget(label9, 4, 2) + parameter_layout.addWidget(label6, 5, 0) + parameter_layout.addWidget(self.window.hil_min_window_input, 5, 1) + parameter_layout.addWidget(label10, 5, 2) + + parameter_layout.addWidget(self.window.HIL_save_button, 6, 2) + + # Set the layout for the page + layout.addWidget(detection_groupbox) + tab.setLayout(layout) + return tab + + def create_detection_parameter_tab_yasa(self): + tab = QWidget() + layout = QGridLayout() + + detection_groupbox = QGroupBox('Detection Parameters') + parameter_layout = QGridLayout(detection_groupbox) + + clear_layout(parameter_layout) + + # Create widgets + text_font = QFont('Arial', 11) + label1 = QLabel('Freq Spindle') + label2 = QLabel('Freq Broad') + label3 = QLabel('Duration') + label4 = QLabel('Min Distance') + label5 = QLabel('rel_pow (thresh)') + label6 = QLabel('corr (thresh)') + label7 = QLabel('rms (thresh)') + label8 = QLabel('Hz') + label9 = QLabel('Hz') + label10 = QLabel('sec') + label11 = QLabel('ms') + + self.window.yasa_freq_sp_input = QLineEdit() + self.window.yasa_freq_sp_input.setFont(text_font) + self.window.yasa_freq_broad_input = QLineEdit() + self.window.yasa_freq_broad_input.setFont(text_font) + self.window.yasa_duration_input = QLineEdit() + self.window.yasa_duration_input.setFont(text_font) + self.window.yasa_min_distance_input = QLineEdit() + self.window.yasa_min_distance_input.setFont(text_font) + self.window.yasa_thresh_rel_pow_input = QLineEdit() + self.window.yasa_thresh_rel_pow_input.setFont(text_font) + self.window.yasa_thresh_corr_input = QLineEdit() + self.window.yasa_thresh_corr_input.setFont(text_font) + self.window.yasa_thresh_rms_input = QLineEdit() + self.window.yasa_thresh_rms_input.setFont(text_font) + self.window.YASA_save_button = QPushButton('Save') + + # Add widgets to the grid layout + parameter_layout.addWidget(label1, 0, 0) # Row 0, Column 0 + parameter_layout.addWidget(self.window.yasa_freq_sp_input, 0, 1) # Row 0, Column 1 + parameter_layout.addWidget(label8, 0, 2) + parameter_layout.addWidget(label2, 1, 0) + parameter_layout.addWidget(self.window.yasa_freq_broad_input, 1, 1) + parameter_layout.addWidget(label9, 1, 2) + parameter_layout.addWidget(label3, 2, 0) + parameter_layout.addWidget(self.window.yasa_duration_input, 2, 1) + parameter_layout.addWidget(label10, 2, 2) + parameter_layout.addWidget(label4, 3, 0) + parameter_layout.addWidget(self.window.yasa_min_distance_input, 3, 1) + parameter_layout.addWidget(label11, 3, 2) + + parameter_layout.addWidget(label5, 4, 0) + parameter_layout.addWidget(self.window.yasa_thresh_rel_pow_input, 4, 1) + parameter_layout.addWidget(label6, 5, 0) + parameter_layout.addWidget(self.window.yasa_thresh_corr_input, 5, 1) + parameter_layout.addWidget(label7, 6, 0) + parameter_layout.addWidget(self.window.yasa_thresh_rms_input, 6, 1) + + parameter_layout.addWidget(self.window.YASA_save_button, 7, 2) + + # Set the layout for the page + layout.addWidget(detection_groupbox) + tab.setLayout(layout) + return tab + + def create_frame_biomarker_hfo(self): + # self.frame_biomarker_layout = QHBoxLayout(self.frame_biomarker_type) + self.window.frame_biomarker_layout.addStretch(1) + + # Add three QLabel widgets to the QFrame + label_type1 = QLabel("Artifact") + label_type1.setFixedWidth(150) + label_type2 = QLabel("spk-HFO") + label_type2.setFixedWidth(150) + label_type3 = QLabel("HFO") + label_type3.setFixedWidth(150) + + line_type1 = QLineEdit() + line_type1.setReadOnly(True) + line_type1.setFrame(True) + line_type1.setFixedWidth(50) + line_type1.setStyleSheet("background-color: orange;") + line_type2 = QLineEdit() + line_type2.setReadOnly(True) + line_type2.setFrame(True) + line_type2.setFixedWidth(50) + line_type2.setStyleSheet("background-color: purple;") + line_type3 = QLineEdit() + line_type3.setReadOnly(True) + line_type3.setFrame(True) + line_type3.setFixedWidth(50) + line_type3.setStyleSheet("background-color: green;") + + # Add labels to the layout + self.window.frame_biomarker_layout.addWidget(line_type1) + self.window.frame_biomarker_layout.addWidget(label_type1) + self.window.frame_biomarker_layout.addWidget(line_type2) + self.window.frame_biomarker_layout.addWidget(label_type2) + self.window.frame_biomarker_layout.addWidget(line_type3) + self.window.frame_biomarker_layout.addWidget(label_type3) + self.window.frame_biomarker_layout.addStretch(1) + + def create_frame_biomarker_spindle(self): + # self.frame_biomarker_layout = QHBoxLayout(self.frame_biomarker_type) + self.window.frame_biomarker_layout.addStretch(1) + + # Add three QLabel widgets to the QFrame + label_type1 = QLabel("Artifact") + label_type1.setFixedWidth(150) + label_type2 = QLabel("spk-Spindle") + label_type2.setFixedWidth(150) + label_type3 = QLabel("Spindle") + label_type3.setFixedWidth(150) + + line_type1 = QLineEdit() + line_type1.setReadOnly(True) + line_type1.setFrame(True) + line_type1.setFixedWidth(50) + line_type1.setStyleSheet("background-color: orange;") + line_type2 = QLineEdit() + line_type2.setReadOnly(True) + line_type2.setFrame(True) + line_type2.setFixedWidth(50) + line_type2.setStyleSheet("background-color: purple;") + line_type3 = QLineEdit() + line_type3.setReadOnly(True) + line_type3.setFrame(True) + line_type3.setFixedWidth(50) + line_type3.setStyleSheet("background-color: green;") + + # Add labels to the layout + self.window.frame_biomarker_layout.addWidget(line_type1) + self.window.frame_biomarker_layout.addWidget(label_type1) + self.window.frame_biomarker_layout.addWidget(line_type2) + self.window.frame_biomarker_layout.addWidget(label_type2) + self.window.frame_biomarker_layout.addWidget(line_type3) + self.window.frame_biomarker_layout.addWidget(label_type3) + self.window.frame_biomarker_layout.addStretch(1) + + def add_widget(self, layout, widget): + attr = getattr(self.window.window_widget, layout) + method = getattr(attr, 'addWidget') + method(widget) \ No newline at end of file diff --git a/src/views/mini_plot_view.py b/src/views/mini_plot_view.py new file mode 100644 index 0000000..003924b --- /dev/null +++ b/src/views/mini_plot_view.py @@ -0,0 +1,41 @@ +from PyQt5 import QtWidgets +import pyqtgraph as pg + + +class MiniPlotView(QtWidgets.QGraphicsView): + def __init__(self, plot_widget: pg.PlotWidget): + super(MiniPlotView, self).__init__() + self.plot_widget = plot_widget + self._init_plot_widget(plot_widget) + + def _init_plot_widget(self, plot_widget): + plot_widget.setMouseEnabled(x=False, y=False) + plot_widget.getPlotItem().hideAxis('bottom') + plot_widget.getPlotItem().hideAxis('left') + plot_widget.setBackground('w') + + def enable_axis_information(self): + self.plot_widget.getPlotItem().showAxis('bottom') + self.plot_widget.getPlotItem().showAxis('left') + + def add_linear_region(self): + self.linear_region = pg.LinearRegionItem([0, 0], movable=False) + self.linear_region.setZValue(-20) + self.plot_widget.addItem(self.linear_region) + + def update_highlight_window(self, start, end, height): + self.linear_region.setRegion([start,end]) + self.linear_region.setZValue(height) + + def plot_biomarker(self, start_time, end_time, top_value, color, width): + self.plot_widget.plot([start_time, end_time], [top_value, top_value], pen=pg.mkPen(color=color, width=width)) + + def set_miniplot_title(self, title, height): + self.plot_widget.getAxis('left').setTicks([[(height, f' {title} ')]]) + + def set_x_y_range(self, x_range, y_range): + self.plot_widget.setXRange(x_range[0], x_range[1]) + self.plot_widget.setYRange(y_range[0], y_range[1]) + + def clear(self): + self.plot_widget.clear()