tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
430 stars 59 forks source link

Vanilla Unet Bringup #13343

Open punithsekar opened 1 week ago

punithsekar commented 1 week ago

Model card - https://github.com/tenstorrent/tt-metal/issues/13272

The ttnn implementation of vanilla UNet is in branch punith/vanilla_unet_inference The test file are in path.

Current PCC is 0.93.

Pending issues related to vanilla UNet model:

6326

13324

13336

Below are the path to run test of vanilla UNet module: Unet - pytest tests/ttnn/integration_tests/vanilla_unet/test_ttnn_unet.py

Currently, Torch convTranspose is used. One torch convolution. One torch maxpool one torch sigmoid - keeping it to ttnn reduces the pcc to 0.0. Should check it.

Many TM operations like to_device, ttnn.reshape and from_device are used to bypass OOM issue and statical buffer issue in the pipeline. Will reduce it.

Commit-WIP for unet inference removing addtional TM ops.

punithsekar commented 1 week ago

fyi @saichandax