kekmodel / FixMatch-pytorch

Unofficial PyTorch implementation of "FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence"
MIT License
758 stars 170 forks source link

Lack of seed during random labeled image selection may be leading to better performance on training resumption #41

Closed dsouzinator closed 3 years ago

dsouzinator commented 3 years ago

In the function x_u_split(...) in dataset/cifar.py the labeled images are generated without a seed. If training runs consist of multiple start and stops then it is possible that total number of labeled images that the model sees exceeds the set value. For instance, training on 40 labels with 2 stops will lead to 120 unique labeled images over the entire course of training even though the model only sees 40 labeled images at a time. I think this can explain the much higher accuracy obtained by this implementation, especially for the low label tasks.

A quick fix would be adding below snippet before random label generation in the x_u_split(...) function.

np.random.seed(args.seed)

lesterlitch commented 3 years ago

This feels like it should be a high priority. Will you accept a PR?

zzzjoey commented 3 years ago

@dsouzinator hello, I think the seed is already added in https://github.com/kekmodel/FixMatch-pytorch/blob/f54946074fba383e28320d8f50b627eabd0c7e3c/train.py#L35

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

and https://github.com/kekmodel/FixMatch-pytorch/blob/f54946074fba383e28320d8f50b627eabd0c7e3c/train.py#L175

    if args.seed is not None:
        set_seed(args)

Then, the seed works in all the code, doesn't it? If my idea is wrong, please correct it~

dsouzinator commented 3 years ago

@zou-yiqi That is correct, however the seed isn't used in the x_u_split() function and so the labels are generated randomly.
I would suggest modifying the code as I've described above as this repo doesn't seem to be maintained anymore.

zzzjoey commented 3 years ago

What I mean is that the the seed has been declared and called in main() function, and x_u_split() is used in DATASET_GETTERS() function, which is included in main(), so the seed declared in main() can take effect in the x_u_split(). Thus, we don't need to add np.random.seed(args.seed) in the x_u_split(...) function. ^ ^

dsouzinator commented 3 years ago

@zou-yiqi Yes you're right. The seed does indeed work in that function. Thanks!