AntreasAntoniou / Learning_to_Learn_via_Self-Critique

The original code for the paper "Learning to Learn via Self-Critique".
https://arxiv.org/abs/1905.10295
MIT License
46 stars 8 forks source link

about image_channels #3

Open qianyewu opened 4 years ago

qianyewu commented 4 years ago

Excuse me,I want to know how to set the 1 channel images,such as omniglot. I change the experiment_config and set image_channels to 1.when I run this program,the following error occurred. Traceback (most recent call last): File "train_few_shot_system.py", line 25, in <module> maml_system.run_experiment() File "/home/qyw/pythonCode/Learning_to_Learn_via_Self-Critique/experiment_builder.py", line 1156, in run_experiment augment_images=self.augment_flag)): File "/home/qyw/pythonCode/Learning_to_Learn_via_Self-Critique/data.py", line 741, in get_train_batches for sample_id, sample_batched in enumerate(self.get_dataloader()): File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 582, in __next__ return self._process_next_batch(batch) File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 608, in _process_next_batch raise batch.exc_type(batch.exc_msg) RuntimeError: Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop samples = collate_fn([dataset[i] for i in batch_indices]) File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 99, in <listcomp> samples = collate_fn([dataset[i] for i in batch_indices]) File "/home/qyw/pythonCode/Learning_to_Learn_via_Self-Critique/data.py", line 667, in __getitem__ self.get_set(self.current_set_name, seed=self.seed[self.current_set_name] + idx, augment_images=self.augment_images) File "/home/qyw/pythonCode/Learning_to_Learn_via_Self-Critique/data.py", line 511, in get_set x = augment_image(image=x_class_data) File "/home/qyw/pythonCode/Learning_to_Learn_via_Self-Critique/data.py", line 22, in augment_image cur_image = transform_current(cur_image) File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py", line 164, in __call__ return F.normalize(tensor, self.mean, self.std, self.inplace) File "/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py", line 208, in normalize tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) RuntimeError: output with shape [1, 84, 84] doesn't match the broadcast shape [3, 84, 84] Looking forward to your reply.