Doubiiu / DynamiCrafter

[ECCV 2024] DynamiCrafter: Animating Open-domain Images with Video Diffusion Priors
Apache License 2.0
2.09k stars 165 forks source link

How to train DynamiCrafter model? #26

Closed dailingx closed 4 months ago

dailingx commented 4 months ago

Hi, first of all, thank you for the work you've done. I have used DynamiCrafter extensively, and its performance is quite impressive in most scenarios. However, I would like to train DynamiCrafter again in certain areas. May I ask if you have any plans to further open source the training code?

Doubiiu commented 4 months ago

Hi, Thanks for your recognition of our work! I am so excited to hear that! Currently we cannot release the training code due to the policy of the company. You can modify the training code from the repos (e.g., DreamBooth) using similar stable diffusion codebase. And I am willing to help on the training code implementation in case you have any questions. Thanks for your understanding.

dailingx commented 4 months ago

If I could receive your help, I would be extremely grateful! I might will try during my free time.

Doubiiu commented 4 months ago

Hi. I find the open-source code from MotionCtrl and LVDM may contain nearly-complete applicable training code for our DynamiCrafter (I think those code can be used with a minor modification)!

dailingx commented 4 months ago

All right, thank you very much for your enthusiastic suggestions. I will try them out and if they work, I will get back to you with feedback.

dailingx commented 4 months ago

I am currently modifying the training code based on MotionCtrl and LVDM, but I have encountered a problem. When executing xc = torch.cat([x] + c_concat, dim=1) in ddpm3d.py, I received an exception. I discovered that c_concat is None, but I am unsure of the possible reasons. Could you offer me some help or suggestions?

Traceback (most recent call last):
  File "/root/DynamiCrafter/train.py", line 913, in <module>
    run_inference(opt, unknown)
  File "/root/DynamiCrafter/train.py", line 394, in run_inference
    trainer.fit(model, data, ckpt_path=ckpt_resume_path)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 582, in fit
    call._call_and_handle_interrupt(
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 90, in launch
    return function(*args, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 624, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1061, in _run
    results = self._run_stage()
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1140, in _run_stage
    self._run_train()
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1163, in _run_train
    self.fit_loop.run()
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 214, in advance
    batch_output = self.batch_loop.run(kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
    outputs = self.optimizer_loop.run(optimizers, kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
    self.advance(*args, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 200, in advance
    result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 239, in _run_optimization
    closure()
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 147, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 133, in closure
    step_output = self._step_fn()
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 406, in _training_step
    training_step_output = self.trainer._call_strategy_hook("training_step", *kwargs.values())
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1443, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 352, in training_step
    return self.model(*args, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/pytorch_lightning/overrides/base.py", line 98, in forward
    output = self._forward_module.training_step(*inputs, **kwargs)
  File "/root/DynamiCrafter/lvdm/models/ddpm3d.py", line 945, in training_step
    loss, loss_dict = self.shared_step(batch, random_uncond=self.classifier_free_guidance)
  File "/root/DynamiCrafter/lvdm/models/ddpm3d.py", line 941, in shared_step
    loss, loss_dict = self(x, c, is_imgbatch=is_imgbatch, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/DynamiCrafter/lvdm/models/ddpm3d.py", line 635, in forward
    return self.p_losses(x, c, t, **kwargs)
  File "/root/DynamiCrafter/lvdm/models/ddpm3d.py", line 582, in p_losses
    model_output = self.apply_model(x_noisy, t, cond, **kwargs)
  File "/root/DynamiCrafter/lvdm/models/ddpm3d.py", line 654, in apply_model
    x_recon = self.model(x_noisy, t, **cond, **kwargs)
  File "/root/DynamiCrafter/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/DynamiCrafter/lvdm/models/ddpm3d.py", line 1022, in forward
    xc = torch.cat([x] + c_concat, dim=1)
TypeError: can only concatenate list (not "NoneType") to list

The code I've written is a bit messy at the moment for the sake of making debugging easier. If you think it's necessary, I can clean up my code and post it here.

dailingx commented 4 months ago

And, before executing the x_recon = self.model(x_noisy, t, **cond, **kwargs) line in theapply_model method, I printed out the values of the arguments.

before apply_model x_noisy:tensor([[[[[ 0.6330,  1.6359, -0.3459,  ...,  0.4664, -0.2241,  1.1442],
           [ 0.5582, -1.2860, -0.4914,  ...,  0.5169,  1.2178, -0.9432],
           [-0.2439,  0.2047,  0.2144,  ...,  1.4254, -0.9282, -0.6778],
           ...,
           [ 0.2135,  1.1049,  1.6460,  ..., -0.5320,  0.1514,  0.8292],
           [-1.7580, -1.3754, -1.5203,  ..., -0.0060,  0.1375, -0.1591],
           [ 0.3521, -0.9101,  0.5028,  ...,  1.1446,  0.4072, -0.4115]],

          [[-0.1798, -1.3728, -0.2989,  ..., -0.5143, -0.5442,  0.8477],
           [ 0.3904, -0.0622,  0.8833,  ...,  0.2036, -0.6342,  1.0128],
           [-0.1993,  1.6014,  0.4741,  ...,  0.4338, -1.0898, -1.6103],
           ...,
           [-0.3446,  0.1603, -0.5498,  ..., -0.8580, -0.7490, -0.9224],
           [ 0.7675,  0.5692, -0.6401,  ..., -1.1986, -0.6331,  2.0890],
           [ 1.4175,  0.9586,  1.3511,  ..., -0.5519, -1.1311, -1.0093]],

          [[ 0.8701,  1.1316,  1.2218,  ..., -0.4378, -0.8028, -0.3502],
           [-0.2945,  1.1588,  0.2435,  ..., -0.4715, -0.6702,  0.7000],
           [ 1.6519,  1.4466,  0.8277,  ...,  0.9485,  1.1874,  0.3733],
           ...,
           [ 0.4615,  0.8175, -0.2471,  ..., -1.0801,  0.3121,  0.3518],
           [ 0.4879, -0.6510, -0.4202,  ..., -0.8701,  0.9345,  0.5041],
           [-0.0404,  0.2464,  0.9424,  ...,  0.1231, -0.3309,  1.6892]],

          ...,

          [[ 1.4276, -0.8785, -0.6550,  ..., -0.4282, -0.4426, -0.5956],
           [-0.8304, -0.0327, -0.7125,  ...,  1.5553, -0.1035,  0.5816],
           [-1.2710,  0.8426,  1.0085,  ...,  0.2933,  0.9275,  1.0752],
           ...,
           [ 0.1568, -0.8294,  0.2037,  ..., -0.1425, -1.7909,  0.4976],
           [ 0.3881, -1.1931,  1.6041,  ...,  0.0126, -0.7115,  1.1592],
           [-1.3840,  0.5085,  1.0985,  ...,  0.5966, -0.8277,  0.7804]],

          [[ 1.2921,  0.7771, -0.3826,  ..., -0.3636,  2.8077, -0.7929],
           [-0.1544, -1.5512, -0.7539,  ..., -1.4662,  0.5829, -0.9639],
           [ 1.0731,  1.2996, -0.1486,  ..., -0.4025,  0.1680,  0.3997],
           ...,
           [-0.5194, -1.8464,  0.7514,  ...,  0.9617, -0.4599, -1.5068],
           [-0.5993,  1.5639,  0.9057,  ..., -0.5003, -0.3627,  0.6609],
           [-0.3751, -0.4982, -0.1645,  ..., -0.5126,  0.4400, -0.8680]],

          [[ 0.0506, -0.5334, -1.7426,  ..., -0.0828, -0.9772,  0.2350],
           [ 0.1841, -0.0824, -0.9359,  ...,  0.8865, -0.5761,  0.6088],
           [-0.6396,  0.0224, -0.3830,  ..., -0.2759, -1.4744, -0.1004],
           ...,
           [-0.6023, -0.7397, -1.0657,  ..., -0.3171, -0.4474, -0.7467],
           [ 0.7569,  0.6212,  1.2672,  ..., -1.3187, -1.3011, -1.2606],
           [-0.2566,  0.1753, -1.0995,  ...,  0.3700, -0.7587,  0.5180]]],

         [[[-0.7007, -0.3968, -0.6673,  ..., -0.0582,  0.5101, -0.0788],
           [-0.8553,  0.2696, -0.2799,  ..., -1.0320,  1.1036, -0.9813],
           [-1.1871, -1.6514,  0.5676,  ...,  0.1189, -0.9516, -0.6213],
           ...,
           [-0.5523,  0.1833, -0.1929,  ..., -0.5798, -0.0127,  0.0605],
           [ 0.3247, -0.3776, -1.2141,  ...,  0.6695, -0.9669, -2.3689],
           [ 2.8152, -0.5247, -1.5819,  ..., -0.2648, -0.4664, -0.3421]],

          [[-0.4205, -0.3026,  0.6852,  ...,  1.7331,  0.4727,  0.8688],
           [ 0.9051,  0.1339, -0.3265,  ...,  0.3224,  2.2012, -2.6522],
           [ 0.0907,  0.5114,  1.1341,  ...,  0.2336, -0.4208,  0.7859],
           ...,
           [ 1.7445,  0.2195,  0.7405,  ...,  1.7385, -0.6866, -1.2096],
           [ 0.0139,  0.0258,  0.2761,  ..., -0.4645,  1.5053,  0.5627],
           [-0.3358, -0.5730, -1.4222,  ..., -0.2574, -0.0991,  0.3812]],

          [[-0.9797, -0.6903, -0.4875,  ..., -0.0689, -0.2700, -0.7630],
           [ 0.8948, -1.9615, -1.8383,  ...,  0.2718, -0.1236, -0.9666],
           [ 0.8852, -1.2421,  0.2449,  ..., -0.2101, -1.0958, -0.9756],
           ...,
           [ 1.5825, -1.2854,  1.4575,  ..., -0.1846,  1.2341, -1.8712],
           [-1.3595, -0.7827,  1.0277,  ..., -0.2393,  2.1104, -0.0974],
           [ 0.4974,  0.7591,  0.2696,  ...,  0.0406,  1.4318,  0.0970]],

          ...,

          [[ 0.2071,  1.1945,  0.7944,  ...,  0.2637, -0.0711,  0.3732],
           [ 1.0452, -0.4186, -0.6442,  ..., -1.1041,  1.2582, -0.5950],
           [ 0.9942,  0.2659, -0.3252,  ...,  0.2385,  0.2041, -0.4829],
           ...,
           [ 0.7003,  0.6114, -1.1819,  ..., -0.9066,  0.3443,  0.3000],
           [-0.3537, -0.7937, -0.0695,  ..., -0.2473, -0.6041,  1.0689],
           [-2.5677,  0.6440,  0.8454,  ..., -3.2232, -0.1839,  0.0879]],

          [[-0.9494, -0.3865, -0.3185,  ...,  1.2244,  0.0397, -0.3752],
           [ 0.2035,  0.6046,  0.7864,  ...,  1.0381, -0.0193,  0.8743],
           [ 0.2603,  0.1765, -0.2286,  ..., -1.3656,  0.0446, -0.5018],
           ...,
           [-1.6689,  0.5927,  0.8846,  ..., -0.0513,  0.5365, -0.1375],
           [ 1.0883, -0.0448,  0.0073,  ..., -0.1346,  0.9589, -0.0138],
           [ 0.3376, -2.3550, -0.5476,  ...,  0.7114, -1.4700, -0.9710]],

          [[-0.8473, -2.6393,  0.4379,  ..., -0.4854, -0.6140,  0.5018],
           [-1.8717,  0.6351, -1.3148,  ...,  0.3458, -0.6454, -1.3258],
           [-0.0486, -2.1555, -0.3378,  ...,  0.1525, -0.7810, -0.4953],
           ...,
           [ 1.1660,  1.4899,  0.8824,  ...,  0.4967,  1.2138, -0.0317],
           [ 1.7983,  0.7902, -0.1636,  ..., -0.3580,  1.1348,  0.6644],
           [ 0.9908,  1.9705, -0.1984,  ...,  0.7564, -0.4021,  0.0264]]],

         [[[ 0.3202,  0.2512,  0.3615,  ...,  0.7096, -2.0672,  0.6247],
           [-0.8046,  0.4753, -0.6846,  ..., -0.9557,  1.3594,  2.9055],
           [-0.1680,  0.9408, -0.3982,  ..., -1.6581, -0.2824, -1.8855],
           ...,
           [ 2.7290, -0.2444,  0.4384,  ...,  2.5135,  0.5893, -0.7923],
           [-0.2419,  0.7847,  0.8562,  ...,  1.1125,  1.0775, -0.6251],
           [ 1.1884, -1.0511, -0.2957,  ..., -0.8401,  1.0509, -1.7649]],

          [[-0.8535, -1.5618, -1.0714,  ..., -3.0885,  0.8315,  1.3512],
           [ 0.2799,  0.7976,  1.6919,  ...,  0.7147,  0.5061,  1.3861],
           [ 1.6430,  1.2088,  1.4906,  ...,  0.1258,  1.0114,  0.2699],
           ...,
           [ 2.1476, -0.5004, -0.0744,  ...,  0.1957, -0.7652, -0.4899],
           [ 1.5919, -0.2760, -0.9083,  ..., -0.5874,  1.4495, -1.7833],
           [ 0.0430,  0.1627,  2.0945,  ...,  0.0439, -0.6050, -0.4374]],

          [[ 0.6050,  0.6711,  2.5406,  ..., -0.7991,  1.3779,  0.5299],
           [-0.3638,  0.2717, -0.7251,  ..., -0.1393,  0.6029, -0.2234],
           [-1.2485, -0.0918, -0.3179,  ..., -0.2198, -0.5713, -0.1962],
           ...,
           [-1.0437, -0.3164, -1.0581,  ...,  0.3052, -0.8312, -0.1260],
           [-0.1294,  0.2990, -0.2964,  ...,  0.1750,  0.9149,  0.6246],
           [ 0.0073, -0.1184,  0.6248,  ...,  0.5212, -0.0157, -1.1507]],

          ...,

          [[-0.0860, -0.4144, -2.1284,  ...,  0.9717, -0.7091,  1.2441],
           [ 0.4012, -0.8169, -0.0241,  ..., -0.1819,  0.5799, -0.9830],
           [-0.6765,  0.2017,  0.1920,  ..., -0.1320,  0.4345, -0.9356],
           ...,
           [ 0.1575,  0.1463,  0.2260,  ...,  1.3836, -0.0705, -0.5410],
           [-0.4299, -0.4388,  0.5853,  ...,  0.0089,  0.5869, -0.0290],
           [-1.2200, -1.7734,  0.0141,  ...,  1.4657, -1.1417,  0.1477]],

          [[ 0.7264,  0.7879,  0.2507,  ..., -0.6537, -0.2886,  0.2381],
           [-0.1846, -0.8822,  0.7409,  ..., -0.1555, -0.0400, -0.0699],
           [ 0.6763,  0.6368,  1.5692,  ..., -2.1924,  1.4581, -0.5598],
           ...,
           [-0.6149,  0.8785,  1.0608,  ...,  1.6276,  0.1647, -0.4927],
           [-0.2470,  0.1338,  0.5615,  ...,  0.5672, -0.0107,  0.7778],
           [-1.7780, -0.9217,  0.2153,  ...,  0.0656, -0.4935, -1.3488]],

          [[ 2.2285,  0.0637,  0.8748,  ...,  0.9770,  1.6146, -1.0328],
           [ 0.9438, -2.3367,  0.0236,  ...,  1.8826,  0.3429, -0.1575],
           [-0.5714,  0.1644,  1.7720,  ...,  0.0268, -0.6838, -1.3976],
           ...,
           [-0.2027,  0.3566,  0.3144,  ...,  0.6676,  0.7599,  1.9493],
           [ 0.2768,  1.0905,  1.5165,  ..., -0.7399, -0.3187, -0.3157],
           [ 1.1925, -0.6830,  1.0635,  ..., -0.6697, -1.4324,  0.7507]]],

         [[[-1.6935, -0.6795, -1.3030,  ..., -0.9283,  1.1761,  0.2018],
           [-0.7361,  0.4294,  0.0250,  ..., -1.1221, -0.1320, -0.2501],
           [ 0.3519,  1.3046,  1.6062,  ..., -0.6646,  0.2397, -2.6570],
           ...,
           [ 0.8527, -1.0824, -0.2106,  ...,  0.2457, -0.5659, -0.2388],
           [ 0.6503,  1.3721,  0.8686,  ..., -0.5459, -0.8351, -0.5538],
           [-0.9465,  0.6821, -0.5678,  ...,  0.6059,  0.1053,  1.2048]],

          [[-1.0123, -0.3732,  0.5756,  ...,  1.1956, -0.2499, -2.2606],
           [-0.7265,  0.6544, -0.5124,  ..., -0.2572, -0.4256, -0.1808],
           [-0.1108, -0.3307,  1.5815,  ...,  0.9133,  1.6324, -0.9623],
           ...,
           [ 0.4119,  0.6768, -1.4462,  ..., -2.4618, -0.9176,  0.0505],
           [-0.1560, -0.3143, -0.9139,  ..., -1.3517, -1.9616, -0.4699],
           [-0.8434, -1.3859,  0.4787,  ...,  0.8527, -0.1170,  1.2096]],

          [[ 0.9238,  1.5613,  0.6908,  ..., -1.8594,  0.2781, -0.9570],
           [ 0.9362, -0.4721, -1.0131,  ...,  2.3112, -0.6876, -0.1232],
           [-1.6031, -0.3794,  0.4095,  ...,  0.0373,  1.0330, -0.4273],
           ...,
           [ 0.7764,  0.2602, -2.0235,  ..., -0.8423, -0.3088, -1.4515],
           [-0.4438,  0.7066, -0.7716,  ...,  0.5780, -1.4184, -0.7170],
           [ 0.9330, -1.2661, -0.1468,  ..., -0.8263, -1.3447, -0.2290]],

          ...,

          [[-1.9911, -0.4204,  0.2456,  ..., -0.5817,  0.8184,  0.4457],
           [-0.8076, -0.5372,  0.1418,  ..., -0.8302,  1.9818,  0.4974],
           [ 0.4776,  1.0264,  0.1151,  ...,  0.4285,  0.4674,  0.0764],
           ...,
           [-1.3951,  2.1954,  0.4899,  ...,  0.3668, -0.3205,  0.8705],
           [-0.3102,  0.9328, -0.2645,  ...,  0.4023,  2.5277, -0.0227],
           [ 0.3007, -0.7179,  0.7867,  ...,  1.1380, -0.3468, -0.8741]],

          [[-0.9943,  0.0998,  0.4430,  ...,  0.8403, -0.8254,  1.1970],
           [ 0.3770, -1.8025,  0.1744,  ...,  1.3204,  1.9701,  0.2375],
           [ 0.0392,  0.5184, -0.6400,  ..., -0.4708,  1.0103, -1.2775],
           ...,
           [ 1.6931, -2.1019,  2.2762,  ..., -1.5555, -0.5061,  1.2843],
           [ 0.0399,  0.7755, -1.5403,  ..., -1.7261, -0.2246,  0.4361],
           [ 0.4483, -1.1071, -0.1735,  ..., -0.9498,  0.4710,  0.6874]],

          [[-1.1826,  0.8728,  0.5715,  ...,  0.1922, -1.3633, -0.4317],
           [ 0.3604,  0.1500, -0.5989,  ..., -0.5828,  0.1793,  0.2418],
           [-0.0564, -0.2495, -0.4312,  ..., -1.4631,  0.0740,  0.0290],
           ...,
           [ 0.8562,  0.6165, -0.3792,  ...,  2.4318,  0.8448, -0.0323],
           [-1.3332, -0.1700, -0.7215,  ..., -1.5220, -0.7542, -0.1331],
           [-0.2321,  0.9192,  0.7963,  ..., -0.8235,  1.0373,  1.3134]]]]],
       device='cuda:0'), t: tensor([992], device='cuda:0'), cond: {'c_crossattn': [tensor([[[-3.1343e-01, -4.4757e-01, -8.2413e-03,  ...,  2.5421e-01,
          -3.2432e-02, -2.9603e-01],
         [ 1.4114e+00,  7.5768e-03, -4.2885e-01,  ...,  1.0365e+00,
          -6.7341e-01,  1.5006e+00],
         [ 1.8676e+00, -1.0888e+00, -1.0656e+00,  ...,  2.0319e+00,
          -1.1396e+00, -1.8952e-01],
         ...,
         [ 6.0626e-03, -1.4990e+00, -3.9075e-01,  ..., -1.7927e-01,
          -3.2250e-01, -1.5238e-02],
         [ 2.7659e-02, -1.5532e+00, -4.1506e-01,  ..., -4.4162e-01,
          -3.9632e-01,  1.9090e-01],
         [-7.0735e-02, -2.6132e+00, -1.0513e+00,  ...,  1.0092e-03,
          -5.0300e-01,  4.0609e-01]]], device='cuda:0')]}, kwargs: {'is_imgbatch': False}
Doubiiu commented 4 months ago

Hi, I find there is no c_concat in your arguments. c_concat is the tensor concatenated with noisy latents. c_crossattn is the tensor fed into U-Net for cross-attention operation. I think you forget to pass the c_concat to the cond.

dailingx commented 4 months ago

Alright, how many dimensions should the c_concat tensor have?

Doubiiu commented 4 months ago

c_concat is the VDG in the paper, it should have the same dim as noisy latents [B,C,T,H,W]. The input conditional image latent should have dim [B,C,1,H,W] and then it will be duplicated in the time axis: [B,C,T,H,W]

dailingx commented 4 months ago

Thank you for your response. I will try to modify my code to correct it.

dwanggit commented 3 months ago

Thanks @Doubiiu for redirecting me to this post. @dailingx that is very impressive to reimplement the training code. I would like to help on this work. Feel free to add me on wechat if you are open to discuss, thanks.

dailingx commented 3 months ago

Thanks @Doubiiu for redirecting me to this post. @dailingx that is very impressive to reimplement the training code. I would like to help on this work. Feel free to add me on wechat (_dwang) if you are open to discuss, thanks.

Hi, thank you for reaching out and offering your help! I've sent a WeChat friend request to connect. Looking forward to discussing and collaborating on the project together.

dwanggit commented 3 months ago

Connected! Lets do it

caojiehui commented 2 months ago

Thanks @Doubiiu for redirecting me to this post. @dailingx that is very impressive to reimplement the training code. I would like to help on this work. Feel free to add me on wechat (_dwang) if you are open to discuss, thanks.

Hi, thank you for reaching out and offering your help! I've sent a WeChat friend request to connect. Looking forward to discussing and collaborating on the project together.

"Can I still join the discussion on the open-source training project? Can you add me?"

dailingx commented 2 months ago

Thanks @Doubiiu for redirecting me to this post. @dailingx that is very impressive to reimplement the training code. I would like to help on this work. Feel free to add me on wechat (_dwang) if you are open to discuss, thanks.

Hi, thank you for reaching out and offering your help! I've sent a WeChat friend request to connect. Looking forward to discussing and collaborating on the project together.

"Can I still join the discussion on the open-source training project? Can you add me?"

I think it is not necessary now, because @Doubiiu will release it in the next few days. link to #66

Doubiiu commented 2 months ago

The training code is available now! @dailingx @dwanggit @caojiehui

ypflll commented 2 months ago

@Doubiiu The training code is very helpful and thanks a lot. I wonder if V100 GPU with 32GB MEM is enough for training a interpolation model with resolution of 1024x576? Can you give me some memory comsuption data?