TensorBFS / TropicalGEMM.jl

The fastest tropical matrix multiplication in the world!
MIT License
29 stars 2 forks source link

Wrong answers for `Tropical{Float32}` #19

Open chriselrod opened 1 year ago

chriselrod commented 1 year ago

e.g.

julia> using TropicalNumbers, Octavian, TropicalGEMM, BenchmarkTools, Tullio

julia> v2!(r, d) = @tullio (max) r[j, i] = d[k, j] + d[i, k]
v2! (generic function with 1 method)

julia> d = rand(Float32, 4000, 4000); r = similar(d); td = Tropical.(d); tr = similar(td);

julia> typeof(tr)
Matrix{TropicalF32} (alias for Array{Tropical{Float32}, 2})

julia> typeof(td)
Matrix{TropicalF32} (alias for Array{Tropical{Float32}, 2})

julia> @time(Octavian.matmul!(tr, td', td'));
 11.552436 seconds (28.39 M allocations: 1.414 GiB, 7.06% gc time, 198.04% compilation time)

julia> @time(v2!(r,d));
  2.761347 seconds (439.87 k allocations: 27.319 MiB, 20.82% compilation time)

julia> r[1:10,1:10]
10×10 Matrix{Float32}:
 1.98815  1.98032  1.96756  1.97688  1.98138  1.99293  1.99121  1.99046  1.98951  1.99046
 1.98347  1.97355  1.98345  1.96617  1.98862  1.97611  1.98876  1.98316  1.98376  1.98263
 1.98005  1.98607  1.94482  1.99388  1.97896  1.98763  1.97235  1.95483  1.98607  1.97469
 1.97811  1.96782  1.96914  1.99468  1.9638   1.96488  1.99281  1.96719  1.99046  1.99007
 1.97995  1.99206  1.96033  1.99925  1.97888  1.99531  1.98424  1.97277  1.96929  1.99149
 1.98175  1.97349  1.98187  1.96846  1.97438  1.99013  1.98184  1.96848  1.97482  1.99134
 1.97559  1.99141  1.97511  1.98494  1.98186  1.97386  1.95691  1.95286  1.97816  1.98359
 1.96895  1.95868  1.96969  1.98323  1.98321  1.96809  1.96434  1.9723   1.97961  1.9598
 1.97215  1.98366  1.97773  1.98742  1.98507  1.95414  1.96615  1.98585  1.97801  1.97805
 1.98135  1.98764  1.98053  1.96926  1.97798  1.988    1.94989  1.97951  1.98827  1.99674

julia> tr[1:10,1:10]
10×10 Matrix{TropicalF32}:
 1.9881469ₜ  1.9545851ₜ  1.9813809ₜ  1.9596899ₜ  1.9895067ₜ  1.9505024ₜ  1.9669962ₜ  1.9318342ₜ  1.9598813ₜ  1.9938858ₜ
 1.9601333ₜ  1.9426515ₜ  1.9789603ₜ  1.9723496ₜ   1.972192ₜ  1.9418759ₜ  1.9936779ₜ  1.9892867ₜ  1.9816258ₜ  1.9438664ₜ
 1.9267545ₜ  1.9517486ₜ  1.9759095ₜ  1.9842439ₜ   1.969292ₜ  1.9947274ₜ   1.967751ₜ  1.9854705ₜ  1.9549026ₜ  1.9786203ₜ
 1.9466515ₜ  1.9751123ₜ  1.9818645ₜ  1.9569125ₜ  1.9781556ₜ  1.9860659ₜ  1.9599998ₜ  1.9799122ₜ  1.9720198ₜ  1.9701405ₜ
  1.972152ₜ  1.9777286ₜ  1.9671929ₜ  1.9584064ₜ  1.9780083ₜ  1.9773908ₜ   1.958333ₜ  1.9829702ₜ  1.9757335ₜ  1.9570483ₜ
 1.9852564ₜ   1.976543ₜ  1.9943745ₜ   1.965773ₜ  1.9823151ₜ  1.9887956ₜ  1.9841229ₜ  1.9784414ₜ  1.9614553ₜ  1.9801428ₜ
 1.9767675ₜ  1.9753637ₜ  1.9792922ₜ  1.9716537ₜ  1.9809005ₜ  1.9872162ₜ  1.9794067ₜ  1.9565374ₜ  1.9802582ₜ   1.963557ₜ
 1.9824476ₜ  1.9638793ₜ  1.9712152ₜ  1.9819689ₜ  1.9859302ₜ  1.9882684ₜ  1.9788489ₜ  1.9857358ₜ   1.979456ₜ  1.9608781ₜ
 1.9712175ₜ  1.9773014ₜ  1.9713657ₜ   1.985241ₜ  1.9794915ₜ  1.9511222ₜ   1.959055ₜ  1.9835603ₜ  1.9517698ₜ  1.9724486ₜ
 1.9561397ₜ  1.9714637ₜ  1.9776536ₜ  1.9738066ₜ  1.9854319ₜ  1.9693574ₜ  1.9678532ₜ  1.9894817ₜ  1.9719748ₜ  1.9628868ₜ

while it works for Float64:

julia> d = rand(4000, 4000); r = similar(d); td = Tropical.(d); tr = similar(td);

julia> @time(v2!(r,d));
  2.539927 seconds (261.82 k allocations: 16.076 MiB, 16.36% compilation time)

julia> @time(Octavian.matmul!(tr, td', td'));
  9.220967 seconds (25.62 M allocations: 1.240 GiB, 3.42% gc time, 192.78% compilation time)

julia> tr == Tropical.(r)
true
chriselrod commented 1 year ago

Note this is on the following custom branches: https://github.com/JuliaLinearAlgebra/Octavian.jl/pull/157 https://github.com/JuliaArrays/ArrayInterface.jl/pull/369 as it threw errors before.

GiggleLiu commented 1 year ago

Hey, @chriselrod . Thanks for the issue. The adjoint of Tropical numbers are not well defined.

julia> (Matrix(td') * Matrix(td'))[1:10, 1:10]
ERROR: MethodError: no method matching conj(::TropicalF32)
Closest candidates are:
  conj(::Union{LinearAlgebra.Hermitian{T, S}, LinearAlgebra.Symmetric{T, S}} where {T, S}) at ~/.julia/juliaup/julia-1.8.3+0.x64.linux.gnu/share/julia/stdlib/v1.8/LinearAlgebra/src/symmetric.jl:368
  conj(::LinearAlgebra.UniformScaling) at ~/.julia/juliaup/julia-1.8.3+0.x64.linux.gnu/share/julia/stdlib/v1.8/LinearAlgebra/src/uniformscaling.jl:123
  conj(::LinearAlgebra.Adjoint) at ~/.julia/juliaup/julia-1.8.3+0.x64.linux.gnu/share/julia/stdlib/v1.8/LinearAlgebra/src/adjtrans.jl:345
  ...
Stacktrace:
  [1] adjoint(x::TropicalF32)
    @ Base ./number.jl:213

So I am confused why matmul! can produce a result. If you use the following statement instead,

julia> @time(Octavian.matmul!(tr, transpose(td), transpose(td)));

The result is correct.

My package versions

(jl_QKsjQL) pkg> st
Status `/tmp/jl_QKsjQL/Project.toml`
  [6fd5a793] Octavian v0.3.18
  [a4ad3063] TropicalGEMM v0.1.8 `~/.julia/dev/TropicalGEMM`   # master branch is the same as the registered  version
  [bc48ee85] Tullio v0.3.5
chriselrod commented 1 year ago

If you're curious, the motivation was this slack discussion: https://julialang.slack.com/archives/C67TK21LJ/p1670593631969889

I can confirm that tranpose works. I was using adjoints because ' is faster to type.