Closed jxmorris12 closed 2 years ago
Enhancement that I'll do once #18 gets merged. I want to move this logic:
if args.experiment == 'finetune': if len(train_batch) == 2: # single-sentence classification sentence, targets = train_batch sentence, targets = sentence.to(device), targets.to(device) # TODO(js) retrofit_change preds = experiment.model(sentence) elif len(train_batch) == 3: # sentence-pair classification sentence1, sentence2, targets = train_batch # We pass sentence pairs as a tensor of shape (B, 2, ...) instead of a tuple of two tensors. sentence_stacked = torch.stack((sentence1, sentence2), axis=1).to(device) targets = targets.to(device) preds = experiment.model(sentence_stacked) else: raise ValueError(f'Expected train_batch of length 2 or 3, got {len(train_batch)}') train_loss = experiment.compute_loss_and_update_metrics(preds, targets, 'Train') else: # sent1, sent2, nsent1, nsent2, token1, token2, ntoken1, ntoken2 = train_batch train_batch = (t.to(device) for t in train_batch) word_rep_pos_1, word_rep_pos_2, word_rep_neg_1, word_rep_neg_2 = experiment.model(*train_batch) train_loss = experiment.compute_loss_and_update_metrics(word_rep_pos_1, word_rep_pos_2, word_rep_neg_1, word_rep_neg_2, 'Train')
into the experiment class, so that we can just make a call like:
train_loss = experiment.compute_loss_and_update_metrics(batch)
and have the experiment handle calling its own model.
Enhancement that I'll do once #18 gets merged. I want to move this logic:
into the experiment class, so that we can just make a call like:
and have the experiment handle calling its own model.