NVIDIA / modulus

Open-source deep-learning framework for building, training, and fine-tuning deep learning models using state-of-the-art Physics-ML methods
https://developer.nvidia.com/modulus
Apache License 2.0
795 stars 172 forks source link

Fix bug for crashing when positional embedding disabled SongUNet #507

Closed daviddpruitt closed 1 month ago

daviddpruitt commented 1 month ago

Modulus Pull Request

Description

SongUNetPosEmbed has the option to disable positional embeding by setting N_grid_channels to 0. This change keeps the forward pass from crashing by adding a check to make sure the positional embedding exists. It also elimates some code duplication between the positional embed and non-positional embed version.

Closes https://github.com/NVIDIA/modulus/issues/504

Closes https://github.com/NVIDIA/modulus/issues/529

Checklist

Dependencies

mnabian commented 1 month ago

/blossom-ci

mnabian commented 1 month ago

@akshaysubr we do have unit tests for this model: https://github.com/NVIDIA/modulus/blob/main/test/models/diffusion/test_song_unet_pos_embd.py#L41

daviddpruitt commented 1 month ago

/blossom-ci