Open tillahoffmann opened 5 months ago
This is a good point. I guess a better check is to make sure that there are no latent variables under the subsample plates. When that is the case, there is no need to specify the create_plates
argument.
@tillahoffmann sorry for the last misleading comment. For subsampling, the usage is
create_plates = lambda n, x, subsample_size=None: numpyro.plate("n", n, subsample_size=subsample_size)
AutoNormal(..., create_plates=create_plates)
AutoNormal
,AutoDelta
, andAutoGuideList
raise an exception in SVI when the subsample size varies across differentlog_density
evaluation. Here is an example reproducing the issue (run onmaster
).The traceback for the failed call is as follows.
I think the issue is that these guides use
_create_plates
which in turn uses prototype traces to determine the subsample size.https://github.com/pyro-ppl/numpyro/blob/aec6bd58b4cf0c2d81b96e62d4d0cf7af3744885/numpyro/infer/autoguide.py#L108-L113
The prototype traces are of course only created on the first invocation such that there is a discrepancy in the expected subsample size when a different mini-batch size is used. Guides inheriting from
AutoContinuous
do not call_create_plates
and do not use plates in their__call__
method. I couldn't quite figure out why some guides do and some guides don't.