VincLee8188 / GMAN-PyTorch

Implementation of Graph Muti-Attention Network with PyTorch
134 stars 30 forks source link

it seems that this version doesn't support gpu to run? #6

Open 0shelter0 opened 2 years ago

0shelter0 commented 2 years ago

I run a epoch cost too long to continue, I want to PR this project to made it to be adaptive to use gpu(s).

VincLee8188 commented 2 years ago

Nope,, the present version doesn't support GPU. It can support GPU through the following modifications: (1) modification in the main.py file, add the following statement at the beggining: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") (2)modification in the train.py file: --- add the following statement at the beggining: model.to(device) --- modification to all the existing statements X = trainX[start_idx: end_idx] TE = trainTE[start_idx: end_idx] label = trainY[start_idx: end_idx] as followings: X = trainX[start_idx: end_idx].to(device) TE = trainTE[start_idx: end_idx].to(device) label = trainY[start_idx: end_idx].to(device) (3) in the test.py file, the similar modification should be made as the above-mentioned in item(2).

Hope it could be helpful to you.

zxz15063731130 commented 1 year ago

Nope,, the present version doesn't support GPU. It can support GPU through the following modifications: (1) modification in the main.py file, add the following statement at the beggining: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") (2)modification in the train.py file: --- add the following statement at the beggining: model.to(device) --- modification to all the existing statements X = trainX[start_idx: end_idx] TE = trainTE[start_idx: end_idx] label = trainY[start_idx: end_idx] as followings: X = trainX[start_idx: end_idx].to(device) TE = trainTE[start_idx: end_idx].to(device) label = trainY[start_idx: end_idx].to(device) (3) in the test.py file, the similar modification should be made as the above-mentioned in item(2).

Hope it could be helpful to you.

Try as you said and report an error: RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor