Closed iorange-77 closed 1 year ago
Hi,
Thanks for pointing that out. The four lines you mentioned are the old version of frequency augmentation, they should be commented out. The augmentations are implemented by the following:
def DataTransform_FD(sample, config):
aug_1 = remove_frequency(sample, pertub_ratio=0.1)
aug_2 = add_frequency(sample, pertub_ratio=0.1)
aug_F = aug_1 + aug_2
return aug_F
def remove_frequency(x, pertub_ratio=0.0):
mask = torch.FloatTensor(x.shape).uniform_() > pertub_ratio # maskout_ratio are False
mask = mask.to(x.device)
return x*mask
def add_frequency(x, pertub_ratio=0.0):
mask = torch.FloatTensor(x.shape).uniform_() > (1-pertub_ratio) # only pertub_ratio of all values are True
mask = mask.to(x.device)
Moreover, we have updated the TFC implementation. Please check more details in the Updates on Jan 2023 section of the repo readme. In summary:
In DataTransform_FD funtion: li = np.random.randint(0, 2, size=[sample.shape[0]]) # there are two augmentations in Frequency domain li_onehot = one_hot_encoding(li) aug_1[1 - li_onehot[:, 0]] = 0 # the rows are not selected are set as zero. aug_2[1 - li_onehot[:, 1]] = 0
The above four lines of code only assign 0 to the values of rows 0 and 1 in aug_1 and aug_2, which is different from removing frequency components and adding frequency components.