WarBean / tps_stn_pytorch

PyTorch implementation of Spatial Transformer Network (STN) with Thin Plate Spline (TPS)
926 stars 155 forks source link

run the training on wider input images (width=height>>28) #7

Open kao123 opened 5 years ago

kao123 commented 5 years ago

great project! i succeeded to run the code on my own images with input size=28. I was trying to run the code using a different input size (e.g widht=height=300) As soon as i modify the args.image_height = args.image_width to any other value than 28 (in my data_loader and mnist_train) i get the following error

File "/home/myaccount/tps_stn_pytorch/tps_grid_gen.py", line 67, in forward assert source_control_points.size(1) == self.num_points AssertionError '''

I tried to modify the tps_grid_gen code.. but nothing's worked. Any help please

KakaVlasic commented 5 years ago

hello! I have the same problem, how do you deal with it? I try to realign the code in mnist_train 'def train(epoch)' , the AssertionError disappear but i don't get any output(checkpoint, accuracy log). but the code indeed run on my gpu. any help would be appreciated!

ilyalasy commented 3 years ago

Hey, is this problem solved? @WarBean

Fleyderer commented 1 year ago

This problem is not solved yet :\