facebookresearch / EGG

EGG: Emergence of lanGuage in Games
MIT License
281 stars 99 forks source link

dump_interactions giving error. #255

Open Shubham0209 opened 1 year ago

Shubham0209 commented 1 year ago

Actually, I have already trained my emergent game now I want to generate the interactions that are learned. Specifically, I want to pass my training dataset through my trained 'game' to store the messages that are generated that are passed from sender to receiver. For this, I use the below command - interaction_final = core.dump_interactions(game.to(DEVICE), train_loader, gs=True, variable_length=True, device=DEVICE).

But I get the following error -

image

Any help, would be highly useful. @robertodessi

robertodessi commented 1 year ago

Hi @Shubham0209,

This doesn't look like the standard egg code, could you please add a bit more details on how you are calling these functions and what are dict1 and dict2?

Also, just to make sure, you trained your game, the one you passing to dump interactions, with gumbel softmax and length > 1, right?

Shubham0209 commented 1 year ago

Hey @robertodessi , Thanks for your reply.

The above code snippet shows the code that is part of the interaction.py file. I am using the dump_interactions function from the same file to extract the interactions.

Yes, the game has been trained with the Gumbel softmax and length>1.

robertodessi commented 1 year ago

Hi @Shubham0209,

I think I understood the problem. Could you please share the line(s) that are raising the error though? How's your code calling the __add__ function or any code from the interaction file?

Thanks!

Shubham0209 commented 1 year ago

@robertodessi here is the code:

code_interactions = core.dump_interactions(game.to(DEVICE), dataset=train_loader, gs = True, variable_length = True, device = DEVICE )

robertodessi commented 1 year ago

Oh you wrote it already, sorry!

I think the error is here https://github.com/facebookresearch/EGG/blob/main/egg/core/interaction.py#L102

Could you maybe try to add these two checks at the beginning of the _combine_aux_dicts function?

if not (dict1 or dict2):
    return {}

if bool(dict1) ^ bool(dict2):
    raise RuntimeError("Found an empty and non-empty dict when aggregating interactions")
Shubham0209 commented 1 year ago

Yes, works now!

robertodessi commented 1 year ago

Great! Leaving it open as a reminder for myself to push the fix