jonasrauber / eagerpy

PyTorch, TensorFlow, JAX and NumPy — all of them natively using the same code
https://eagerpy.jonasrauber.de
MIT License
693 stars 39 forks source link

Why restrict cross entropy to 2D inputs only? #57

Open hristo-vrigazov opened 1 year ago

hristo-vrigazov commented 1 year ago

First, congrats on such a great project!

Basically the title. PyTorch and Tensorflow both support cross entropy in the ND case, and your implementation in Numpy would work for the multi-dimensional case too. However, in every function, there is an assert that the logits are 2D array. I propose to remove those asserts :smile:

christian-westbrook commented 1 year ago

I ran into this problem while using the latest version of https://github.com/bethgelab/foolbox, which uses eagerpy. I would also benefit from support for cross-entropy against more than two dimensions.