Open Seventeen17 opened 1 year ago
@AlexWertheim @alanwaketan any insight?
Seems that the unflattened views of parameters live throughout the forward and backward and are not freed with _free_full_params
.
For example:
import numpy as np
import torch
import torch.nn.functional as F
from torch_xla.core import xla_model as xm
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
class Submodule1(torch.nn.Module):
def __init__(self):
super(Submodule1, self).__init__()
self.fc1 = torch.nn.Linear(1280, 1280, bias=False)
self.fc2 = torch.nn.Linear(1280, 1280, bias=False)
self.fc3 = torch.nn.Linear(1280, 1280, bias=False)
self.fc4 = torch.nn.Linear(1280, 1280, bias=False)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)
output = F.softmax(x, dim=1)
return output
class Submodule2(torch.nn.Module):
def __init__(self):
super(Submodule2, self).__init__()
self.fc1 = torch.nn.Linear(1280, 1280, bias=False)
self.fc2 = torch.nn.Linear(1280, 1280, bias=False)
self.fc3 = torch.nn.Linear(1280, 1280, bias=False)
self.fc4 = torch.nn.Linear(1280, 1280, bias=False)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
x = self.fc4(x)
output = F.softmax(x, dim=1)
return output
class Layer(torch.nn.Module):
def __init__(self):
super(Layer, self).__init__()
self.sb1 = Submodule1()
self.sb2 = Submodule2()
def forward(self, x):
a = self.sb1(x)
b = self.sb2(a)
output = F.softmax(b, dim=1)
return output
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = Layer()
self.fc2 = Layer()
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
output = F.softmax(x, dim=1)
return output
def train(model, device, optimizer):
model.train()
for i in range(3):
optimizer.zero_grad()
input = torch.randn(64, 1280, device=device)
loss = model(input).sum()
loss.backward()
optimizer.step()
# Clean unflatten params views
xm.mark_step()
import functools
def main():
model = Net()
device = xm.xla_device()
xm.set_replication(device, [device])
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
auto_wrap_policy = functools.partial(
fsdp.transformer_auto_wrap_policy,
transformer_layer_cls={Layer},
)
model = FSDP(model, auto_wrap_policy=auto_wrap_policy, flatten_parameters=True)
model.to(device)
train(model, device, optimizer)
if __name__ == '__main__':
main()
flatten_parameters=False
Peak mem is 324102928 (309.09MiB);
flatten_parameters=True
Peak mem is 787751448 (751.26MiB);
Live ranges at 508 (peak):
reduce-scatter.0: 52428800 bytes
p2.308: 52428800 bytes
p1.2: 8 bytes
concatenate.2: 327680 bytes
p0.1: 4 bytes
slice.11: 6553600 bytes
custom-call: 327680 bytes
slice.12: 6553600 bytes
custom-call.1: 327680 bytes
slice.13: 6553600 bytes
custom-call.2: 327680 bytes
slice.14: 6553600 bytes
constant_254: 4 bytes
divide.11: 327680 bytes
slice.15: 6553600 bytes
custom-call.4: 327680 bytes
slice.16: 6553600 bytes
slice.17: 6553600 bytes
slice.18: 6553600 bytes
p3.437: 52428800 bytes
slice.19: 6553600 bytes
slice.20: 6553600 bytes
slice.21: 6553600 bytes
slice.22: 6553600 bytes
slice.23: 6553600 bytes
slice.24: 6553600 bytes
slice.25: 6553600 bytes
slice.26: 6553600 bytes
divide.9: 327680 bytes
constant_266: 13107200 bytes
broadcast.168: 52428800 bytes
constant_267: 13107200 bytes
constant_269: 13107200 bytes
constant_270: 13107200 bytes
constant_271: 13107200 bytes
constant_272: 13107200 bytes
constant_273: 13107200 bytes
constant_274: 13107200 bytes
tuple.859: 16 bytes
all-reduce-start.2: 16 bytes
custom-call.31: 327680 bytes
pad.10: 52428800 bytes
custom-call.33: 327680 bytes
pad.11: 52428800 bytes
fusion.23: 16 bytes
fusion.23: 16 bytes
fusion.23: 16 bytes
p5.1188: 52428800 bytes
p4.1180: 52428800 bytes
The slices are from the unflattened views of parameters. Combining with the HLO computation graph, there are many bitcast operations in the output, which are from slices.
ROOT tuple.41 = (f32[13107200]{0}, f32[13107200]{0}, f32[64,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=5*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=10*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, /*index=15*/f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[1280,1280]{1,0}, f32[], /*index=20*/f32[13107200]{0}, f32[13107200]{0}) tuple(get-tuple-element.83, get-tuple-element.84, get-tuple-element.85, bitcast.1759, bitcast.2852, /*index=5*/bitcast.2824, bitcast.2792, bitcast.2739, bitcast.2699, bitcast.2667, /*index=10*/bitcast.2635, bitcast.2537, bitcast.2488, bitcast.2458, bitcast.2428, /*index=15*/bitcast.2377, bitcast.2339, bitcast.2309, bitcast.2276, reduce.1209, /*index=20*/get-tuple-element.103, get-tuple-element.104)
flatten_parameters=True
and remove them from the output.When removing this unflattened views of parameters from the output, the output becomes:
ROOT tuple.41 = (f32[13107200]{0}, f32[13107200]{0}, f32[64,1280]{1,0}, f32[], f32[13107200]{0}, /*index=5*/f32[13107200]{0}) tuple(get-tuple-element.83, get-tuple-element.84, get-tuple-element.85, reduce.1205, get-tuple-element.87, /*index=5*/get-tuple-element.88)
And Peak Mem becomes 698950040 (666.57MiB) Fewer slices live at peak.
Live ranges at 476 (peak):
p2.308: 52428800 bytes
p1.2: 8 bytes
concatenate.2: 327680 bytes
p0.1: 4 bytes
custom-call: 327680 bytes
slice.12: 6553600 bytes
custom-call.1: 327680 bytes
slice.13: 6553600 bytes
custom-call.2: 327680 bytes
slice.14: 6553600 bytes
constant_254: 4 bytes
divide.11: 327680 bytes
slice.15: 6553600 bytes
custom-call.4: 327680 bytes
slice.16: 6553600 bytes
custom-call.5: 327680 bytes
slice.17: 6553600 bytes
custom-call.6: 327680 bytes
slice.18: 6553600 bytes
divide.2: 327680 bytes
divide.3: 327680 bytes
p3.437: 52428800 bytes
slice.19: 6553600 bytes
divide.9: 327680 bytes
multiply.20: 327680 bytes
constant_266: 13107200 bytes
broadcast.168: 52428800 bytes
constant_267: 13107200 bytes
constant_269: 13107200 bytes
constant_270: 13107200 bytes
constant_271: 13107200 bytes
constant_272: 13107200 bytes
constant_273: 13107200 bytes
constant_274: 13107200 bytes
custom-call.29: 327680 bytes
add.120: 52428800 bytes
tuple.859: 16 bytes
all-reduce-start.2: 16 bytes
broadcast.172: 52428800 bytes
add.121: 52428800 bytes
p5.1188: 52428800 bytes
p4.1180: 52428800 bytes
It seems that the peak memory is still much larger than flatten_parameters=False
, and I'm not sure if there is still room to reduce peak memory.
In TPU, we recommend the users don't use flatten_parameters and therefore this path is not tested very throughout.
If you have any fix for it, feel free to open up a PR and then we can discuss there. Otherwise, I will follow up later.
In TPU, we recommend the users don't use flatten_parameters and therefore this path is not tested very throughout.
If you have any fix for it, feel free to open up a PR and then we can discuss there. Otherwise, I will follow up later.
OK, I have some fixes, I will submit a PR for this later.
❓ Questions and Help
I have noticed during testing that enabling FSDP's flatten_parameter=True results in a significant increase in GPU Peak Memory. In fact, the memory usage is several times larger than the total parameter size. Have any of you encountered this issue before? If so, are there any known solutions or workarounds to address this problem? @ronghanghu @JackCaoG
Thank you in advance for your assistance!