Closed koliaok closed 4 years ago
@koliaok Variable
is from the pytorch class Variable
. It is imported in https://github.com/henryhungle/NADST/blob/afdc1d1f7ecb855b03933e441c0b2fcefbc28feb/model/nadst.py#L7
"NADST/utils/dataset.py" file use Variable in line number 165 like "tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))"
and I add to code "from torch.autograd import Variable" to line number 8 in "dataset.py" file
after add to code and no error for my mac OS environment
it is correct ??
@koliaok yes it is correct. Thanks for pointing out.
your "dataset.py"
def make_std_mask(tgt, pad): "Create a mask to hide padding and future words." tgt_mask = (tgt != pad).unsqueeze(-2) tgt_mask = tgt_mask & Variable(subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) return tgt_mask
it your function and what class reference "Variable"