zhen-he / tracking-by-animation

A PyTorch implementation of the "Tracking-by-Animation" algorithm published at CVPR 2019.
123 stars 24 forks source link

TypeError: type Tensor doesn't define __round__ method #4

Closed RandomFool123 closed 5 years ago

RandomFool123 commented 5 years ago

My pytorch vision is 1.0.1 instead of 0.3.1,and when I run the gen_mnist.py ,it seem occur a wrong that:

(pytorch) yeyu@yeyu-QiTianM415-D004:~/downProject/tracking-by-animation$ python scripts/gen_mnist.py 
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/mnist/MNIST/raw/train-images-idx3-ubyte.gz
100%|█████████████████████████████▉| 9904128/9912422 [01:29<00:00, 44937.66it/s]Extracting data/mnist/MNIST/raw/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz
                                                                               Extracting data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz                       
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
9920512it [01:40, 44937.66it/s]                                                Extracting data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
                                                                               Extracting data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!t [00:00, 16220.22it/s]                                                    
train frame number: 1920000.0
test frame number: 80000.00000000007
Running with 4 cores.
Traceback (most recent call last):
  File "scripts/gen_mnist.py", line 170, in <module>
    out_batch = parallel(delayed(process_batch)(states_batch[n], s) for n in range(0, N)) # N * 2 * T * H * W * D
  File "/home/yeyu/anaconda3/envs/pytorch/lib/python3.6/site-packages/joblib/parallel.py", line 934, in __call__
    self.retrieve()
  File "/home/yeyu/anaconda3/envs/pytorch/lib/python3.6/site-packages/joblib/parallel.py", line 833, in retrieve
    self._output.extend(job.get(timeout=self.timeout))
  File "/home/yeyu/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/pool.py", line 644, in get
    raise self._value
  File "/home/yeyu/anaconda3/envs/pytorch/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "/home/yeyu/anaconda3/envs/pytorch/lib/python3.6/site-packages/joblib/_parallel_backends.py", line 567, in __call__
    return self.func(*args, **kwargs)
  File "/home/yeyu/anaconda3/envs/pytorch/lib/python3.6/site-packages/joblib/parallel.py", line 225, in __call__
    for func, args, kwargs in self.items]
  File "/home/yeyu/anaconda3/envs/pytorch/lib/python3.6/site-packages/joblib/parallel.py", line 225, in <listcomp>
    for func, args, kwargs in self.items]
  File "scripts/gen_mnist.py", line 95, in process_batch
    h_, w_ = round(h * scale * ratio), round(w * scale / ratio)
TypeError: type Tensor doesn't define __round__ method
9920512it [01:57, 84620.34it/s]
1654784it [00:26, 63032.69it/s] 

So I try to solve it by changing h_, w_ = round(h * scale * ratio), round(w * scale / ratio) into h_, w_ = np.round(h * scale * ratio),np. round(w * scale / ratio).Then it seem work,but I don't know if it will affect the result.

zhen-he commented 5 years ago

Hi, I think your solution is right. However, with a pytorch version > 0.3.1, many more places need to be changed (to know the differences in pytorch 0.3.1, 0.4, and 1.0, please check the official website). So I would recommend using pytorch 0.3.1 before we release the next version.

RandomFool123 commented 5 years ago

Thank you for your answer.