jscuds / rf-bert

MIT License
0 stars 0 forks source link

Refactor train.py to move messy inference logic into experiment #22

Closed jxmorris12 closed 2 years ago

jxmorris12 commented 2 years ago

Enhancement that I'll do once #18 gets merged. I want to move this logic:

if args.experiment == 'finetune':
                if len(train_batch) == 2: # single-sentence classification
                    sentence, targets = train_batch
                    sentence, targets = sentence.to(device), targets.to(device) # TODO(js) retrofit_change
                    preds = experiment.model(sentence)
                elif len(train_batch) == 3: # sentence-pair classification
                    sentence1, sentence2, targets = train_batch
                    # We pass sentence pairs as a tensor of shape (B, 2, ...) instead of a tuple of two tensors.
                    sentence_stacked = torch.stack((sentence1, sentence2), axis=1).to(device)
                    targets = targets.to(device)
                    preds = experiment.model(sentence_stacked)
                else:
                    raise ValueError(f'Expected train_batch of length 2 or 3, got {len(train_batch)}')
                train_loss = experiment.compute_loss_and_update_metrics(preds, targets, 'Train')
            else:
                # sent1, sent2, nsent1, nsent2, token1, token2, ntoken1, ntoken2 = train_batch
                train_batch = (t.to(device) for t in train_batch)
                word_rep_pos_1, word_rep_pos_2, word_rep_neg_1, word_rep_neg_2 = experiment.model(*train_batch)
                train_loss = experiment.compute_loss_and_update_metrics(word_rep_pos_1, word_rep_pos_2, word_rep_neg_1, word_rep_neg_2, 'Train')

into the experiment class, so that we can just make a call like:

train_loss = experiment.compute_loss_and_update_metrics(batch)

and have the experiment handle calling its own model.