ruby-numo / numo-narray

Ruby/Numo::NArray - New NArray class library
http://ruby-numo.github.io/narray/
BSD 3-Clause "New" or "Revised" License
415 stars 41 forks source link

Prefer argmax rather than max_index #117

Closed sonots closed 4 years ago

sonots commented 5 years ago

NumPy argmax returns the indices of the maximum values along an axis Numo max_index returns the indices of the maximum values along an axis from head.

>>> import numpy
>>> a = numpy.arange(6).reshape(2,3)
>>> a
array([[0, 1, 2],
       [3, 4, 5]])
>>> a.argmax(0)
array([1, 1, 1])
irb(main):003:0> a = Numo::SFloat.new(2, 3).seq
=> Numo::SFloat#shape=[2,3]
[[0, 1, 2],
 [3, 4, 5]]
irb(main):004:0> a.max_index(0)
=> Numo::Int32#shape=[3]
[3, 4, 5]

Because of this, red-chainer must write a workaround to convert the indices as:

https://github.com/red-data-tools/red-chainer/blob/2a454e57d1341dda63b150e666d67221ca65ad23/lib/chainer/functions/evaluation/accuracy.rb#L33

where chainer writes as:

https://github.com/chainer/chainer/blob/aeccb0aa8e8063c0c0c49908155ce240280675ed/chainer/functions/evaluation/accuracy.py#L53

This different gives huge impact especially in Cumo because it requires to synchronize GPU with CPU and transfer GPU memory to CPU memory at map.

I do no know about background why Numo chose max_index specification rather than argmax. If there are any reasons, please let me know.

masa16 commented 4 years ago

New methods argmin and argmax are implemented. https://github.com/ruby-numo/numo-narray/commit/6c12cc01932cdfa85ca8ce24a39fe95692a0b06f https://github.com/ruby-numo/numo-narray/commit/0eda526c94cafefc4e7137797c1560e33229671c