Closed colleenjg closed 3 years ago
Fixes #25, enabling torch compatibility (checked up to 18.1).
dpc/main.py, lines 215 and 267:
dpc/main.py
target_flattened
cuda
criterion()
int
argmax()
utils/utils.py, line 53:
utils/utils.py
correct_k = correct[:k].view(-1).float().sum(0)
contiguous()
view(-1)
Thanks!
Fixes #25, enabling torch compatibility (checked up to 18.1).
dpc/main.py
, lines 215 and 267:target_flattened
placed oncuda
when created, to avoid an error whencriterion()
is called.target_flattened
converted toint
beforeargmax()
to avoid an error with iftarget_flattened
is a boolean tensor.utils/utils.py
, line 53:correct_k = correct[:k].view(-1).float().sum(0)
, acontiguous()
call is added to avoid an error whenview(-1)
is called.