Closed ThomasHeap closed 1 year ago
A quick solution to this would be accepting the covariates into trace_pred to have the name to dim mapping happen there and then grabbing tr.covariates_all
Is this still an issue on the laurence
branch?
This is fixed now.
How should covariates work with pred lls? Presumably we want to do something similar to
.sample
where we can pass inputs, but we need to callself.dims_data_inputs
on the covariates first to ensure they have the right dims. The issue is that that the platedims we already have will be the size of thetrain set
nottrainset + testset
like the latents in trace pred will be. We are doing the named to torchdim mapping in two places with two different sizes.https://github.com/alan-ppl/alan/blob/44c5546778baee599adcd2e307fa8d95e7f97d49/alan/model.py#L106