gitabcworld / MatchingNetworks

This repo provides pytorch code which replicates the results of the Matching Networks for One Shot Learning paper on the Omniglot and MiniImageNet dataset
Other
328 stars 81 forks source link

there have something wrong in MatchingNetwork.py #9

Closed wan3333 closed 1 year ago

wan3333 commented 1 year ago
def forward(self, support_set_images, support_set_labels_one_hot, target_image, target_label):
        """
        Builds graph for Matching Networks, produces losses and summary statistics.
        :param support_set_images: A tensor containing the support set images [batch_size, sequence_size, n_channels, 28, 28]
        :param support_set_labels_one_hot: A tensor containing the support set labels [batch_size, sequence_size, n_classes]
        :param target_image: A tensor containing the target image (image to produce label for) [batch_size, n_channels, 28, 28]
        :param target_label: A tensor containing the target label [batch_size, 1]
        :return: 
        """
        # produce embeddings for support set images
        encoded_images = []
        for i in np.arange(support_set_images.size(1)):
            gen_encode = self.g(support_set_images[:,i,:,:,:])
            encoded_images.append(gen_encode)

        # produce embeddings for target images
        for i in np.arange(target_image.size(1)):
            gen_encode = self.g(target_image[:,i,:,:,:])
            encoded_images.append(gen_encode)
            outputs = torch.stack(encoded_images)
for i in np.arange(support_set_images.size(1)):
            gen_encode = self.g(support_set_images[:,i,:,:,:]) 

should be

for i in np.arange(support_set_images.size(0)):
            gen_encode = self.g(support_set_images[i,:,:,:,:])