Override __call__ in _Weighting and Aggregator to specify type hints
Force _as_tensor_list to cast tensors into a list of tensors
Fix wrong type provided to _grad and _jac
Rename v to grad_outputs in get_vjp
I think one of the best changes is get_vjp. It was not clear at all what v was supposed to be.
In fact, jac_outputs is a Sequence[Tensor] and v is also a Sequence[Tensor], but of each tensor has 1 less dimension. So v is what we refer to as grad_outputs. I thus renamed v to grad_outputs.
ExceptionContext
type hint for expectations__call__
in_Weighting
andAggregator
to specify type hints_as_tensor_list
to cast tensors into a list of tensors_grad
and_jac
v
tograd_outputs
inget_vjp
I think one of the best changes is
get_vjp
. It was not clear at all whatv
was supposed to be. In fact,jac_outputs
is aSequence[Tensor]
andv
is also aSequence[Tensor]
, but of each tensor has 1 less dimension. Sov
is what we refer to asgrad_outputs
. I thus renamedv
tograd_outputs
.