Closed michaeldeistler closed 2 months ago
In #1066, we had defined that log_prob and loss have the same input and output shapes:
log_prob
loss
density_estimator.log_prob(input, condition)
input: (sample_input, batch_input, *event_shape_input) condition: (batch_condition, *event_shape_condition) returns: (sample_input, batch_input) raises: batch_input != batch_condition
However, for .loss, we are now removing the sample_dim. Therefore, the .loss function now has the following signature:
.loss
sample_dim
input: (batch_input, *event_shape_input) condition: (batch_condition, *event_shape_condition) returns: (batch_input) raises: batch_input != batch_condition
Put an x in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code.
x
pytest.mark.slow
main
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 77.01%. Comparing base (005aeac) to head (6ae9ede).
005aeac
6ae9ede
In #1066, we had defined that
log_prob
andloss
have the same input and output shapes:However, for
.loss
, we are now removing thesample_dim
. Therefore, the.loss
function now has the following signature:Checklist
Put an
x
in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code.pytest.mark.slow
.main
(or there are no conflicts withmain
)