bubbliiiing / Siamese-pytorch

这是一个孪生神经网络(Siamese network)的库,可进行图片的相似性比较。
MIT License
539 stars 120 forks source link

数据读取问题 #18

Closed boris-lb closed 1 year ago

boris-lb commented 1 year ago

博主好,非常感谢你的代码,按照要求整理好数据之后,会出现如下错误: Traceback (most recent call last): File "train.py", line 343, in fit_one_epoch(model_train, model, loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank) File "/home//projects/Siamese-pytorch/utils/utils_fit.py", line 37, in fit_one_epoch outputs = model_train(images) File "/home//anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(input, kwargs) File "/home//anaconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 167, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/home//anaconda3/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py", line 177, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/home//anaconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply output.reraise() File "/home//anaconda3/lib/python3.7/site-packages/torch/_utils.py", line 429, in reraise raise self.exc_type(msg) ValueError: Caught ValueError in replica 0 on device 0. Original Traceback (most recent call last): File "/home//anaconda3/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker output = module(input, kwargs) File "/home//anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(input, kwargs) File "/home//projects/Siamese-pytorch/nets/siamese.py", line 34, in forward x1, x2 = x ValueError: not enough values to unpack (expected 2, got 1)

仔细读了一下代码,没找到是什么原因引起的 debug了一下,发现前面生成的数据正常,生成几十次之后就会出现这个问题,麻烦问一下为什么x的长度有时候会是1而造成报错?

boris-lb commented 1 year ago

在博主的文章问答里面找到了解决方案 第一步:修改nets/siamese.py文件里面的forward(self, x)为forward(self, x1, x2) 第二步:修改utils/utils_fit.py文件里面的model_train(images)为model_train(images[0],images[1]),共两处地方,训练和验证。

bubbliiiing commented 1 year ago

虽然但是我好像没有看出来问题在哪……0 0