Skip attention_mask_compress if compress_kv_factor is 1;
Set vae to fp16 in train_t2v.py using the argument vae_precision;
Using vae_keep_gn_fp32 to allow nn.GroupNorm in custom_fp32_cells. Defaults to False. It is verified by its inference quality. The training experiment with this commit is visually fine.
To improve speed, some modifications are made:
compress_kv_factor
is 1;train_t2v.py
using the argumentvae_precision
;vae_keep_gn_fp32
to allownn.GroupNorm
incustom_fp32_cells
. Defaults toFalse
. It is verified by its inference quality. The training experiment with this commit is visually fine.Some minor changes:
ProfilerCallback
changes toProfilerCallbackEpoch