pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 460 forks source link

adaptive_max_pool2d is not supported #3049

Closed luchangli03 closed 3 years ago

luchangli03 commented 3 years ago

๐Ÿš€ Feature

Support adaptive_max_pool2d

Motivation

Currently adaptive_avg_pool2d is supported, but the adaptive_max_pool2d is not. When I try to train U-GAT-IT (https://github.com/znxlwm/UGATIT-pytorch), the module is lowered to multiple graph and thus result in bad performance due to this problem.

luchangli03 commented 3 years ago

@JackCaoG I have also met a problem that upsample canโ€™t be lowered when I try to train U-GAT-IT using XLA-GPU backend. Do you know where the custom call of upsample for TPU is implemented and do you have some tips about how to implement the custom call of upsample for GPU backend? Thank you very much.

JackCaoG commented 3 years ago

I implemented the custom call EmitPadToStatic and EmitSliceToDynamic for gpu, you can check my cl https://github.com/tensorflow/tensorflow/commit/64a5248407af3532c5eb282414fe95b62ee3bfec. The code currently lives in here

JackCaoG commented 3 years ago

I am working on adaptive_max_pool2d lowering, but we can only support the input_size % output_size == 0 case. The reason is that pt/xla use xla::maxpool to implement the max_pool which requires fixed kernel and stride. For a input size(using 1d as example) of [10] and output size of [4], pytorch will max pool from windows [0, 2], [2,4], [5,7], [7,9], where stride is not constant.