insilicomedicine / GENTRL

Generative Tensorial Reinforcement Learning (GENTRL) model
596 stars 215 forks source link

Various problems in train_as_vaelp function in gentrl.py #16

Open albertma-evotec opened 4 years ago

albertma-evotec commented 4 years ago

I have added my questions as comment lines in the code below. I think there might be some problems in the code based on my understanding to it: 1) epoch_i += 1 should not be added twice in the while loop 2) i += 1 should move to after the if to_reinit: block

Please correct me if I am wrong... Many thanks

def train_as_vaelp(self, train_loader, num_epochs=10,
                       verbose_step=50, lr=1e-3):
        optimizer = optim.Adam(self.parameters(), lr=lr)

        global_stats = TrainStats()
        local_stats = TrainStats()

        epoch_i = 0
        to_reinit = False
        buf = None
        while epoch_i < num_epochs:
            i = 0
            if verbose_step:
                print("Epoch", epoch_i, ":")

            if epoch_i in [0, 1, 5]:      #why set to_reinit to True only in epoch 0, 1, 5
                to_reinit = True

            epoch_i += 1     #epoch is added by 1 here but it is added again right before the "if i > 0:" block down below

            for x_batch, y_batch in train_loader:
                if verbose_step:
                    print("!", end='')

                i += 1      #I think this line should be moved to right after the "if to_reinit" block to avoid the "i" is updated but not local/global_stats

                y_batch = y_batch.float().to(self.lp.tt_cores[0].device)
                if len(y_batch.shape) == 1:
                    y_batch = y_batch.view(-1, 1).contiguous()

                if to_reinit:
                    if (buf is None) or (buf.shape[0] < 5000):
                        enc_out = self.enc.encode(x_batch)
                        means, log_stds = torch.split(enc_out,
                                                      len(self.latent_descr),
                                                      dim=1)
                        z_batch = (means + torch.randn_like(log_stds) *
                                   torch.exp(0.5 * log_stds))
                        cur_batch = torch.cat([z_batch, y_batch], dim=1)
                        if buf is None:
                            buf = cur_batch
                        else:
                            buf = torch.cat([buf, cur_batch])
                    else:
                        descr = len(self.latent_descr) * [0]
                        descr += len(self.feature_descr) * [1]
                        self.lp.reinit_from_data(buf, descr)
                        self.lp.cuda()
                        buf = None
                        to_reinit = False

                    continue
                #I thought i += 1 should be here instead (see above)
                elbo, cur_stats = self.get_elbo(x_batch, y_batch)
                local_stats.update(cur_stats)
                global_stats.update(cur_stats)

                optimizer.zero_grad()
                loss = -elbo
                loss.backward()
                optimizer.step()

                if verbose_step and i % verbose_step == 0:
                    local_stats.print()
                    local_stats.reset()
                    i = 0

            epoch_i += 1       #why adding epoch_i by 1 again here
            if i > 0:
                local_stats.print()
                local_stats.reset()

        return global_stats