stepelu / lua-sci

SciLua: Scientific Computing with LuaJIT
http://scilua.org
Other
144 stars 15 forks source link

Mat operations #11

Open sonoro1234 opened 1 year ago

sonoro1234 commented 1 year ago

This is continuation of #4 with several related questions

How this should be done without sci syntax extensions? For multiplication we should use alg.mul but which should be used for sums?

Also: with and without sci syntax extensions, how to exponentiate (math.exp) all elements in a matrix?

How to get matrix transposed without sci syntax? Related to that, (while I was trying to self answer the above question) I found an error with sci-lang in the script:

local alg = require 'sci.alg'
local A = alg.mat(4,7)
local B = (A[]`)

with error:

local function __aexpr_1(__x1)
    local __r1 = __array_alloc(__x1, __dim_elw_1(__x1))
    for __i = 0, __r1._n - 1 do
        __r1._p[__i] = ` __x1._p[__i]
    end
    return __r1
end
local alg = require("sci.alg")
local A = alg.mat(4, 7)
local B = __aexpr_1(A)

Output:
c:\scilua\luajit.exe: c:\scilua\lua\sci-lang\__bin\scilua.lua:77: [string "local __alg = require("sci.alg").__..."]:7: unexpected symbol near '`'

Of course it can be done as in:

local function transpose(x)
    local t = alg.mat(x:ncol(),x:nrow())
    for r=1,x:nrow() do
        for c=1,x:ncol() do
            t[{c, r}] = x[{r, c}]
        end
    end
    return t
end

but I had the impression that using OpenBlas was for using functions from the library to multiply, transpose or sum matrixes?

sonoro1234 commented 1 year ago

Trying to perform res = A*x + y I have defined

local function matdims(x)
    return x:nrow().."x"..x:ncol()
end
local function same_type_check_3(x, y, z)
  local ct = x:elementct()
  if ct ~= y:elementct() or ct ~= z:elementct() then
    error('constant element type required')
  end
end
-- matrix, vector, vector : res = A*x + y
local function muladdv(A,x,y)
    print("muladdv",matdims(A),matdims(x),matdims(y))
    assert(A:ncol() == x:nrow() and y:nrow()== A:nrow() and y:ncol() == x:ncol() and y:ncol()==1)
    local res = y:copy()
    assert(not(rawequal(res, x) or rawequal(res, x) or rawequal(res,y)))
    same_type_check_3(A,x,res)
    res:_gemv(A, x, 0, 1, 1)
    return res
end

The program, after performing several muladdv, suddenly stops (perhaps bad memory access?) without any error message or assert. Which could be the reason?

sonoro1234 commented 1 year ago

Another question: why __add and related where not used in scilua?