Closed nicofirst1 closed 3 years ago
Unfortunately my proposed solution won't work because of the shape of kwargs which will be passed as a list to the game rather than a proper dict.
At the moment my batch looks like this
list( -> [4]
1. tensor -> [batch size, image size],
2. tensor -> [batch size, label size],
3. tensor -> [batch size, image size],
4. list ( -> [batch size]
dict[ 7 keys],
dict[ 7 keys],
....
dict[ 7 keys]
),
)
I see two solutions here:
Instead of having a list of dicts (as in my case, 4th elem of batch) passed from the dataLoader to the forward method, use a dict of lists. This should work with kwargs, but it will be kind of weird to have all the inputs as a list and the kwards as a dict. The result would something like
list( -> [4]
1. tensor -> [batch size, image size],
2. tensor -> [batch size, label size],
3. tensor -> [batch size, image size],
4. dict ( -> [7]
list -> [batch size],
list -> [batch size],
....
list -> [batch size],
),
)
The necessary modification would be in trainer
for batch_id, batch in enumerate(self.train_data):
batch = move_to(batch, self.device)
context = autocast() if self.scaler else nullcontext()
with context:
args, kwargs = batch
optimized_loss, interaction = self.game(*args, **kwargs)
And here as I showed before
In this case add an auxiliary input to the forward method which is later passed to the sender/receiver/logger.
def forward(self, sender_input, labels, receiver_input=None, aux_input=None):
message, log_prob_s, entropy_s = self.sender(sender_input, aux_input)
message_length = find_lengths(message)
receiver_output, log_prob_r, entropy_r = self.receiver(
message, receiver_input, message_length, aux_input
)
loss, aux_info = self.loss(
sender_input, message, receiver_input, receiver_output, labels, aux_input
)
......
aux_info["sender_entropy"] = entropy_s.detach()
aux_info["receiver_entropy"] = entropy_r.detach()
aux_info["length"] = message_length.float() # will be averaged
aux_info["aux_input"] = aux_input
logging_strategy = (
self.train_logging_strategy if self.training else self.test_logging_strategy
)
interaction = logging_strategy.filtered_interaction(
sender_input=sender_input,
labels=labels,
receiver_input=receiver_input,
message=message.detach(),
receiver_output=receiver_output.detach(),
message_length=message_length,
aux=aux_info,
)
return optimized_loss, interaction
The problem with this approach is that the sender and receiver would raise error since the default pytorch nn.Module has only one input in the forward method. Maybe one could define a generic class which subclasses nn.Module and takes as argument an aux parameter.
After discussion in #174 I'm closing this
Is your proposal related to a problem?
I'm working with the coco dataset which uses as label a dictionary of elements. I would like to pass such dicts to Sender/Receiver and logger (with aux_info).
Describe the solution you'd like to have implemented
The solution is quite easy. Just add kwargs here as: