johndpope / MegaPortrait-hack

Using Claude Opus to reverse engineer code from MegaPortraits: One-shot Megapixel Neural Head Avatars
https://arxiv.org/abs/2207.07621
82 stars 8 forks source link

Stage1 training- warpgenerator blows up OOM #11

Closed johndpope closed 6 months ago

johndpope commented 6 months ago

QUESTIONS

what is the size of volumetric features (vs) ?

(vs.shape: torch.Size([1, 96, 16, 32, 32])) OR using the idk_avgpool will get (vs.shape: torch.Size([1, 96, 16, 64, 64]))

what should the size of WarpField be?

WarpField > zs.shape: torch.Size([1, 512, 8, 8])
conv1x1 > zs.shape: torch.Size([1, 512, 8, 8])
reshape_layer > zs.shape: torch.Size([1, 512, 8, 8])
ResBlock_Custom > x.shape: torch.Size([512, 512, 4, 8, 8])
upsample1 > zs.shape: torch.Size([1, 512, 8, 8])
ResBlock_Custom > x.shape: torch.Size([512, 256, 8, 16, 16])
upsample2 > zs.shape: torch.Size([1, 512, 8, 8])
ResBlock_Custom > x.shape: torch.Size([512, 128, 16, 32, 32])

  x = self.upsample3(self.resblock3(x)) 💣💣💣💣💣 blows up here.
__getitem__
frame_idx
frame_idx
🌸
.
source_frame.shape: torch.Size([1, 3, 256, 256])
ResBlock_Custom > x.shape: torch.Size([1, 64, 256, 256])
ResBlock_Custom > x.shape: torch.Size([1, 128, 128, 128])
ResBlock_Custom > x.shape: torch.Size([1, 256, 64, 64])
ResBlock_Custom > x.shape: torch.Size([1, 96, 16, 32, 32])
ResBlock_Custom > x.shape: torch.Size([1, 96, 16, 32, 32])
vs.shape: torch.Size([1, 96, 16, 32, 32])
👤 head_pose shape Should print: torch.Size([1, 6]): torch.Size([1, 6])
📐 rotation shape Should print: torch.Size([1, 3]): torch.Size([1, 3])
📷 translation shape Should print: torch.Size([1, 3]): torch.Size([1, 3])
pitch: tensor([0.0011], grad_fn=<SelectBackward0>)
yaw: tensor([-0.0213], grad_fn=<SelectBackward0>)
roll: tensor([0.0251], grad_fn=<SelectBackward0>)
x.shape: torch.Size([1, 3, 256, 256])
👤 head_pose shape Should print: torch.Size([1, 6]): torch.Size([1, 6])
📐 rotation shape Should print: torch.Size([1, 3]): torch.Size([1, 3])
📷 translation shape Should print: torch.Size([1, 3]): torch.Size([1, 3])
pitch: tensor([0.0041], grad_fn=<SelectBackward0>)
yaw: tensor([-0.0328], grad_fn=<SelectBackward0>)
roll: tensor([0.0263], grad_fn=<SelectBackward0>)
x.shape: torch.Size([1, 3, 256, 256])
es shape: torch.Size([1, 512, 8, 8])
zs shape: torch.Size([1, 512, 8, 8])
WarpField > zs.shape: torch.Size([1, 512, 8, 8])
conv1x1 > zs.shape: torch.Size([1, 512, 8, 8])
reshape_layer > zs.shape: torch.Size([1, 512, 8, 8])
ResBlock_Custom > x.shape: torch.Size([512, 512, 4, 8, 8])
upsample1 > zs.shape: torch.Size([1, 512, 8, 8])
ResBlock_Custom > x.shape: torch.Size([512, 256, 8, 16, 16])
upsample2 > zs.shape: torch.Size([1, 512, 8, 8])
ResBlock_Custom > x.shape: torch.Size([512, 128, 16, 32, 32])
Traceback (most recent call last):
  File "/media/oem/12TB/MegaPortrait-hack/train.py", line 310, in <module>
    main(config)
  File "/media/oem/12TB/MegaPortrait-hack/train.py", line 294, in main
    train_base(cfg, Gbase, Dbase, dataloader)
  File "/media/oem/12TB/MegaPortrait-hack/train.py", line 130, in train_base
    output_frame = Gbase(source_frame, driving_frame)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/oem/12TB/MegaPortrait-hack/model.py", line 957, in forward
    w_em_s2c = self.warp_generator_s2c(Rs, ts, zs, es)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/oem/12TB/MegaPortrait-hack/model.py", line 803, in forward
    w_em_s2c = self.warpfield(zs_sum)
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/oem/12TB/MegaPortrait-hack/model.py", line 343, in forward
    x = self.upsample3(self.resblock3(x))
                       ^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/media/oem/12TB/MegaPortrait-hack/model.py", line 69, in forward
    out2 = self.conv_res(x)
           ^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 610, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/oem/miniconda3/envs/ani/lib/python3.11/site-packages/torch/nn/modules/conv.py", line 605, in _conv_forward
    return F.conv3d(
           ^^^^^^^^^
RuntimeError: [enforce fail at alloc_cpu.cpp:83] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 115964116992 bytes. Error code 12 (Cannot allocate memory)

UPDATE - I change this line to use just the same avgpool from previous steps. https://github.com/johndpope/MegaPortrait-hack/blob/main/model.py#L149

Screenshot from 2024-05-22 13-57-17

source_frame.shape: torch.Size([1, 3, 256, 256])
ResBlock_Custom > x.shape: torch.Size([1, 64, 256, 256])
ResBlock_Custom > x.shape: torch.Size([1, 128, 128, 128])
ResBlock_Custom > x.shape: torch.Size([1, 256, 64, 64])
ResBlock_Custom > x.shape: torch.Size([1, 96, 16, 32, 32])
ResBlock_Custom > x.shape: torch.Size([1, 96, 16, 32, 32])

fyi @xuzheyuan624

johndpope commented 6 months ago

using 1d for features / resnet seems to fix this