artyom-beilis / pytorch_dlprim

DLPrimitives/OpenCL out of tree backend for pytorch
http://blog.dlprimitives.org/
MIT License
277 stars 17 forks source link

Error evaluating network with torch.float64 #14

Open kurnevsky opened 1 year ago

kurnevsky commented 1 year ago

I get the following error when try to evaluate network using torch.float64 data type:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/kurnevsky/workspace/oppai-rs/zero-torch/oppai_net.py", line 59, in predict
    policies, values = self(inputs)
  File "/nix/store/s7ndjjy8fbsn9np75d24r57v2xikfmvq-python3-3.10.7-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/kurnevsky/workspace/oppai-rs/zero-torch/oppai_net.py", line 43, in forward
    x = F.relu(self.bn1(self.conv1(x)))
  File "/nix/store/s7ndjjy8fbsn9np75d24r57v2xikfmvq-python3-3.10.7-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/nix/store/s7ndjjy8fbsn9np75d24r57v2xikfmvq-python3-3.10.7-env/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 457, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/nix/store/s7ndjjy8fbsn9np75d24r57v2xikfmvq-python3-3.10.7-env/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 453, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Failed dtype == float_data at /build/source/src/gemm.cpp:490
artyom-beilis commented 1 year ago

dlprim does not support float64 yet.