owenliang / mnist-dits

Diffusion Transformers (DiTs) trained on MNIST dataset
55 stars 11 forks source link

训练报错 #1

Open zhangzhizhongz3 opened 7 months ago

zhangzhizhongz3 commented 7 months ago

代码下载到本地后,安装完依赖包,运行train.py报错,报错如下,麻烦帮忙看下什么原因哈 /Users/zhangzhizhong/PycharmProjects/mnist-dits/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 8 (cpuset is not taken into account), which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(_create_warning_msg( /Users/zhangzhizhong/PycharmProjects/mnist-dits/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 8 (cpuset is not taken into account), which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary. warnings.warn(_create_warning_msg( Traceback (most recent call last): File "", line 1, in File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/spawn.py", line 120, in spawn_main exitcode = _main(fd, parent_sentinel) ^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/spawn.py", line 129, in _main prepare(preparation_data) File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/spawn.py", line 240, in prepare _fixup_main_from_path(data['init_main_from_path']) File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/spawn.py", line 291, in _fixup_main_from_path main_content = runpy.run_path(main_path, ^^^^^^^^^^^^^^^^^^^^^^^^^ File "", line 291, in run_path File "", line 98, in _run_module_code File "", line 88, in _run_code File "/Users/zhangzhizhong/PycharmProjects/mnist-dits/train.py", line 38, in for imgs,labels in dataloader: File "/Users/zhangzhizhong/PycharmProjects/mnist-dits/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 434, in iter self._iterator = self._get_iterator() ^^^^^^^^^^^^^^^^^^^^ File "/Users/zhangzhizhong/PycharmProjects/mnist-dits/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 387, in _get_iterator return _MultiProcessingDataLoaderIter(self) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/zhangzhizhong/PycharmProjects/mnist-dits/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1040, in init w.start() File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/process.py", line 121, in start self._popen = self._Popen(self) ^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/context.py", line 224, in _Popen return _default_context.get_context().Process._Popen(process_obj) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/context.py", line 288, in _Popen return Popen(process_obj) ^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/popen_spawn_posix.py", line 32, in init super().init(process_obj) File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/popen_fork.py", line 19, in init self._launch(process_obj) File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/popen_spawn_posix.py", line 42, in _launch prep_data = spawn.get_preparation_data(process_obj._name) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/spawn.py", line 158, in get_preparation_data _check_not_importing_main() File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/spawn.py", line 138, in _check_not_importing_main raise RuntimeError(''' RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase.

    This probably means that you are not using fork to start your
    child processes and you have forgotten to use the proper idiom
    in the main module:

        if __name__ == '__main__':
            freeze_support()
            ...

    The "freeze_support()" line can be omitted if the program
    is not going to be frozen to produce an executable.

Traceback (most recent call last): File "/Users/zhangzhizhong/PycharmProjects/mnist-dits/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1133, in _try_get_data data = self._data_queue.get(timeout=timeout) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/queues.py", line 113, in get if not self._poll(timeout): ^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/connection.py", line 256, in poll return self._poll(timeout) ^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/connection.py", line 423, in _poll r = wait([self], timeout) ^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/connection.py", line 930, in wait ready = selector.select(timeout) ^^^^^^^^^^^^^^^^^^^^^^^^ File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/selectors.py", line 415, in select fd_event_list = self._selector.poll(timeout) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/zhangzhizhong/PycharmProjects/mnist-dits/venv/lib/python3.11/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler _error_if_any_worker_fails() RuntimeError: DataLoader worker (pid 5519) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/Users/zhangzhizhong/PycharmProjects/mnist-dits/train.py", line 38, in for imgs,labels in dataloader: File "/Users/zhangzhizhong/PycharmProjects/mnist-dits/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 631, in next data = self._next_data() ^^^^^^^^^^^^^^^^^ File "/Users/zhangzhizhong/PycharmProjects/mnist-dits/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1329, in _next_data idx, data = self._get_data() ^^^^^^^^^^^^^^^^ File "/Users/zhangzhizhong/PycharmProjects/mnist-dits/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1295, in _get_data success, data = self._try_get_data() ^^^^^^^^^^^^^^^^^^^^ File "/Users/zhangzhizhong/PycharmProjects/mnist-dits/venv/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1146, in _try_get_data raise RuntimeError(f'DataLoader worker (pid(s) {pids_str}) exited unexpectedly') from e RuntimeError: DataLoader worker (pid(s) 5517, 5519) exited unexpectedly

IcedAmericanoooo commented 6 months ago

这个报错是需要把train()那个文件里面的代码放到if name == 'main': 下面去执行,具体如下: ` if name=='main': iter_count=0 for epoch in range(EPOCH): for imgs,labels in dataloader: x=imgs*2-1 # 图像的像素范围从[0,1]转换到[-1,1],和噪音高斯分布范围对应 t=torch.randint(0,T,(imgs.size(0),)) # 为每张图片生成随机t时刻 y=labels

        x,noise=forward_add_noise(x,t) # x:加噪图 noise:噪音
        pred_noise=model(x.to(DEVICE),t.to(DEVICE),y.to(DEVICE))

        loss=loss_fn(pred_noise,noise.to(DEVICE))

        optimzer.zero_grad()
        loss.backward()
        optimzer.step()

        if iter_count%1000==0:
            print('epoch:{} iter:{},loss:{}'.format(epoch,iter_count,loss))
            torch.save(model.state_dict(),'.model.pth')
            os.replace('.model.pth','model.pth')
        iter_count+=1

` 调整的时候注意代码的缩进,不然还是会报错