pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.5k stars 482 forks source link

FSDP flatten_parameter=True causing excessive memory consumption #5464

Open Seventeen17 opened 1 year ago

Seventeen17 commented 1 year ago

❓ Questions and Help

I have noticed during testing that enabling FSDP's flatten_parameter=True results in a significant increase in GPU Peak Memory. In fact, the memory usage is several times larger than the total parameter size. Have any of you encountered this issue before? If so, are there any known solutions or workarounds to address this problem? @ronghanghu @JackCaoG

Thank you in advance for your assistance!

JackCaoG commented 1 year ago

@AlexWertheim @alanwaketan any insight?

Seventeen17 commented 1 year ago

Seems that the unflattened views of parameters live throughout the forward and backward and are not freed with _free_full_params.

Seventeen17 commented 1 year ago

For example:

import numpy as np
import torch
import torch.nn.functional as F
from torch_xla.core import xla_model as xm
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP

class Submodule1(torch.nn.Module):
  def __init__(self):
    super(Submodule1, self).__init__()
    self.fc1 = torch.nn.Linear(1280, 1280, bias=False)
    self.fc2 = torch.nn.Linear(1280, 1280, bias=False)
    self.fc3 = torch.nn.Linear(1280, 1280, bias=False)
    self.fc4 = torch.nn.Linear(1280, 1280, bias=False)

  def forward(self, x):
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    x = self.fc4(x)
    output = F.softmax(x, dim=1)
    return output

class Submodule2(torch.nn.Module):
  def __init__(self):
    super(Submodule2, self).__init__()
    self.fc1 = torch.nn.Linear(1280, 1280, bias=False)
    self.fc2 = torch.nn.Linear(1280, 1280, bias=False)
    self.fc3 = torch.nn.Linear(1280, 1280, bias=False)
    self.fc4 = torch.nn.Linear(1280, 1280, bias=False)

  def forward(self, x):
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    x = self.fc4(x)
    output = F.softmax(x, dim=1)
    return output

class Layer(torch.nn.Module):
  def __init__(self):
    super(Layer, self).__init__()
    self.sb1 = Submodule1()
    self.sb2 = Submodule2()

  def forward(self, x):
    a = self.sb1(x)
    b = self.sb2(a)
    output = F.softmax(b, dim=1)
    return output

class Net(torch.nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.fc1 = Layer()
    self.fc2 = Layer()

  def forward(self, x):
    x = self.fc1(x)
    x = self.fc2(x)
    output = F.softmax(x, dim=1)
    return output

def train(model, device, optimizer):
  model.train()
  for i in range(3):
    optimizer.zero_grad()
    input = torch.randn(64, 1280, device=device)
    loss = model(input).sum()
    loss.backward()
    optimizer.step()
    # Clean unflatten params views
    xm.mark_step()

import functools

def main():
  model = Net()
  device = xm.xla_device()
  xm.set_replication(device, [device])
  optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

  auto_wrap_policy = functools.partial(
      fsdp.transformer_auto_wrap_policy,
      transformer_layer_cls={Layer},
  )
  model = FSDP(model, auto_wrap_policy=auto_wrap_policy, flatten_parameters=True)
  model.to(device)

  train(model, device, optimizer)

if __name__ == '__main__':
  main()

With flatten_parameters=False

Peak mem is 324102928 (309.09MiB);

With flatten_parameters=True

Peak mem is 787751448 (751.26MiB);

Live ranges at 508 (peak):
    reduce-scatter.0: 52428800 bytes
    p2.308: 52428800 bytes
    p1.2: 8 bytes
    concatenate.2: 327680 bytes
    p0.1: 4 bytes
    slice.11: 6553600 bytes
    custom-call: 327680 bytes
    slice.12: 6553600 bytes
    custom-call.1: 327680 bytes
    slice.13: 6553600 bytes
    custom-call.2: 327680 bytes
    slice.14: 6553600 bytes
    constant_254: 4 bytes
    divide.11: 327680 bytes
    slice.15: 6553600 bytes
    custom-call.4: 327680 bytes
    slice.16: 6553600 bytes
    slice.17: 6553600 bytes
    slice.18: 6553600 bytes
    p3.437: 52428800 bytes
    slice.19: 6553600 bytes
    slice.20: 6553600 bytes
    slice.21: 6553600 bytes
    slice.22: 6553600 bytes
    slice.23: 6553600 bytes
    slice.24: 6553600 bytes
    slice.25: 6553600 bytes
    slice.26: 6553600 bytes
    divide.9: 327680 bytes
    constant_266: 13107200 bytes
    broadcast.168: 52428800 bytes
    constant_267: 13107200 bytes
    constant_269: 13107200 bytes
    constant_270: 13107200 bytes
    constant_271: 13107200 bytes
    constant_272: 13107200 bytes
    constant_273: 13107200 bytes
    constant_274: 13107200 bytes
    tuple.859: 16 bytes
    all-reduce-start.2: 16 bytes
    custom-call.31: 327680 bytes
    pad.10: 52428800 bytes
    custom-call.33: 327680 bytes
    pad.11: 52428800 bytes
    fusion.23: 16 bytes
    fusion.23: 16 bytes
    fusion.23: 16 bytes
    p5.1188: 52428800 bytes
    p4.1180: 52428800 bytes

The slices are from the unflattened views of parameters. Combining with the HLO computation graph, there are many bitcast operations in the output, which are from slices.

ROOT tuple.41 = (f32[13107200]{0}, f32[13107200]{0}, f32[64,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=5*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=10*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=15*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[], /*index=20*/f32[13107200]{0}, f32[13107200]{0}) tuple(get-tuple-element.83, get-tuple-element.84, get-tuple-element.85, bitcast.1759, bitcast.2852, /*index=5*/bitcast.2824, bitcast.2792, bitcast.2739, bitcast.2699, bitcast.2667, /*index=10*/bitcast.2635, bitcast.2537, bitcast.2488, bitcast.2458, bitcast.2428, /*index=15*/bitcast.2377, bitcast.2339, bitcast.2309, bitcast.2276, reduce.1209, /*index=20*/get-tuple-element.103, get-tuple-element.104)

With flatten_parameters=True and remove them from the output.

When removing this unflattened views of parameters from the output, the output becomes:

ROOT tuple.41 = (f32[13107200]{0}, f32[13107200]{0}, f32[64,1280]{1,0}, f32[], f32[13107200]{0}, /*index=5*/f32[13107200]{0}) tuple(get-tuple-element.83, get-tuple-element.84, get-tuple-element.85, reduce.1205, get-tuple-element.87, /*index=5*/get-tuple-element.88)

And Peak Mem becomes 698950040 (666.57MiB) Fewer slices live at peak.

Live ranges at 476 (peak):
    p2.308: 52428800 bytes
    p1.2: 8 bytes
    concatenate.2: 327680 bytes
    p0.1: 4 bytes
    custom-call: 327680 bytes
    slice.12: 6553600 bytes
    custom-call.1: 327680 bytes
    slice.13: 6553600 bytes
    custom-call.2: 327680 bytes
    slice.14: 6553600 bytes
    constant_254: 4 bytes
    divide.11: 327680 bytes
    slice.15: 6553600 bytes
    custom-call.4: 327680 bytes
    slice.16: 6553600 bytes
    custom-call.5: 327680 bytes
    slice.17: 6553600 bytes
    custom-call.6: 327680 bytes
    slice.18: 6553600 bytes
    divide.2: 327680 bytes
    divide.3: 327680 bytes
    p3.437: 52428800 bytes
    slice.19: 6553600 bytes
    divide.9: 327680 bytes
    multiply.20: 327680 bytes
    constant_266: 13107200 bytes
    broadcast.168: 52428800 bytes
    constant_267: 13107200 bytes
    constant_269: 13107200 bytes
    constant_270: 13107200 bytes
    constant_271: 13107200 bytes
    constant_272: 13107200 bytes
    constant_273: 13107200 bytes
    constant_274: 13107200 bytes
    custom-call.29: 327680 bytes
    add.120: 52428800 bytes
    tuple.859: 16 bytes
    all-reduce-start.2: 16 bytes
    broadcast.172: 52428800 bytes
    add.121: 52428800 bytes
    p5.1188: 52428800 bytes
    p4.1180: 52428800 bytes

It seems that the peak memory is still much larger than flatten_parameters=False, and I'm not sure if there is still room to reduce peak memory.

alanwaketan commented 1 year ago

In TPU, we recommend the users don't use flatten_parameters and therefore this path is not tested very throughout.

If you have any fix for it, feel free to open up a PR and then we can discuss there. Otherwise, I will follow up later.

Seventeen17 commented 1 year ago

In TPU, we recommend the users don't use flatten_parameters and therefore this path is not tested very throughout.

If you have any fix for it, feel free to open up a PR and then we can discuss there. Otherwise, I will follow up later.

OK, I have some fixes, I will submit a PR for this later.