facebookresearch / EGG

EGG: Emergence of lanGuage in Games
MIT License
281 stars 99 forks source link

Add kwargs in SenderReceiverRnnReinforce forward pass #173

Closed nicofirst1 closed 3 years ago

nicofirst1 commented 3 years ago

Is your proposal related to a problem?

I'm working with the coco dataset which uses as label a dictionary of elements. I would like to pass such dicts to Sender/Receiver and logger (with aux_info).

Describe the solution you'd like to have implemented

The solution is quite easy. Just add kwargs here as:

    def forward(self, sender_input, labels, receiver_input=None, **kwargs):
        message, log_prob_s, entropy_s = self.sender(sender_input, **kwargs)
        message_length = find_lengths(message)
        receiver_output, log_prob_r, entropy_r = self.receiver(
            message, receiver_input, message_length, **kwargs
        )

        loss, aux_info = self.loss(
            sender_input, message, receiver_input, receiver_output, labels, **kwargs
        )

        ......

        aux_info["sender_entropy"] = entropy_s.detach()
        aux_info["receiver_entropy"] = entropy_r.detach()
        aux_info["length"] = message_length.float()  # will be averaged
        aux_info["kwargs"] = kwargs

        logging_strategy = (
            self.train_logging_strategy if self.training else self.test_logging_strategy
        )
        interaction = logging_strategy.filtered_interaction(
            sender_input=sender_input,
            labels=labels,
            receiver_input=receiver_input,
            message=message.detach(),
            receiver_output=receiver_output.detach(),
            message_length=message_length,
            aux=aux_info,
        )

        return optimized_loss, interaction
nicofirst1 commented 3 years ago

Unfortunately my proposed solution won't work because of the shape of kwargs which will be passed as a list to the game rather than a proper dict.

More infos

At the moment my batch looks like this

list( -> [4]
    1. tensor -> [batch size, image size], 
    2. tensor -> [batch size, label size], 
    3. tensor -> [batch size, image size], 
   4. list ( -> [batch size]
        dict[ 7 keys],
        dict[ 7 keys],
        ....
        dict[ 7 keys]
        ),
    )

I see two solutions here:

Pass a dict of lists

Instead of having a list of dicts (as in my case, 4th elem of batch) passed from the dataLoader to the forward method, use a dict of lists. This should work with kwargs, but it will be kind of weird to have all the inputs as a list and the kwards as a dict. The result would something like

list( -> [4]
    1. tensor -> [batch size, image size], 
    2. tensor -> [batch size, label size], 
    3. tensor -> [batch size, image size], 
   4. dict ( -> [7]
        list -> [batch size],
        list -> [batch size],
        ....
        list -> [batch size],
        ),
    )

The necessary modification would be in trainer


        for batch_id, batch in enumerate(self.train_data):
            batch = move_to(batch, self.device)

            context = autocast() if self.scaler else nullcontext()
            with context:
                args, kwargs = batch
                optimized_loss, interaction = self.game(*args, **kwargs)

And here as I showed before

Use aux input

In this case add an auxiliary input to the forward method which is later passed to the sender/receiver/logger.

 def forward(self, sender_input, labels, receiver_input=None, aux_input=None):
        message, log_prob_s, entropy_s = self.sender(sender_input, aux_input)
        message_length = find_lengths(message)
        receiver_output, log_prob_r, entropy_r = self.receiver(
            message, receiver_input, message_length, aux_input
        )

        loss, aux_info = self.loss(
            sender_input, message, receiver_input, receiver_output, labels, aux_input
        )

        ......

        aux_info["sender_entropy"] = entropy_s.detach()
        aux_info["receiver_entropy"] = entropy_r.detach()
        aux_info["length"] = message_length.float()  # will be averaged
        aux_info["aux_input"] = aux_input

        logging_strategy = (
            self.train_logging_strategy if self.training else self.test_logging_strategy
        )
        interaction = logging_strategy.filtered_interaction(
            sender_input=sender_input,
            labels=labels,
            receiver_input=receiver_input,
            message=message.detach(),
            receiver_output=receiver_output.detach(),
            message_length=message_length,
            aux=aux_info,
        )

        return optimized_loss, interaction

The problem with this approach is that the sender and receiver would raise error since the default pytorch nn.Module has only one input in the forward method. Maybe one could define a generic class which subclasses nn.Module and takes as argument an aux parameter.

robertodessi commented 3 years ago

After discussion in #174 I'm closing this