facebookresearch / EGG

EGG: Emergence of lanGuage in Games
MIT License
281 stars 99 forks source link

Signal game not able to run with GumbelSoftmax #256

Open DominikKuenkele opened 10 months ago

DominikKuenkele commented 10 months ago

Expected Behavior

When running the signal game with the GumbelSoftmax optimization, the agents should be trained on the data and the loss and metrics should be displayed after each epoch.

Current Behavior

Currently the following error comes up:

Traceback (most recent call last):
  File "/home/dominik/miniconda3/envs/thesis/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/dominik/miniconda3/envs/thesis/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/dominik/Development/EGG/egg/zoo/signal_game/train.py", line 148, in <module>
    trainer.train(n_epochs=opts.n_epochs)
  File "/home/dominik/Development/EGG/egg/core/trainers.py", line 273, in train
    train_loss, train_interaction = self.train_epoch()
  File "/home/dominik/Development/EGG/egg/core/trainers.py", line 262, in train_epoch
    full_interaction = Interaction.from_iterable(interactions)
  File "/home/dominik/Development/EGG/egg/core/interaction.py", line 212, in from_iterable
    aux[k] = _check_cat([x.aux[k] for x in interactions])
  File "/home/dominik/Development/EGG/egg/core/interaction.py", line 190, in _check_cat
    return torch.cat(lst, dim=0)
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated

Steps to Reproduce

  1. download signaling_game_data to /signaling_game_data
  2. run python -m egg.zoo.signal_game.train --root=/signaling_game_data --mode gs

Detailed Description

The problem lies in the function loss_nll. Here, the accuracy is calculated over the batch and meaned. Instead, no mean should be taken, since this is done in the forward call of the SymbolGameGS

Possible Implementation

Remove the calculation of the mean