braindecode / braindecode

Deep learning software to decode EEG, ECG or MEG signals
https://braindecode.org/
BSD 3-Clause "New" or "Revised" License
782 stars 175 forks source link

A minimal rewriting of PhysioNet Chambon2018 #616

Closed OverLordGoldDragon closed 4 months ago

OverLordGoldDragon commented 4 months ago

Hello,

I was looking for understandable code that decodes PhysioNet into (x, y) that can be fed to a model. I found the original example too high-level, and source code too dense for my purposes to manipulate, so I rewrote it - code & details below.

Is braindecode interested in featuring this as an example? I'm willing to make a few changes if needed. (Also, I mostly intended it as a script rather than an html, but I can adjust it.) If interested, I could also host it as a gist, which could be linked in the example.

My issues with it aside, I'm glad for the example! Spares lots of trouble.

expand ```python # -*- coding: utf-8 -*- """ Minimal-code rewriting of BrainDecode's https://braindecode.org/stable/auto_examples/applied_examples/plot_sleep_staging_chambon2018.html https://github.com/braindecode/braindecode/blob/master/examples/applied_examples/plot_sleep_staging_chambon2018.py It - Eliminates all (explicit) dependence on `braindecode` - Minimizes dependence on `mne` My motivation was, understandable code that decodes PhysioNet into `(x, y)` that can be fed to a model. I found the original example too high-level, and source code too dense to manipulate for my purposes. Rewritten code is - heavily cut down in total number of lines - tested to completely reproduce the original BrainDecode script's results (note, requires `torch.use_deterministic_algorithms(True)`). Caveat: that's with a preprocessing step (lowpass filtering) omitted, just comment that out in original source (or reimplement it here). - includes some fixes and extensions of functionality - commented, explaining/justifying omitting of code with respect to `braindecode` source code (meant to be read alongside) - split into "helpers" (function definitions) and "execution"; former could be moved into its own file and imported - a rewriting of v0.8.0, https://github.com/braindecode/braindecode/tree/v0.8 Tip: download PhysioNet from Google Cloud bucket, then set `data_loaddir`, it was x10 faster for me than downloading directly from physionet.org (what the original script does). Disclaimer, while it was free for me, it says it's paid, so I can't guarantee that. Instructions at https://physionet.org/content/sleep-edfx/1.0.0/ Note: example uses different `subject_ids` and `recording_ids`. """ # USER CONFIG ---------------------------------------------------------------- # where PhysioNet source data is already stored, if exists; if None, will install # from scratch data_loaddir = None # where to save processed data to; defaults to current working directory data_savedir = None # subject and recording IDs subject_ids = [0, 1, 2, 3] recording_ids = [1, 2] # use GPU if available use_gpu = True #%% ########################################################################## # Imports # ------- # should set env var before running other imports import os if data_loaddir is not None: os.environ['PHYSIONET_SLEEP_PATH'] = data_loaddir import random import bisect import numpy as np import torch import torch.nn as nn # torch.use_deterministic_algorithms(True) import mne from mne.datasets.sleep_physionet.age import fetch_data from sklearn.preprocessing import scale as standard_scale ############################################################################## # HELPER FUNCTIONS # **************** #%% ########################################################################## # Loading & saving data # --------------------- # Files are named e.g. # # SC4001E0-PSG.edf # 01234567 # # where # # `34` = `subject_id` (i.e. 0) # `5` = `recording_id` (i.e. 1) # # and, # # SC4001EC-Hypnogram.edf # # where # # `PSG` = data # `Hypnogram` = metadata (including labels) # # Notes: # # - "Data" includes stuff other than EEG, exclude via `exclude_chs`. # - `p = paths[0]` is a pair of paths, where `p[0]` = PSG, `p[1]` = Hypnogram. # # define a function to reuse later def process_path(p, exclude_chs, data_savedir): raw, annots, subj_num, sess_num = _load_data(p, exclude_chs) raw = _drop_unlabeled_data(raw, annots) raw = _preprocess(raw) # by reading further code (window processing), we determine that we need # to keep only the following: # # for `raw`: # `raw._data` # for `annotations` (== `raw.annotations`): # `annotations.onset` # `annotations.description` # `annotations.duration` # data = { 'data': raw._data, 'onset': raw.annotations.onset, 'description': raw.annotations.description, 'duration': raw.annotations.duration, '_first_time': raw._first_time, # for `braindecode_ver = True` } savename = os.path.basename(p[0]).replace('-PSG', '').replace('.edf', '.npz') savepath = os.path.join(data_savedir, savename) np.savez(savepath, **data) return data, savepath def _load_data(p, exclude_chs): # load data & labels ----------------------------------------------------- p_data, p_meta = p # (`preload` to load the data into RAM) raw = mne.io.read_raw_edf(p_data, preload=True, exclude=exclude_chs) annots = mne.read_annotations(p_meta) # Get subject and recording number --------------------------------------- basename = os.path.basename(p_data) subj_num = int(basename[3:5]) sess_num = int(basename[5]) return raw, annots, subj_num, sess_num def _drop_unlabeled_data(raw, annots): # set `raw.annotations` from `annots` ------------------------------------ # - each `a = annots[0]` has `a['onset']` and `a['duration']` # - `a` are cropped such that they remain within (inclusive) `tmin = 0` and # `tmax = raw.times[-1] + 1 / raw.info['sfreq']`, where sfreq = sampling # freq = 100 Hz. "Cropped" means their `'onset'` and `'duration'` are # adjusted. raw.set_annotations(annots, emit_warning=False) # crop data to exclude unlabeled segments -------------------------------- braindecode_ver = True if not braindecode_ver: # "labels" are over data's time segments. E.g. `x[:20000]` can be # "Sleep stage W" (wake), and `x[20000:25000]` be "Sleep stage 1", etc. # "Sleep stage ?" is unlabeled. mask = [x[-1] != '?' for x in annots.description] sleep_event_inds = np.where(mask)[0] # above assumes there's only one such sleep stage, and that it's the last # one; check both assumptions assert mask.count(False) == 1, mask assert not mask[-1], mask else: # see `not braindecode_ver` comments mask = [ x[-1] in ['1', '2', '3', '4', 'R'] for x in annots.description] sleep_event_inds = np.where(mask)[0] # Crop raw (also crops labels) # determine `tmax` as first timestamp of last stage before # "Sleep stage ?", plus the duration of that stage, minus one sample (`dT`) # (otherwise we end up at first timestamp of "Sleep stage ?", and # `crop` is inclusive on `tmax`). dT = 1 / raw.info["sfreq"] a_tmin = annots[sleep_event_inds[0]] a_tmax = annots[sleep_event_inds[-1]] if not braindecode_ver: tmin = a_tmin['onset'] tmax = a_tmax['onset'] + a_tmax['duration'] - dT else: crop_wake_mins = 30 tmin = a_tmin['onset'] - crop_wake_mins * 60 tmax = a_tmax['onset'] + a_tmax['duration'] - dT + crop_wake_mins * 60 # internally, this converts `tmin`, `tmax` to indices, and does something # like `x = x[tmin:tmax]`. # it also correspondingly crops labels, via `set_annotations`. # Internals (for labels): # - `tmin` for `set_annotations` is set from `_first_time` (and another # var we can treat as zero). # This is updated in `crop` from the `tmin` we specify (see `_first_time` # definition as `@property`). # - `tmax` for `set_annotations` is set from `times[-1]` (+ dT). # This is updated in `crop` from the `tmax` we specify (see `times` # definition as `@property`). # - This drops annotations that are completely out of range of `tmin`, # `tmax`. Here, it amounts to dropping the annotation for "Sleep stage ?". raw.crop(tmin=max(tmin, raw.times[0]), tmax=min(tmax, raw.times[-1])) # Rename EEG channels ---------------------------------------------------- raw.rename_channels({nm: nm.replace('EEG ', '') for nm in raw.ch_names}) return raw def _preprocess(raw): # internally, this justs updates `raw._data`, with # - handling for multiprocessing # - checking that out.shape == in.shape # - loading `_data` if it isn't (for us, it is, via `preload=True`) # V -> uV raw._data = raw._data * 1e6 return raw #%% ########################################################################## # Creating windows # ---------------- def _create_windows_from_events(p, mapping, sfreq): # load data d = np.load(p) # `a_` for `annotations_` data, a_description, a_onset, a_duration, _first_time = [ d[nm] for nm in ('data', 'description', 'onset', 'duration', '_first_time') ] events = _events_from_annotations( a_description, a_onset, _first_time, mapping, sfreq) # the lib method also returns `event_id_`, which is redundant for us # - (it is a copy of `mapping`, unless `mapping` has some keys that # `annotations.description` doesn't, but we then only use `event_id_` # to check whether what's in `annotations.description` is in `event_id_`, # which is circular and same as checking directly against `mapping`) onsets = events[:, 0] # Onsets are relative to the beginning of the recording filtered_durations = np.array([ dur for dur, desc in zip(a_duration, a_description) if desc in mapping ]) stops = onsets + (filtered_durations * sfreq).astype(int) # sanity check; note, `stops` is used exclusively, i.e. `start:stop` # don't need the `raw.first_samp` in `raw.first_samp + raw.n_times` since we # commented out `+= raw.first_samp` (upon `onsets`, in # `_events_from_annotations`) earlier; and, # `raw.n_times == raw._data.shape[-1]` # (i.e. the full assert is `stops[-1] <= raw.first_samp + raw.n_times`) assert stops[-1] <= data.shape[-1] # this no longer executes since we commented out `+= raw.first_samp` earlier # onsets = onsets - raw.first_samp # stops = stops - raw.first_samp # generate windows window_size_samples = 3000 window_stride_samples = 3000 drop_last_window = False i_trials, starts, stops = _compute_window_inds( onsets, stops, window_size_samples, window_stride_samples, drop_last_window) # generate window events description = events[:, -1] # events = [[start, window_size_samples, description[i_trial]] # for start, i_trial in zip(starts, i_trials)] # events = np.array(events) # description_windows = events[:, -1] description_windows = np.array([description[i_trial] for i_trial in i_trials]) windows_ds = WindowsDataset( data, target=description_windows, i_start_in_trial=starts, i_stop_in_trial=stops ) return windows_ds def _events_from_annotations(description, onset, _first_time, mapping, fs): # minimally implements `_select_annotations_based_on_description` -------- event_sel = [i for i, d in enumerate(description) if d in mapping] # Convert onsets to sample indices. Internals: =========================== # - `annotations.onset` = timestamps, in seconds, of start times of labels # (where each "label", again, is over a time interval over data, e.g. # `x[20000:25000]`). # - `len(annotations.onset) == len(labels)`. # - `annotations.orig_time` = `raw.info["meas_date"]`, as long as # `annotations = mne.read_annotations(path)` is used. # - `raw.info["meas_date"]` = measurement date, a `datetime` object created # from metadata in the (-PSG) EDF file. # minimally implements --------------------------------------------------- # # inds = raw.time_as_index(times=annotations.onset, use_rounding=True, # origin=annotations.orig_time) # # we won't end up needing `origin` (see below), so don't fetch it. # origin = annotations.orig_time times = onset # Internals: # - `self._first_time = self.first_samp / self.info['sfreq']` # - `self.first_samp = self._cropped_samp` # - `self._cropped_samp = first_samps[0]` if `raw.crop()` wasn't used # with `tmin != 0`, else it's modified (in this case to # `tmin * self.info['sfreq']`) # - `first_samps = (0,)` # - (`braindecode_ver = False` only) Hence, `raw._first_time == 0`, and # since `origin == raw.info["meas_date"]` (see above), `delta == 0`, # so we can skip all below code (and `times` is already 1d) # Since `braindecode_ver = True` is supported, execute the relevant portion # in `raw.time_as_index`, but rewritten: # # `origin - first_samp_in_abs_time` # <=> # `raw.info["meas_date"] - (raw.info["meas_date"] + raw._first_time)` # <=> # `- raw._first_time` delta = - _first_time times += delta # `raw.times[0]` is always zero in our case (RawEDF loaded the way we have) # index = (np.atleast_1d(times) - raw.times[0]) * fs index = times * fs inds = np.round(index).astype(int) # ======================================================================== # Executes if `if annotations.orig_time is not None:`, which is the case here. # Do not execute this, so we don't have to `-= raw.first_samp` later, so # we don't have to store `raw.first_samp` anywhere # inds += raw.first_samp # `annotations.description` -> numeric values based on `mapping`. E.g. if # # mapping == {'Sleep stage W': 0, 'Sleep stage 1': 1}` # annotations.description == ['Sleep stage 1', 'Sleep stage W', # 'Sleep stage W'] # # then # # values == [1, 0, 0] # # but ignoring `event_sel` (indices of selected labels based on whether they # were in `mapping`). # values = [mapping[kk] for kk in description[event_sel]] # Apply `event_sel` to `inds` inds = inds[event_sel] # This simply concatenates the arrays into `(n_events, 3)`, and casts to int events = np.c_[inds, np.zeros(len(inds)), values].astype(int) return events def _compute_window_inds(starts, stops, size, stride, drop_last_window): assert not any(size > (stops-starts)) i_trials, window_starts = [], [] for start_i, (start, stop) in enumerate(zip(starts, stops)): # Generate possible window starts, with given stride, between # starts and stops (i.e. original trial onsets and stops, shifted by # start_offset and stop_offset, respectively) possible_starts = np.arange(start, stop, stride) # Possible window start is actually a start, if window size fits in # trial start and stop for i_window, s in enumerate(possible_starts): if (s + size) <= stop: window_starts.append(s) i_trials.append(start_i) # If the last window start + window size is not the same as # stop + stop_offset, create another window that overlaps and stops # at onset + stop_offset if not drop_last_window: if window_starts[-1] + size != stop: window_starts.append(stop - size) i_trials.append(start_i) # Set window stops to be event stops (rather than trial stops) window_stops = np.array(window_starts) + size assert len(i_trials) == len(window_starts) == len(window_stops) return i_trials, window_starts, window_stops def _preprocess_windows(windows_ds): windows_ds.data = standard_scale(windows_ds.data, axis=1) windows_ds.data = windows_ds.data.copy().astype('float32') class WindowsDataset(): def __init__(self, data, target, i_start_in_trial, i_stop_in_trial): self.data = data self.y = np.asarray(target, dtype='int64') # skorch expects int64 self.i_start_in_trial = i_start_in_trial self.inds = np.c_[i_start_in_trial, i_stop_in_trial] def __getitem__(self, index): i_start, i_end = self.inds[index] X = self.data[:, i_start:i_end] y = self.y[index] return X, y def __len__(self): return len(self.y) #%% ########################################################################## # Creating dataset objects # ------------------------ class Sampler(): def __init__(self, i_start_in_trial_all, n_windows, n_windows_stride, randomize=False): self.n_windows = n_windows self.n_windows_stride = n_windows_stride self.randomize = randomize # braindecode applies `groupby` by `'subject'`, `'recording'` upon # dataframe, then resets index, and operates on said indices below, # meaning we first generate the indices within each `i_start_in_trial_all` # (where "each" refers to `'subject'` and `'recording'` operate on those, # then concatenate idxs_all = [list(range(len(i_start_in_trial_all[0])))] for length in map(len, i_start_in_trial_all[1:]): idx_last = idxs_all[-1][-1] idxs_all.append(list(range(idx_last + 1, idx_last + 1 + length))) end_offset = 1 - n_windows self.start_inds = np.concatenate( [idxs[:end_offset:self.n_windows_stride] for idxs in idxs_all] ) def __len__(self): return len(self.start_inds) def __iter__(self): if self.randomize: start_inds = np.random.permutation(self.start_inds) else: start_inds = self.start_inds for start_ind in start_inds: yield tuple(range(start_ind, start_ind + self.n_windows)) class ConcatDataset(torch.utils.data.Dataset): # Merges `BaseConcatDataset` and `ConcatDataset` classes. # `skorch` requires this to be an instance of `Dataset` def __init__(self, datasets, target_transform=None): self.datasets = datasets # for script readability, we assign this later self.target_transform = (target_transform if target_transform is not None else lambda x: x) self.cumulative_sizes = np.cumsum(list(map(len, self.datasets))) def __getitem__(self, idxs): """ idxs : tuple / list Indices of windows and targets to return (concatenated). The target output can be modified on the fly by the ``target_transform`` parameter. """ X, y = [], [] for idx in idxs: out_i = self._getitem(idx) X.append(out_i[0]) y.append(out_i[1]) X = np.stack(X, axis=0) y = self.target_transform(np.array(y)) return X, y def _getitem(self, idx): assert idx >= 0 and idx < len(self) dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) sample_idx = (idx if dataset_idx == 0 else idx - self.cumulative_sizes[dataset_idx - 1]) return self.datasets[dataset_idx][sample_idx] def __len__(self): return self.cumulative_sizes[-1] #%% ########################################################################## # Creating model # -------------- class SleepStagerChambon2018(nn.Module): """Feature extractor only.""" def __init__(self, n_chans, n_outputs, n_times, sfreq, n_conv_chs=8, apply_batch_norm=False): super().__init__() self.n_chans = n_chans self.n_outputs = n_outputs self.n_times = n_times self.n_conv_chs = n_conv_chs self.apply_batch_norm = apply_batch_norm assert self.n_chans > 1 # handle params time_conv_size_s = 0.5 max_pool_size_s = time_conv_size_s / 4 pad_size_s = time_conv_size_s / 2 time_conv_size, max_pool_size, pad_size = ( int(np.ceil(x * sfreq)) for x in (time_conv_size_s, max_pool_size_s, pad_size_s) ) # handle certain layers self.spatial_conv = nn.Conv2d(1, self.n_chans, (self.n_chans, 1)) batch_norm = (nn.BatchNorm2d if apply_batch_norm else nn.Identity) # make feature extractor self.feature_extractor = nn.Sequential( nn.Conv2d( 1, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)), batch_norm(n_conv_chs), nn.ReLU(), nn.MaxPool2d((1, max_pool_size)), nn.Conv2d( n_conv_chs, n_conv_chs, (1, time_conv_size), padding=(0, pad_size)), batch_norm(n_conv_chs), nn.ReLU(), nn.MaxPool2d((1, max_pool_size)), nn.Flatten(), ) # length of last layer (for later) with torch.no_grad(): self.len_last_layer = len(self.feature_extractor( torch.Tensor(1, 1, self.n_chans, self.n_times) ).flatten()) def forward(self, x): """x: batch of EEG windows of shape (batch_size, n_channels, n_times)""" x = x.unsqueeze(1) x = self.spatial_conv(x) x = x.transpose(1, 2) return self.feature_extractor(x) class TimeDistributed(nn.Module): """Apply module on a sequence of windows and return their concatenation (see `forward`): `(batch_size, seq_len, n_channels, n_times)` -> `(batch_size, seq_len, output_size)` Useful with sequence-to-prediction models (e.g. sleep stager which must map a sequence of consecutive windows to the label of the middle window in the sequence). """ def __init__(self, module): super().__init__() self.module = module def forward(self, x): """ x: sequence of windows of shape (batch_size, seq_len, n_channels, n_times) Returns output of shape (batch_size, seq_len, output_size) """ b, s, c, t = x.shape out = self.module(x.view(b * s, c, t)) return out.view(b, s, -1) #%% ########################################################################## # Creating train loop # ------------------- from skorch.helper import predefined_split from skorch.classifier import NeuralNetClassifier from skorch.callbacks import BatchScoring, EpochScoring, EpochTimer, PrintLog from skorch.utils import noop, train_loss_score, valid_loss_score class EEGClassifier(NeuralNetClassifier): """ Is `NeuralNetClassifier`, with `_default_callbacks` overridden. All arguments are passed straight into `NeuralNetClassifier`. Note, `NeuralNetClassifier` doesn't assume softmax activation and calls the loss function directly (without applying e.g. log). Parameter note: `iterator_train__shuffle` (default True) defines whether train dataset will be shuffled. As `skorch` does not shuffle the train dataset by default, this one overwrites this option. """ def __init__( self, module, criterion=None, callbacks=None, iterator_train__shuffle=True, iterator_train__drop_last=True, **kwargs ): super().__init__( module, criterion=criterion, callbacks=callbacks, iterator_train__shuffle=iterator_train__shuffle, iterator_train__drop_last=iterator_train__drop_last, **kwargs, ) @property def _default_callbacks(self): return [ ("epoch_timer", EpochTimer()), ( "train_loss", BatchScoring( train_loss_score, name="train_loss", on_train=True, target_extractor=noop, ), ), ( "valid_loss", BatchScoring( valid_loss_score, name="valid_loss", target_extractor=noop, ), ), ("print_log", PrintLog()), ( "valid_acc", EpochScoring( "accuracy", name="valid_acc", lower_is_better=False, ) ) ] # `skorch` default is # return [ # ('epoch_timer', EpochTimer()), # ('train_loss', PassthroughScoring( # name='train_loss', # on_train=True, # )), # ('valid_loss', PassthroughScoring( # name='valid_loss', # )), # ('print_log', PrintLog()), # ] # this unites `_EEGNeuralNet._default_callbacks` and # `EEGClassifier._default_callbacks` (latter excluding the fact that it # appends to former) # Excluded inherited classes explanation --------------------------------- # _EEGNeuralNet: # Running original script with and without this class changed nothing. # Inspecting the code, it appears to handle certain configurations that # aren't used here (e.g. "cropping"). # Excluded arguments explanation ----------------------------------------- # aggregate_predictions: # Was only used in `predict_proba`, which was dropped # Excluded methods explanation ------------------------------------------- # get_iterator: # only does something via `ThrowAwayIndexLoader`, which only does # something if iterator returns x.ndim==3, which doesn't happen here # predict_proba: # only does something if `cropped=True`, which isn't the case here # get_loss: # only does something if `isinstance(self.criterion_, torch.nn.NLLLoss)`, # which isn't the case here # predict: # completely identical to inherited class's definition, is likely # redefined for docs clarity or future changes # predict_trials: # meant to be used with `cropped=True`, which isn't the case here # _get_n_outputs: # unused in `clf.fit()`, checked by inserting `1/0` here # Excluded attributes explanation ---------------------------------------- # _last_window_inds_: # TL;DR unused #%% ########################################################################## # EXECUTION # ********* #%% ########################################################################## # Convert and save data as numpy arrays # ------------------------------------- # handle configs if data_savedir is None: data_savedir = os.getcwd() # set excluded channels exclude_chs = ('EOG horizontal', 'Resp oro-nasal', 'EMG submental', 'Temp rectal', 'Event marker') # merge stages 3 and 4 following AASM standards mapping = { 'Sleep stage W': 0, 'Sleep stage 1': 1, 'Sleep stage 2': 2, 'Sleep stage 3': 3, 'Sleep stage 4': 3, 'Sleep stage R': 4 } # sampling freq, obtained via `raw.info['sfreq']` # (hard-coded here for cleaner code) sfreq = 100 # Fetch paths, generate ids paths = fetch_data(subject_ids, recording=recording_ids, on_missing='warn') # For rest of this script, generalize original script's example to any # `subject_ids` and `recording_ids` by not assuming there's two of former and # one of latter. # Below for-loop ordering follows that of `fetch_data`. ids = [(sid, rid) for sid in subject_ids for rid in recording_ids] # Map ids to paths ipaths = {id_: p for id_, p in zip(ids, paths)} # Store processed data paths isavepaths = {} for _id, p in ipaths.items(): raw, psave = process_path(p, exclude_chs, data_savedir) isavepaths[_id] = psave #%% ########################################################################## # Create windows # -------------- iwindows_ds = {} for id_, p in isavepaths.items(): windows_ds = _create_windows_from_events(p, mapping, sfreq) _preprocess_windows(windows_ds) iwindows_ds[id_] = windows_ds #%% ########################################################################## # Create dataset objects, make train-test split # --------------------------------------------- # split by subject, so train and validation have different subjects split_ids = dict( train=[_id for _id in ids if _id[0] in subject_ids[::2]], valid=[_id for _id in ids if _id[0] in subject_ids[1::2]], ) train_set = ConcatDataset([iwindows_ds[id_] for id_ in split_ids['train']]) valid_set = ConcatDataset([iwindows_ds[id_] for id_ in split_ids['valid']]) # make samplers n_windows = 3 n_windows_stride = 3 train_sampler = Sampler( [windows_ds.i_start_in_trial for windows_ds in train_set.datasets], n_windows, n_windows_stride, randomize=True ) valid_sampler = Sampler( [windows_ds.i_start_in_trial for windows_ds in valid_set.datasets], n_windows, n_windows_stride ) #%% Make label transformer # Use label of center window in the sequence def get_center_label(x): if isinstance(x, int): return x return x[np.ceil(len(x) / 2).astype(int)] if len(x) > 1 else x train_set.target_transform = get_center_label valid_set.target_transform = get_center_label #%% Compute class weights y_train = [train_set[idx][1] for idx in train_sampler] # replicate `sklearn.utils.compute_class_weight` for `class_weight='balanced'` classes = np.unique(y_train) class_weights = len(y_train) / ( len(classes) * np.array([y_train.count(c) for c in classes])) #%% ########################################################################## # Create model # ------------ # check if GPU is available, assuming the user wants it cuda = use_gpu and torch.cuda.is_available() device = 'cuda' if cuda else 'cpu' if cuda: # faster but lowers reproducibility torch.backends.cudnn.benchmark = True # set seeds for reproducibility seed = 31 random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if cuda: torch.cuda.manual_seed_all(seed) # Reproducibility caveats: # - More info on reproducibility in torch: # https://pytorch.org/docs/stable/notes/randomness.html # - In some cases, may need to set `PYTHONHASHSEED` env var before running script: # https://forums.fast.ai/t/solved-reproducibility-where-is-the-randomness-coming-in/31628/14 # - `torch.use_deterministic_algorithms(True)` isn't used, also plays a role n_classes = 5 # Extract number of channels and time steps from dataset n_channels, input_size_samples = train_set[(0,)][0][0].shape feat_extractor = SleepStagerChambon2018( n_channels, n_outputs=n_classes, n_times=input_size_samples, sfreq=sfreq, ) model = nn.Sequential( TimeDistributed(feat_extractor), # apply model on each 30-s window nn.Sequential( # apply linear layer on concatenated feature vectors nn.Flatten(start_dim=1), nn.Dropout(0.5), nn.Linear(feat_extractor.len_last_layer * n_windows, n_classes) ) ) if cuda: model = model.cuda() #%% ########################################################################## # Create train loop # ----------------- lr = 1e-3 batch_size = 32 n_epochs = 10 train_bal_acc = EpochScoring( scoring='balanced_accuracy', on_train=True, name='train_bal_acc', lower_is_better=False) valid_bal_acc = EpochScoring( scoring='balanced_accuracy', on_train=False, name='valid_bal_acc', lower_is_better=False) callbacks = [ ('train_bal_acc', train_bal_acc), ('valid_bal_acc', valid_bal_acc) ] clf = EEGClassifier( model, criterion=torch.nn.CrossEntropyLoss, criterion__weight=torch.Tensor(class_weights).to(device), optimizer=torch.optim.Adam, iterator_train__shuffle=False, iterator_train__sampler=train_sampler, iterator_valid__sampler=valid_sampler, train_split=predefined_split(valid_set), # using valid_set for validation optimizer__lr=lr, batch_size=batch_size, callbacks=callbacks, device=device, classes=np.unique(y_train), ) #%% ########################################################################## # Run training # ------------ # Model training for a specified number of epochs. `y` is None as it is already # supplied in the dataset. clf.fit(train_set, y=None, epochs=n_epochs) #%% ########################################################################## # The rest of the code (e.g. plotting) is same as in original script. ```
agramfort commented 4 months ago

is it doing something conceptually different? If the API of braindecode is not ideal for your usecase do you have a suggestion to improve it? Note that the idea is that boiler plate code you would need to write to run a similar experiment in the future (eg on different dataset) is minimal.

OverLordGoldDragon commented 4 months ago

Braindecode isn't doing anything "wrong" here, it's sort of like me rewriting Conv1d just so I can understand it. In this case, I wanted to understand how exactly data is routed from .edf into model(x, y), and Braindecode was too high-level for that, and contained too many if config1, elif config2, etc to make it easy to understand.

A few specifics, I wanted to 1) store the processed .edf as .npy for future reuse, 2) ensure the lib isn't doing anything I disagree with. For 1), it wasn't clear to me how to do it without involving windows (I don't necessarily wish to replicate Chambon2018 there). For 2), I did find something I rather do different - I favor not excluding data (undersampling).

So I imagine this code is helpful to someone who wants more transparency and control over .edf -> model(x, y). I don't see it as something for Braindecode itself to improve on.

agramfort commented 4 months ago

ok thanks. Then I would just put this in https://gist.github.com for people to eventually look at in the future. So it's not lost

Message ID: @.***>

OverLordGoldDragon commented 4 months ago

It won't really be visible here, I think it'd help some people if it were linked in the example. Up to maintainers, though.

https://gist.github.com/OverLordGoldDragon/b6709fd266929f90c7979fb1d5635c4b

agramfort commented 4 months ago

ok to have a PR that links to it.