mims-harvard / TFC-pretraining

Self-supervised contrastive learning for time series via time-frequency consistency
https://zitniklab.hms.harvard.edu/projects/TF-C/
MIT License
439 stars 81 forks source link

DataTransform_FD funtion implement is inconsistent with the paper? #10

Closed iorange-77 closed 1 year ago

iorange-77 commented 1 year ago

In DataTransform_FD funtion: li = np.random.randint(0, 2, size=[sample.shape[0]]) # there are two augmentations in Frequency domain li_onehot = one_hot_encoding(li) aug_1[1 - li_onehot[:, 0]] = 0 # the rows are not selected are set as zero. aug_2[1 - li_onehot[:, 1]] = 0

The above four lines of code only assign 0 to the values of rows 0 and 1 in aug_1 and aug_2, which is different from removing frequency components and adding frequency components.

xiangzhang1015 commented 1 year ago

Hi,

Thanks for pointing that out. The four lines you mentioned are the old version of frequency augmentation, they should be commented out. The augmentations are implemented by the following:

def DataTransform_FD(sample, config):
    aug_1 = remove_frequency(sample, pertub_ratio=0.1)
    aug_2 = add_frequency(sample, pertub_ratio=0.1)
    aug_F = aug_1 + aug_2
    return aug_F

def remove_frequency(x, pertub_ratio=0.0):
    mask = torch.FloatTensor(x.shape).uniform_() > pertub_ratio  # maskout_ratio are False
    mask = mask.to(x.device)
    return x*mask

def add_frequency(x, pertub_ratio=0.0):

    mask = torch.FloatTensor(x.shape).uniform_() > (1-pertub_ratio)  # only pertub_ratio of all values are True
    mask = mask.to(x.device)

Moreover, we have updated the TFC implementation. Please check more details in the Updates on Jan 2023 section of the repo readme. In summary: