WarBean / tps_stn_pytorch

PyTorch implementation of Spatial Transformer Network (STN) with Thin Plate Spline (TPS)
922 stars 154 forks source link

run python mnist_visualize.py --model unbounded_stn --angle 90 --grid_size 4 error #9

Open KakaVlasic opened 5 years ago

KakaVlasic commented 5 years ago

hello! when I run python mnist_visualize.py --model unbounded_stn --angle 90 --grid_size 4, it appears following:

create model with STN Traceback (most recent call last): File "mnist_visualize.py", line 47, in data_list = target2data_list[target] KeyError: tensor(7)

any help will be appreciate!

w11m commented 5 years ago

I have Same problem

jackweiwang commented 4 years ago

target2data_list = [list() for i in range(10)]

freepoet commented 3 years ago

I think index shouldn't be a tensor, it should be int type, like this :

data_list = target2data_list[int(target)]