gorgonia / tensor

package tensor provides efficient and generic n-dimensional arrays in Go that are useful for machine learning and deep learning purposes
Apache License 2.0
362 stars 49 forks source link

Fix scalar multiplication #51

Closed bezineb5 closed 5 years ago

bezineb5 commented 5 years ago

Hello, I found an issue in the Mul operator, when the 1st term is a scalar, it always returned the 1st term unchanged. The actual problem seems to be in defaultengine_arith.go, in MulScalar, where the leftTensor parameter is not properly handled (seems that on line 669, it calls Mul in a way that doesn't change the output). However, this is generated code, so I'm not sure how to fix it. Plus, it's actually simpler to commute the 2 terms. What was the rationale behind the use of leftTensor?

Thanks!

coveralls commented 5 years ago

Coverage Status

Coverage increased (+0.02%) to 72.809% when pulling f719761e84a1db3ee47db8b7a6f8f6a9d35b55a1 on bezineb5:scalar-mul into d78b17f3e9857e0b755070a011d200176b7baa75 on gorgonia:master.

coveralls commented 5 years ago

Coverage Status

Coverage increased (+0.003%) to 72.794% when pulling f719761e84a1db3ee47db8b7a6f8f6a9d35b55a1 on bezineb5:scalar-mul into d78b17f3e9857e0b755070a011d200176b7baa75 on gorgonia:master.

chewxy commented 5 years ago

Two things

1. You Found A Bug

You have found a bug! Thank you! The bug is highly likely related to the issues you found, but not what you described.

The test cases you wrote highlights the bug:

    a2 := New(WithBacking([]float64{2}))
    b2 := New(WithBacking([]float64{3}))
    var correct interface{} = 6.0

    res, err := Mul(a2, b2)
    if err != nil {
        t.Fatalf("Error: %v", err)
    }
    assert.Equal(t, correct, res.Data())
    t.Logf("a2 %v b2 %v, res %v", a2, b2, res)

and

    a := New(WithBacking([]float64{3, 2}))
    b := New(WithBacking([]float64{2}))
    correct := []float64{6, 4}

    res, err := Mul(a, b)
    if err != nil {
        t.Fatalf("Error: %v", err)
    }
    assert.Equal(t, correct, res.Data())
    t.Logf("a %v b %v, res %v", a, b, res)

    // Test commutativity
    res, err = Mul(b, a)
    if err != nil {
        t.Fatalf("Error: %v", err)
    }
    assert.Equal(t, correct, res.Data())
    t.Logf("a %v b %v, res %v", a, b, res)

The solution however is not to fix the API, but to fix the defaultEngine 's arith functions, which are as you say, generated.

Specifically, one line short of 666: https://github.com/gorgonia/tensor/blob/master/defaultengine_arith.go#L665

This problem is easily fixed. But for now I'd like to close this PR, write up an issue, then attribute the fix to you

2. Email Me

Please email me - chewxy [at] gmail.com