samuela / torch2jax

Run PyTorch in JAX. 🤝
168 stars 5 forks source link

adding torch.nn.ConvTranspose2d #3

Closed matthieutrs closed 8 months ago

matthieutrs commented 8 months ago

Thanks for this very nice repo! I needed to convert models containing torch.nn.ConvTranspose2d but this is not completely straightforward as torch and lax do not perform the same transposed conv.

This PR is essentially a merging of https://github.com/google/jax/pull/5772 to solve this issue.

Atm it relies on numpy; if this is a problem I think this could be avoided but I didn't have time to remove the dependency yet.

matthieutrs commented 8 months ago

Thanks a lot for the careful review!

  1. The dependency to numpy has been removed;
  2. The failling tests for some strides/kernel sizes was due to jax assuming a certain output_padding in torch.nn.ConvTranspose2d, I've added an assertion in the def of conv_transpose2d (see above);

Let me know if there are other things that need to be updated!

matthieutrs commented 8 months ago

Thanks a lot! It indeed works on my architectures.

samuela commented 8 months ago

Thanks so much @matthieutrs! You get a special prize for being the first person to merge a PR on torch2jax! 🌟