tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
https://docs.tenstorrent.com/ttnn/latest/index.html
Apache License 2.0
488 stars 80 forks source link

Improvements Needed for Dropout Feature Implementation #4669

Open dongjin-na opened 10 months ago

dongjin-na commented 10 months ago

Dear TT members,

There are LLK implementations of dropout op but no interface and no cases to use in tt-metal. (llk_math_eltwise_unary_sfpu_dropout_init, llk_math_eltwise_unary_sfpu_dropout)

I wrote a compute API header that calls those LLK functions and modified the SFPI code of the calculate_dropout function to run dropout op properly. (For more information, please refer to https://github.com/tenstorrent-metal/tt-metal/commit/0dcdd1879d862c13c1184b7eef3564ec5ddf5782)

After testing this feature, I found two things to check.

The management of random state.

In general, compute API follows this call sequence; xxx_init() and xxx(). The random state is initialized with a seed argument from dropout_init() and then is used in dropout().

In other words, the same random state is used every time the dropout API is executed.

Rather than the current method, it should be changed to initializing the random state once, storing it in memory, and updating the random state by passing it as an argument to the dropout API. Also, to manage the random state in the memory, tile data format with uint16_t or uint32_t should be supported. (related to https://github.com/tenstorrent-metal/tt-metal/issues/4624)

Functional aspects of dropout

When torch.nn.Dropout(p=0.5) is executed, then 50% of elements in the output tiles should be zero. When testing the dropout API on a single tile, the difference occurs between the ratio of zeroes in the output tile compared to the value of the p(probability) argument. Moreover, there is a significant difference in the number of zeroes in the output for same p but different seed.

Test env

p (integer representation) ratio of zeroes in the output tile
(when seed = 0)
ratio of zeroes in the output tile
(when seed = 3407)
0.9 (58982) 0.95 0.87
0.7 (45875) 0.83 0.73
0.5 (32768) 0.7 0.45
0.3 (19660) 0.55 0.2
0.1 (6553) 0.38 0.04
dongjin-na commented 10 months ago

I pushed moreh/dropout branch to check this feature. :) test code is test_eltwise_unary_dropout.cpp.