ivy-llc / ivy

The Unified ML Representation
https://ivy.dev
Other
14.02k stars 5.82k forks source link

fix: torch.masked_fill #28740

Closed Kacper-W-Kozdon closed 2 months ago

Kacper-W-Kozdon commented 2 months ago

PR Description

Changed the way type-casting was done in the frontend. Previous implementation did not enforce keeping the input's dtype and paddle backend utilises type promotion in the function called by torch's frontend .masked_fill(). This resulted in the failure of dtype assertions and a failed paddle pytest. All of the backends use .where() to get the same results as torch.masked_fill(), with paddle's solution being probably a bit overcomplicated- but the tests for .where() are passing, so I did not want to alter that.

Potentially to-do:

Related Issue

Closes #28437

Checklist

Socials