henryhungle / NADST

Code for the paper Non-Autoregressive Dialog State Tracking (ICLR20)
MIT License
44 stars 5 forks source link

what class and I some error #3

Closed koliaok closed 4 years ago

koliaok commented 4 years ago

your "dataset.py"

def make_std_mask(tgt, pad): "Create a mask to hide padding and future words." tgt_mask = (tgt != pad).unsqueeze(-2) tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) return tgt_mask it your function and what class reference "Variable"

henryhungle commented 4 years ago

@koliaok Variable is from the pytorch class Variable. It is imported in https://github.com/henryhungle/NADST/blob/afdc1d1f7ecb855b03933e441c0b2fcefbc28feb/model/nadst.py#L7

koliaok commented 4 years ago

"NADST/utils/dataset.py" file use Variable in line number 165 like "tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))"

and I add to code "from torch.autograd import Variable" to line number 8 in "dataset.py" file

after add to code and no error for my mac OS environment

it is correct ??

henryhungle commented 4 years ago

@koliaok yes it is correct. Thanks for pointing out.