Closed matthieutrs closed 8 months ago
Thanks a lot for the careful review!
output_padding
in torch.nn.ConvTranspose2d
, I've added an assertion in the def of conv_transpose2d
(see above);Let me know if there are other things that need to be updated!
Thanks a lot! It indeed works on my architectures.
Thanks so much @matthieutrs! You get a special prize for being the first person to merge a PR on torch2jax! 🌟
Thanks for this very nice repo! I needed to convert models containing
torch.nn.ConvTranspose2d
but this is not completely straightforward as torch and lax do not perform the same transposed conv.This PR is essentially a merging of https://github.com/google/jax/pull/5772 to solve this issue.
Atm it relies on numpy; if this is a problem I think this could be avoided but I didn't have time to remove the dependency yet.