triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.44k stars 1.65k forks source link

how to implement backward propagation #1189

Open XiaoyuShi97 opened 1 year ago

XiaoyuShi97 commented 1 year ago

Hi. I am new to triton and cuda. From my understanding, when we implement a customized pytorch operator using cuda, we need to define both forward and backward function, so that the gradients are propagated properly. However, in the triton document (e.g., how to implement vector addition, fused softmax ...), there seems nothing about backward function? Does that mean the backward function is automatically generated by triton? Thx!

yenchenlin commented 1 year ago

No, it's not. See this tutorial for an example. The HTML generated by this tutorial is not yet added to the website but you can run make html yourself to render the doc for a preview.

XiaoyuShi97 commented 1 year ago

Hi. Thanks for your reply. I see. Another question is if a function similar to pytorch grid_sample will be implemented? I want to write a customized operator to combine grid_sample and dot product. Naively using pytorch leads to out of memory. I guess it is because sampling too many points. Is triton a good choice for this case? Thx!

yenchenlin commented 1 year ago

@btwbtm I am not a project owner so take my answer with a grain of salt. I think it's totally possible to implement the grid sample inside the Triton kernel function yourself (basically copying values from input based on the grid index). Since the issue's title is "how to implement backward propagation", you may get more help by closing this issue and creating a new one.

XiaoyuShi97 commented 1 year ago

@yenchenlin thanks for your kind reply and suggestion!