When running the signal game with the GumbelSoftmax optimization, the agents should be trained on the data and the loss and metrics should be displayed after each epoch.
Current Behavior
Currently the following error comes up:
Traceback (most recent call last):
File "/home/dominik/miniconda3/envs/thesis/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/dominik/miniconda3/envs/thesis/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/dominik/Development/EGG/egg/zoo/signal_game/train.py", line 148, in <module>
trainer.train(n_epochs=opts.n_epochs)
File "/home/dominik/Development/EGG/egg/core/trainers.py", line 273, in train
train_loss, train_interaction = self.train_epoch()
File "/home/dominik/Development/EGG/egg/core/trainers.py", line 262, in train_epoch
full_interaction = Interaction.from_iterable(interactions)
File "/home/dominik/Development/EGG/egg/core/interaction.py", line 212, in from_iterable
aux[k] = _check_cat([x.aux[k] for x in interactions])
File "/home/dominik/Development/EGG/egg/core/interaction.py", line 190, in _check_cat
return torch.cat(lst, dim=0)
RuntimeError: zero-dimensional tensor (at position 0) cannot be concatenated
Steps to Reproduce
download signaling_game_data to /signaling_game_data
run python -m egg.zoo.signal_game.train --root=/signaling_game_data --mode gs
Detailed Description
The problem lies in the function loss_nll. Here, the accuracy is calculated over the batch and meaned. Instead, no mean should be taken, since this is done in the forward call of the SymbolGameGS
Expected Behavior
When running the signal game with the GumbelSoftmax optimization, the agents should be trained on the data and the loss and metrics should be displayed after each epoch.
Current Behavior
Currently the following error comes up:
Steps to Reproduce
/signaling_game_data
python -m egg.zoo.signal_game.train --root=/signaling_game_data --mode gs
Detailed Description
The problem lies in the function
loss_nll
. Here, the accuracy is calculated over the batch and meaned. Instead, no mean should be taken, since this is done in theforward
call of theSymbolGameGS
Possible Implementation
Remove the calculation of the mean