crystal-data / num.cr

Scientific computing in pure Crystal
MIT License
151 stars 12 forks source link

`matmul` not Defined for Float * Complex #77

Open stellarpower opened 1 year ago

stellarpower commented 1 year ago

Code Sample, a copy-pastable example if possible

require "num"

include Num

alias RT64 = Tensor(Float64, CPU(Float64))
alias CT64 = Tensor(Complex, CPU(Complex))

complexMatrix = RT64.random(0.0..1.0, [3, 8])
realVector    = RT64.random(0.0..1.0, [8   ])

complexMatrix *= (1 + 1.i) # Didn't get random working, so cheat here.

result = complexMatrix.matmul(realVector)

pp result

Problem description

matmul (and possibly other operations, not managed to check yet) only seems to be defined for operand matrices of the same data type. Multiplying a complex and a float is well-defined on an element-by-element basis, as would be multiplication of a real an an integer, so I reckon an overload should be provided for cases where the underlying element types are different but compatible.

christopherzimmerman commented 1 year ago

Thanks for the issue. Yea, this probably isn’t ideal API design. For any method that uses internal Crystal operators, the compiler knows what type is produced by a Float and Complex, but any method that calls a BLAS routine does not.

I need to do a bit more research to see how much effort it would be to be able to coalesce types before BLAS ops, otherwise multiple overloads should take care of it.

stellarpower commented 1 year ago

Understand what you mean, I assumed it would be the same.

Almost unrelated tangent, I don't know if I opened as an issue, but I'd really like to see compile-time (partially) fixed-size tensors in Crystal. Having used these in C++, and using partial specialisation to write functions operating on them, IMO it's an extremely powerful way to ensure your maths is checked before the code runs, and can be combined with similar statu techniques to take things further (dimensional analysis and conversion of units, strong typing between co-ordinate spaces, for example). Given we have all these great features in the crystal compiler available to us, I think it's be a big plus if we can make use of them to get rid of a whole class of errors that things like numpy just can't pick up until you test the code.

And yeah, I don't know if you can do this using an expression template before then turning around and calling the implementation as it stands now. You'd presumably need a 1-1 correspondence between the BLAS types and the crystal ones, in terms of what the deduced type is for compound expressions.

christopherzimmerman commented 1 year ago

Would you mind opening a separate issue for fixed sizes Tensors? And if you have any, provide a link to what you’re referencing in C++?

Compile time bounds checking and broadcasting would be fantastic.