Closed korchi closed 10 months ago
@korchi if you are truncating the target you should use trainer.predict()
, which is using n-1
inputs to predict nth
item. we are not masking anything if we use .predict()
.
However, if you are using trainer.evaluate()
, we are automatically masking the last item under the hood, so that we generate prediction result for the last item in the given input. So you dont need to truncate the input sequence if you use .evaluate()
.
Bug description
When
trainer.evaluate()
is called, the model can see all the inputs, including the targets, whose embeddings influence the all latent embeddings. I believe, thattargets
should betruncated
to simulate the production environment.Steps/Code to reproduce bug
model
and anysequence
from a dataset.sequence
intoinput, target = sequence[:-1], sequence[-1]
and runpred = trainer.evaluate(input_dataset).predictions[0]
(oninput_dataset
created from theinput
sequence) and computerecall_simulated = recall(target, pred)
.recall_eval = trainer.evaluate(sequence)
recall_eval.recall
is different fromrecall_simulated
, which shouldn't be.Expected behavior
recall_eval.recall
should return the same recall asrecall_simulated
Environment details
Additional context
Find attached masking.patch file, which fixed the result discrepancy for me.