Vchitect / Latte

Latte: Latent Diffusion Transformer for Video Generation.
Apache License 2.0
1.44k stars 147 forks source link

Error once speed up training #81

Open moeinheidari7829 opened 1 month ago

moeinheidari7829 commented 1 month ago

Dear authors, thank you for the great work and open source code,

I am training the model on own dataset, however once I train with the following arguments as True (to speed up training), I get the following error (When I make these commands False, the model trains with no error):

SPEED UP COMMANS:

use_compile: True mixed_precision: True enable_xformers_memory_efficient_attention: True gradient_checkpointing: True

Error:

Traceback (most recent call last): File "/scratch/st-ilker-1/moein/code/Latte/train.py", line 285, in main(OmegaConf.load(args.config)) File "/scratch/st-ilker-1/moein/code/Latte/train.py", line 162, in main update_ema(ema, model.module) #, decay=0) # Ensure EMA is initialized with synced weights File "/project/st-ilker-1/moein/moein-envs/latte-env/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) File "/scratch/st-ilker-1/moein/code/Latte/utils.py", line 200, in update_ema emaparams[name].mul(decay).add_(param.data, alpha=1 - decay) KeyError: '_orig_mod.pos_embed' WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 47691 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 47692 closing signal SIGTERM WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 47694 closing signal SIGTERM Traceback (most recent call last): sys.exit(main()) return f(args, **kwargs) return launch_agent(self._config, self._entrypoint, list(args)) raise ChildFailedError( torch.distributed.elastic.multiprocessing.errors.ChildFailedError:

Has any one had the same experience?

moeinheidari7829 commented 1 month ago

Update:

The command that is making the problem is : use_compile

maxin-cn commented 1 month ago

Update:

The command that is making the problem is : use_compile

Thank you for your interest. In the Latte code, I provide some apis that can accelerate training, but I don't check their correctness. torch.compile sets a module for the model, so the Pytorch class will be different between ema and model. If you solve these problems, we welcome your PR~