tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
471 stars 75 forks source link

[Feature Request] Cross entropy loss forward and backward functions #14113

Open rfurko-tt opened 3 weeks ago

rfurko-tt commented 3 weeks ago

Is your feature request related to a problem? Please describe. To train a classification model we can do it in two ways:

Cross entropy will reduce amount of programs that we need to execute and improve performance (fused op).

Describe the solution you'd like Introduce cross_entropy_loss and cross_entropy_loss_backward that takes logits and targets (indices, int32_t or uint32_t)

Describe alternatives you've considered

dmakoviichuk-tt commented 3 weeks ago

@ayerofieiev-tt moving it to you because I think Brian is the only person who can implement it fast because of his experience with embedding_bw