gonum / matrix

Matrix packages for the Go language [DEPRECATED]
446 stars 53 forks source link

mat64: consider protecting Copy with shadow detection #351

Closed kortschak closed 8 years ago

kortschak commented 8 years ago

See comment https://github.com/gonum/matrix/issues/336#issuecomment-194589199.

Either we detect the orientation of the copy (we can't guarantee this since the BLAS implementation may not be aware of this - and it's hard to get in pure Go) or we prohibit overlapping copies. At the moment we allow silent corruption.

kortschak commented 8 years ago

A correction on this: We can guarantee the orientation of the copy via BLAS through the use of the incX and incY signs. This means that we can use offset(dst, src)'s sign to decide the inc signs.

kortschak commented 8 years ago

We should however check in the T case.

kortschak commented 8 years ago

This is the gist of my proposal (requires some change to the Copier docs):

diff --git a/mat64/dense.go b/mat64/dense.go
index 505bf63..d8500e6 100644
--- a/mat64/dense.go
+++ b/mat64/dense.go
@@ -394,14 +394,24 @@ func (m *Dense) Copy(a Matrix) (r, c int) {
        case RawMatrixer:
                amat := aU.RawMatrix()
                if trans {
+                       m.checkOverlap(amat)
                        for i := 0; i < r; i++ {
                                blas64.Copy(c,
                                        blas64.Vector{Inc: amat.Stride, Data: amat.Data[i : i+(c-1)*amat.Stride+1]},
                                        blas64.Vector{Inc: 1, Data: m.mat.Data[i*m.mat.Stride : i*m.mat.Stride+c]})
                        }
                } else {
-                       for i := 0; i < r; i++ {
-                               copy(m.mat.Data[i*m.mat.Stride:i*m.mat.Stride+c], amat.Data[i*amat.Stride:i*amat.Stride+c])
+                       switch o := offset(m.mat.Data, amat.Data); {
+                       case o < 0:
+                               for i := r - 1; i >= 0; i-- {
+                                       copy(m.mat.Data[i*m.mat.Stride:i*m.mat.Stride+c], amat.Data[i*amat.Stride:i*amat.Stride+c])
+                               }
+                       case o > 0:
+                               for i := 0; i < r; i++ {
+                                       copy(m.mat.Data[i*m.mat.Stride:i*m.mat.Stride+c], amat.Data[i*amat.Stride:i*amat.Stride+c])
+                               }
+                       default:
+                               // Do nothing.
                        }
                }
        default: