bubbliiiing / yolox-pytorch

这是一个yolox-pytorch的源码,可以用于训练自己的模型。
Apache License 2.0
877 stars 184 forks source link

设置seed似乎有bug #154

Closed illrayy closed 6 months ago

illrayy commented 8 months ago
import numpy as np
import random
import torch
from torch.utils.data import Dataset, DataLoader

class RandomDataset(Dataset):
    def __getitem__(self, index):
        return np.random.randint(0, 1000, 3)

    def __len__(self):
        return 8

def seed_everything(seed=11):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def seed_everything_wrap(seed=11):
    def _init_fn(worker_id):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    return _init_fn

dataset = RandomDataset()
seed = 23
seed_everything(seed)

dataloader = DataLoader(dataset, batch_size=2, num_workers=4, worker_init_fn=seed_everything_wrap(seed))
for epoch in range(3):
    print(f"epoch: {epoch}")
    for batch in dataloader:
        print(batch)

输出是

epoch: 0
tensor([[595, 742,  40],
        [969, 950, 488]])
tensor([[595, 742,  40],
        [969, 950, 488]])
tensor([[595, 742,  40],
        [969, 950, 488]])
tensor([[595, 742,  40],
        [969, 950, 488]])
epoch: 1
tensor([[595, 742,  40],
        [969, 950, 488]])
tensor([[595, 742,  40],
        [969, 950, 488]])
tensor([[595, 742,  40],
        [969, 950, 488]])
tensor([[595, 742,  40],
        [969, 950, 488]])
epoch: 2
tensor([[595, 742,  40],
        [969, 950, 488]])
tensor([[595, 742,  40],
        [969, 950, 488]])
tensor([[595, 742,  40],
        [969, 950, 488]])
tensor([[595, 742,  40],
        [969, 950, 488]])

每个worker和epoch间的输出是一样的

illrayy commented 8 months ago

上述代码每次初始化worker的seed都是23,应该把worker_init_fn=seed_everything_wrap(seed)改成worker_init_fn=seed_everything_wrap

输出

epoch: 0
tensor([[ 81, 169, 427],
        [182, 596, 650]])
tensor([[977, 758, 359],
        [110, 376, 906]])
tensor([[202, 234, 280],
        [ 52, 717, 142]])
tensor([[337, 227, 759],
        [236, 373, 282]])
epoch: 1
tensor([[582, 419, 678],
        [522, 126, 356]])
tensor([[572,  10, 893],
        [164, 870, 733]])
tensor([[107, 345, 285],
        [702, 769, 716]])
tensor([[719, 632, 451],
        [749, 765, 522]])
epoch: 2
tensor([[661, 964, 120],
        [590, 768, 306]])
tensor([[405, 747, 811],
        [331, 718, 927]])
tensor([[315, 521,  82],
        [417, 380, 333]])
tensor([[856, 861, 761],
        [422, 719, 197]])