I get the following error when try to evaluate network using torch.float64 data type:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/kurnevsky/workspace/oppai-rs/zero-torch/oppai_net.py", line 59, in predict
policies, values = self(inputs)
File "/nix/store/s7ndjjy8fbsn9np75d24r57v2xikfmvq-python3-3.10.7-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/home/kurnevsky/workspace/oppai-rs/zero-torch/oppai_net.py", line 43, in forward
x = F.relu(self.bn1(self.conv1(x)))
File "/nix/store/s7ndjjy8fbsn9np75d24r57v2xikfmvq-python3-3.10.7-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/nix/store/s7ndjjy8fbsn9np75d24r57v2xikfmvq-python3-3.10.7-env/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 457, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/nix/store/s7ndjjy8fbsn9np75d24r57v2xikfmvq-python3-3.10.7-env/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 453, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Failed dtype == float_data at /build/source/src/gemm.cpp:490
I get the following error when try to evaluate network using
torch.float64
data type: