Jutho / TensorOperations.jl

Julia package for tensor contractions and related operations
https://jutho.github.io/TensorOperations.jl/stable/
Other
438 stars 55 forks source link

Floating Point Accuracy of @tensor results with CUDA #144

Closed ejmeitz closed 10 months ago

ejmeitz commented 11 months ago

What should the expected accuracy of the results be when using the @tensor macro with CUDA Float32 arrays?

I am using the following contraction:

 @tensor begin
       A[n,m,l] = B[i,j,k]*C[i,n]*C[j,m]*C[k,l]
 end

and noticed that if I brute force calculate one of the resulting terms in the A tensor the accuracy varies wildly. For the largest elements in the A matrix the brute force approach matches @tensor fairly well (0.009529917140821487 vs 0.009529919f0). However, for the smaller elements (which should be zero) bruteforce (64 bit) gives fairly different results than the @tensor macro (1e-18 vs 1e-10). Is this expected behavior? I'm having a hard time deciding the cutoff for what values are truly there vs. which values should be zero.

Any advice on how to apply tolerances to the A tensor would be appreciated! I feel like it should be something to do with eps(Float32) but if the library is using tensor cores internally then it might be lowered to float16.

Jutho commented 11 months ago

I don't understand the question very well, but if you are working with Float32, the expected precision is order 1e-7. Check:

julia> eps(Float32)
1.1920929f-7

Hence, entries of the arrays that should be zero, can easily end up to be order 1e-7 due to numerical precision.

ejmeitz commented 11 months ago

Yeah I understand the signficant digits thing. I guess two questions I have:

lkdvos commented 10 months ago

Let me first elaborate a bit on the cuTENSOR side of things. As far as I understand, cuTENSOR allows for a dynamic way of changing the floating point accuracy, which is allowed to be of higher precision then the output array. The exact details are mostly mentioned in the docs page you linked, but here are some important notes:

Over to the TensorOperations side of things, in principle we do not expose the additional option to have increased intermediate accuracy when the output array is of lower accuracy. TensorOperations (not only for CuArrays) chooses an output type that has the precision based on promoting the type of all input arrays, and this is also the guarantee we then ask of cuTENSOR.

Finally, if you are using Julia, be careful with the conversion from Array to CuArray. cu(A) is defined to silently change the precision to Float32 from whatever it was before, as this is typically the most optimized precision for your GPU, but this may not necessarily be what you want, in which case you should CuArray(A).

Thus, answering your first question, cuTENSOR will in general not lower the floating point accuracy, but it may increase it if all input/output arrays and scalars support this. TensorOperations only asks the guaranteed precision of the output array from cuTENSOR, which is determined in function of the inputs.

For your second question, this is rather application-dependent. In generic cases, there really is no way of knowing if 1e-2 is actually a floating point effect or rather a value being zero. This is just inherent to working with floating point numbers. As a pathological example, the following holds:

julia> 1e-1 + 1e20 - 1e20 - 1e-1
-0.1

Nevertheless, when dealing with addition and multiplication only, this mostly occurs when your input floats are of vastly different scales. This is precisely what eps will tell you, it is the smallest number you can add which will still be representable, thus any precision smaller than this is lost. If your use-case has some bounds on the scales of input floats, you could possibly exclude these pathological cases, and decide that anything larger than some predefined limit is definitely not zero, and possibly anything smaller than this is actually zero. Often people use sqrt(eps) or eps^(3/4) for this, but this is mostly phenomenological and you should probably experiment and see what works for you. However, unless you abolutely require this, you could just leave the ones that are probably zero and continue, as these should only contribute about as much as the floating point errors you are making to the final result.