Closed lianchengmingjue closed 2 years ago
@lianchengmingjue good catch! By default, torch.load() first loads the snapshot to CPU then moves to the device it was saved from(I guess it's GPU0). In this case, all ranks load the snapshot to GPU0. We should always use "map_location" in torch.load() to load files saved in other environment. Because it might be saved in GPUx which doesn't exist in your host and cause a failure during loading. Please feel free to send a PR for the fix. cc: @suraj813
https://github.com/pytorch/examples/blob/2ee8d43dbe420be152fd5ce0d80b43b419a0e352/distributed/ddp-tutorial-series/multigpu_torchrun.py#L39 When I run the code and resume from a existed .pt file. The memory usage of GPU0 is significantly higher than other GPUs. It can be solved by adding a parameter "map_location".
snapshot = torch.load(snapshot_path, map_location=torch.device('cuda', int(os.environ["LOCAL_RANK"])))
My Environment
cudatoolkit 10.2 pytorch 12.1