ycq091044 / ContraWR

JMIR AI'23: EEG dataset processing and EEG Self-supervised Learning
42 stars 6 forks source link

About augment function in utils #1

Closed YoloEliwa closed 2 years ago

YoloEliwa commented 2 years ago
def noise_channel(ts, mode, degree, bound):
    """
    Add noise to ts

    mode: high, low, both
    degree: degree of noise, compared with range of ts    

    Input:
        ts: (n_length)
    Output:
        out_ts: (n_length)

    """

noise_channel需要一个单通道时间序列,但是调用它的add_noiseremove_noise传给noise_channelx[i,:],函数说明中说 x: (n_length, n_channel),那每次传给noise_channel的x不是变成了某个采样点的全通道信号吗,请问一下作者这里是否是有什么问题

    def add_noise(self, x, ratio):
        """
        Add noise to multiple ts
        Input: 
            x: (n_length, n_channel)
        Output: 
            x: (n_length, n_channel)
        """
        for i in range(self.n_channels):
            if np.random.rand() > ratio:
                mode = np.random.choice(['high', 'low', 'both', 'no'])
                x[i,:] = noise_channel(x[i,:], mode=mode, degree=0.05, bound=self.bound)
        return x

    def remove_noise(self, x, ratio):
        """
        Remove noise from multiple ts
        Input: 
            x: (n_length, n_channel)
        Output: 
            x: (n_length, n_channel)
        """
        for i in range(self.n_channels):
            rand = np.random.rand()
            if rand > 0.75:
                x[i, :] = denoise_channel(x[i, :], self.bandpass1, self.signal_freq, bound=self.bound) +\
                        denoise_channel(x[i, :], self.bandpass2, self.signal_freq, bound=self.bound)
            elif rand > 0.5:
                x[i, :] = denoise_channel(x[i, :], self.bandpass1, self.signal_freq, bound=self.bound)
            elif rand > 0.25:
                x[i, :] = denoise_channel(x[i, :], self.bandpass2, self.signal_freq, bound=self.bound)
            else:
                pass

        return x
ycq091044 commented 2 years ago

Thanks for pointing out this. Sorry for causing misunderstandings!

I have correct the comments: (n_length, n_channel) -> (n_channel, n_length).

The input x is actually (n_channel, n_length). Thanks again!