I have added my questions as comment lines in the code below.
I think there might be some problems in the code based on my understanding to it:
1) epoch_i += 1 should not be added twice in the while loop
2) i += 1 should move to after the if to_reinit: block
Please correct me if I am wrong...
Many thanks
def train_as_vaelp(self, train_loader, num_epochs=10,
verbose_step=50, lr=1e-3):
optimizer = optim.Adam(self.parameters(), lr=lr)
global_stats = TrainStats()
local_stats = TrainStats()
epoch_i = 0
to_reinit = False
buf = None
while epoch_i < num_epochs:
i = 0
if verbose_step:
print("Epoch", epoch_i, ":")
if epoch_i in [0, 1, 5]: #why set to_reinit to True only in epoch 0, 1, 5
to_reinit = True
epoch_i += 1 #epoch is added by 1 here but it is added again right before the "if i > 0:" block down below
for x_batch, y_batch in train_loader:
if verbose_step:
print("!", end='')
i += 1 #I think this line should be moved to right after the "if to_reinit" block to avoid the "i" is updated but not local/global_stats
y_batch = y_batch.float().to(self.lp.tt_cores[0].device)
if len(y_batch.shape) == 1:
y_batch = y_batch.view(-1, 1).contiguous()
if to_reinit:
if (buf is None) or (buf.shape[0] < 5000):
enc_out = self.enc.encode(x_batch)
means, log_stds = torch.split(enc_out,
len(self.latent_descr),
dim=1)
z_batch = (means + torch.randn_like(log_stds) *
torch.exp(0.5 * log_stds))
cur_batch = torch.cat([z_batch, y_batch], dim=1)
if buf is None:
buf = cur_batch
else:
buf = torch.cat([buf, cur_batch])
else:
descr = len(self.latent_descr) * [0]
descr += len(self.feature_descr) * [1]
self.lp.reinit_from_data(buf, descr)
self.lp.cuda()
buf = None
to_reinit = False
continue
#I thought i += 1 should be here instead (see above)
elbo, cur_stats = self.get_elbo(x_batch, y_batch)
local_stats.update(cur_stats)
global_stats.update(cur_stats)
optimizer.zero_grad()
loss = -elbo
loss.backward()
optimizer.step()
if verbose_step and i % verbose_step == 0:
local_stats.print()
local_stats.reset()
i = 0
epoch_i += 1 #why adding epoch_i by 1 again here
if i > 0:
local_stats.print()
local_stats.reset()
return global_stats
I have added my questions as comment lines in the code below. I think there might be some problems in the code based on my understanding to it: 1) epoch_i += 1 should not be added twice in the while loop 2) i += 1 should move to after the if to_reinit: block
Please correct me if I am wrong... Many thanks