mindspore-lab / mindone

one for all, Optimal generator with No Exception
https://mindspore-lab.github.io/mindone/
Apache License 2.0
332 stars 63 forks source link

Dit model definition #363

Closed wtomin closed 4 months ago

wtomin commented 4 months ago

The model definition file’ mindone/models/dit.py, with flash-attention support, based on @geniuspatrick 's implementation. When enabling FA, the inference results on 910B with ms 2.2.10 were ok.

Usage:

from mindone.models.dit import DiT_models
model = DiT_models["DiT_XL_2"](input_size=32, block_kwargs={"enable_flash_attention": True})

TODO: