masa-su / pixyz

A library for developing deep generative models in a more concise, intuitive and extendable way
https://pixyz.io
MIT License
487 stars 41 forks source link

Fix/graph option #162

Closed ktaaaki closed 3 years ago

ktaaaki commented 3 years ago

Issue

The following assertion fails and the sample shape is torch.Size([1]).

dist = Normal(var=['x'], cond_var=['y'],
      loc='y', scale=1) * Normal(var=['y'], loc=0, scale=1)
dist.graph.set_option(
      dict(batch_n=4, sample_shape=(2, 3)), ['y'])
sample = dist.sample()
assert sample['y'].shape == torch.Size([2, 3, 4])

This is because local options are overwritten by global options and ignored.

Solution