There is another potential issue here which is there wandb does not make nice loss curves for more than one model. I don't know why.
The logging of reward and gating model loss etc maybe would look like this:
for i in tqdm.tqdm(range(max_iter)):
t0 = time.time()
qsize = model.history.qsize()
model.train(max_iter=1)
# most recent backward pass populates self.optimizer.loss
wandb.log({'gating_train_loss': model.gating_model.optimizer.loss.item()})
wandb.log({'reward_train_loss': model.reward_model.optimizer.loss.item()}) # reward model is not trained currently in neuron class
Note that above I'm assuming we can access the loss from the optimizer. This needs to be confirmed.
Another thing is that the baseline models such as ConstantGatingModel and RandomGatingModel are pytorch modules but do not have any trainable parameters. So tracking the gradients and parameters may not work at all. If this causes errors we need to redefine the baseline models to have a torch.nn.Linear layer with untrainable weights.
There is another potential issue here which is there wandb does not make nice loss curves for more than one model. I don't know why.
The logging of reward and gating model loss etc maybe would look like this:
Note that above I'm assuming we can access the loss from the optimizer. This needs to be confirmed.
Another thing is that the baseline models such as
ConstantGatingModel
andRandomGatingModel
are pytorch modules but do not have any trainable parameters. So tracking the gradients and parameters may not work at all. If this causes errors we need to redefine the baseline models to have atorch.nn.Linear
layer with untrainable weights.