raskr / rust-autograd

Tensors and differentiable operations (like TensorFlow) in Rust
MIT License
487 stars 37 forks source link

Bug for `g.argmax` #46

Closed laocaoshilaocao closed 3 years ago

laocaoshilaocao commented 3 years ago

Hi, i am using ag.argmax method and i met a weird situation.

let test = g.constant(array![85.0, 16.0, 0.04, 85.0, 16.0, 85.0]);
let max_dis_index = g.argmax(test, 0, false).show_with("max_dis_index is");
g.eval(&[max_dis_index], &[]);

The output becomes 8.0 which is obviously wrong.

That happens when the maximum number has more than 2 times.

raskr commented 3 years ago

Fixed in https://github.com/raskr/rust-autograd/pull/49