Changed the way type-casting was done in the frontend. Previous implementation did not enforce keeping the input's dtype and paddle backend utilises type promotion in the function called by torch's frontend .masked_fill(). This resulted in the failure of dtype assertions and a failed paddle pytest. All of the backends use .where() to get the same results as torch.masked_fill(), with paddle's solution being probably a bit overcomplicated- but the tests for .where() are passing, so I did not want to alter that.
Potentially to-do:
if it's not too redundant: adding .masked_fill() for the superset behaviour (even torch backend uses .where() when .masked_fill() is called);
reformatting paddle's backend implementation of .where().
Related Issue
Closes #28437
Checklist
[ ] Did you add a function?
[ ] Did you add the tests?
[x] Did you run your tests and are your tests passing?
PR Description
Changed the way type-casting was done in the frontend. Previous implementation did not enforce keeping the input's dtype and paddle backend utilises type promotion in the function called by torch's frontend
.masked_fill()
. This resulted in the failure of dtype assertions and a failed paddle pytest. All of the backends use.where()
to get the same results astorch.masked_fill()
, with paddle's solution being probably a bit overcomplicated- but the tests for.where()
are passing, so I did not want to alter that.Potentially to-do:
.masked_fill()
for the superset behaviour (even torch backend uses.where()
when.masked_fill()
is called);.where()
.Related Issue
Closes #28437
Checklist
Socials