facebookresearch / EGG

EGG: Emergence of lanGuage in Games
MIT License
291 stars 103 forks source link

issue in MNIST tutorial #237

Open dmitrySorokin opened 2 years ago

dmitrySorokin commented 2 years ago

Plot function doesn't work for receiver_rnn cause it returns n images where n == len(tokens)

Steps to Reproduce

  1. run EGG walkthrough with a MNIST autoencoder untill cell # 26

Possible Implementation

Taking last image in plot would help

def plot(game, test_dataset, is_gs, variable_length):
    interaction = \
            core.dump_interactions(game, test_dataset, is_gs, variable_length)

    for z in range(10):
        src = interaction.sender_input[z].squeeze(0)
        if variable_length:
            dst = interaction.receiver_output[z].view(-1, 28, 28)[-1]
        else:
            dst = interaction.receiver_output[z].view(28, 28)
        # we'll plot two images side-by-side: the original (left) and the reconstruction
        image = torch.cat([src, dst], dim=1).cpu().numpy()

        plt.title(f"Input: digit {z}, channel message {interaction.message[z]}")
        plt.imshow(image, cmap='gray')
        plt.show()
robertodessi commented 2 years ago

Hi @dmitrySorokin,

Thanks for spotting it. Technically we should plot the output of the receiver when the sender generated an EOS (index 0 in EGG) if any, otherwise the last output/image. Feel free to open a PR on this.

Thanks!

dmitrySorokin commented 2 years ago

Ok, I'll do it soon