microsoft / IRNet

An algorithm for cross-domain NL2SQL
MIT License
268 stars 80 forks source link

mask with dtype torch.uint8, deprecated #5

Closed ng2dev closed 4 years ago

ng2dev commented 4 years ago

I am getting alot of deprecation warnings with python 3.6 cuda 9.2 and pytorch 1.3 Any idea what I could do to avoid that?

/tmp/pip-req-build-vxpey3tb/aten/src/ATen/native/cuda/LegacyDefinitions.cpp:19: UserWarning: maskedfill received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. /tmp/pip-req-build-vxpey3tb/aten/src/ATen/native/cuda/LegacyDefinitions.cpp:19: UserWarning: maskedfill received a mask with dtype torch.uint8, this behavior is now deprecated,please use a mask with dtype torch.bool instead. ^CTraceback (most recent call last): File "train.py", line 117, in train(args) File "train.py", line 80, in train sketch_loss_coefficient=args.sketch_loss_coefficient) File "/home/eni/IRNet/src/utils.py", line 242, in epoch_train score = model.forward(examples) File "/home/eni/IRNet/src/models/model.py", line 281, in forward pre_types.append(pre_type) KeyboardInterrupt

longxudou commented 4 years ago

Just modify the src/models/nn_utils.py in the following way: change the return of length_array_to_mask_tensor()/table_dict_to_mask_tensor()/ pred_col_mask() from mask = torch.ByteTensor(mask) to mask = torch.BoolTensor(mask). @ng2dev

JasperGuo commented 4 years ago

Thanks @DreamerDeo . We have created a patch for this issue. @ng2dev please try the latest version of IRNet.