facebookresearch / TensorComprehensions

A domain specific language to express machine learning workloads.
https://facebookresearch.github.io/TensorComprehensions/
Apache License 2.0
1.76k stars 211 forks source link

Support for CPU computations and add is_cuda check #211

Open arogozhnikov opened 6 years ago

arogozhnikov commented 6 years ago

When submitting a bug report, please include the following information (where relevant):

TensorComprehensions fail when given CPU tensors (and drops the kernel). Here is an example:

import tensor_co`mprehensions as tc
import torch

lang = """
def fcrelu(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1) {
    O1(b, n) +=! I(b, m) * W1(n, m)
    O1(b, n) = O1(b, n) + B1(n)
    O1(b, n) = fmax(O1(b, n), 0)
}
"""
fcrelu = tc.define(lang, name="fcrelu")

B, M, N = 100, 128, 100
# I, W1, B1 = torch.randn(B, M).cuda(), torch.randn(N, M).cuda(), torch.randn(N).cuda()
I, W1, B1 = torch.randn(B, M), torch.randn(N, M), torch.randn(N)

fcrelu.autotune(I, W1, B1, cache="fcrelu_100_128_100.tc")
out = fcrelu(I, W1, B1)

Result:

WARNING: Logging before InitGoogleLogging() is written to STDERR
W0326 11:57:08.843168   229 rtc.cc:144] Error at: /opt/conda/conda-bld/tensor_comprehensions_1520457708651/work/src/core/rtc.cc:144: CUDA_ERROR_INVALID_CONTEXT
W0326 11:57:08.843502   229 rtc.cc:44] Error at: /opt/conda/conda-bld/tensor_comprehensions_1520457708651/work/src/core/rtc.cc:44: CUDA_ERROR_INVALID_HANDLE
terminate called after throwing an instance of 'std::runtime_error'
  what():  Error at: /opt/conda/conda-bld/tensor_comprehensions_1520457708651/work/src/core/rtc.cc:44: CUDA_ERROR_INVALID_HANDLE

It also drops python kernel if you first optimize function and then try to run it with CPU tensors.

PS. Is there a plan to support CPU computations? If not, a simple check like is_cuda should be enough

nicolasvasilache commented 6 years ago

@arogozhnikov TC indeed does not support CPU atm, this is work in progress.

prigoyal commented 6 years ago

Hi @arogozhnikov , thanks for reporting. We don't support CPU computations in TC right now. However, adding a simple check for is_cuda() should definitely be possible. I'll re-title this issue to reflect that more properly. Thank you :)