Closed z1xy2 closed 11 months ago
The code is tested on Ubuntu 20.04 with 4 Nvidia 3090 GPUs (24GB memory). The CUDA version is 11.3, the pytorch version is 1.12.1
Thanks, I found the problem. It was my global cudatoolkit that was preventing torch.mm from running properly, I deleted the global environment variable and restarted the terminal with pycharm.sh open to make the environment variable change take effect. It's working fine now, thanks for the help!
Great!
Hi, when running the code it will report an error at the line out[start:end] = torch.mm(buffer.flatten(1, 2), weights.flatten(0, 1)), after checking it's a cuda problem. What version of nvcc or cuda toolkit were you using at the time and can you give me a reference, thanks!