pytorch / examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
https://pytorch.org/examples
BSD 3-Clause "New" or "Revised" License
22.36k stars 9.53k forks source link

The GPU load is unbalanced #1078

Closed lianchengmingjue closed 2 years ago

lianchengmingjue commented 2 years ago

https://github.com/pytorch/examples/blob/2ee8d43dbe420be152fd5ce0d80b43b419a0e352/distributed/ddp-tutorial-series/multigpu_torchrun.py#L39 When I run the code and resume from a existed .pt file. The memory usage of GPU0 is significantly higher than other GPUs. It can be solved by adding a parameter "map_location". snapshot = torch.load(snapshot_path, map_location=torch.device('cuda', int(os.environ["LOCAL_RANK"])))

My Environment

cudatoolkit 10.2 pytorch 12.1

hudeven commented 2 years ago

@lianchengmingjue good catch! By default, torch.load() first loads the snapshot to CPU then moves to the device it was saved from(I guess it's GPU0). In this case, all ranks load the snapshot to GPU0. We should always use "map_location" in torch.load() to load files saved in other environment. Because it might be saved in GPUx which doesn't exist in your host and cause a failure during loading. Please feel free to send a PR for the fix. cc: @suraj813