Below are the path to run test of vanilla UNet module:
Unet - pytest tests/ttnn/integration_tests/vanilla_unet/test_ttnn_unet.py
Currently,
Torch convTranspose is used.
One torch convolution.
One torch maxpool
one torch sigmoid - keeping it to ttnn reduces the pcc to 0.0. Should check it.
Many TM operations like to_device, ttnn.reshape and from_device are used to bypass OOM issue and statical buffer issue in the pipeline. Will reduce it.
Commit-WIP for unet inference removing addtional TM ops.
Model card - https://github.com/tenstorrent/tt-metal/issues/13272
The ttnn implementation of vanilla UNet is in branch punith/vanilla_unet_inference The test file are in path.
Current PCC is 0.93.
Pending issues related to vanilla UNet model:
6326
13324
13336
Below are the path to run test of vanilla UNet module: Unet -
pytest tests/ttnn/integration_tests/vanilla_unet/test_ttnn_unet.py
Currently, Torch convTranspose is used. One torch convolution. One torch maxpool one torch sigmoid - keeping it to ttnn reduces the pcc to 0.0. Should check it.
Many TM operations like to_device, ttnn.reshape and from_device are used to bypass OOM issue and statical buffer issue in the pipeline. Will reduce it.
Commit-WIP for unet inference removing addtional TM ops.