jonasrauber / eagerpy

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code
https://eagerpy.jonasrauber.de
MIT License
695 stars 40 forks source link

Implement eagerpy.diag #22

Closed Holt59 closed 1 year ago

Holt59 commented 3 years ago

Closes #21

Implements eagerpy.diag, following numpy / torch convention:

This cannot handle batched inputs like Tensorflow linalg.diag or linalg.diag_part, but I think it's cleaner that way, let me know what you think.

Holt59 commented 3 years ago

So apparently there is a bug in the version of tensorflow used for the test (2.1.2?). The k parameter for tf.linalg.diag and tf.linalg.diag_part is completely neglected.

jonasrauber commented 3 years ago

Let's call the method diag not _diag. No need limit diag to ep.diag. Might be nice to chain diag like other methods.

Holt59 commented 3 years ago

Let's call the method diag not _diag. No need limit diag to ep.diag. Might be nice to chain diag like other methods.

Done. I've also reverted the change to test_transpose_1d I made by mistake.

jonasrauber commented 3 years ago

Thanks! Seems a couple of tests fail because the frameworks return inconsistent results.

Holt59 commented 3 years ago

Thanks! Seems a couple of tests fail because the frameworks return inconsistent results.

Yes, as I said in a previous comment, the tensorflow version used in the CI tests has a bug regarding the k parameter in tf.linalg.diag (see the link in the previous comment).

Holt59 commented 3 years ago

I've rebased the PR from master to get the upgraded TF in tests, hopefully this should fix the bug previously mentioned.