Closed ktaaaki closed 3 years ago
The following assertion fails and the sample shape is torch.Size([1]).
torch.Size([1])
dist = Normal(var=['x'], cond_var=['y'], loc='y', scale=1) * Normal(var=['y'], loc=0, scale=1) dist.graph.set_option( dict(batch_n=4, sample_shape=(2, 3)), ['y']) sample = dist.sample() assert sample['y'].shape == torch.Size([2, 3, 4])
This is because local options are overwritten by global options and ignored.
global_option
.update
=
set_option
Issue
The following assertion fails and the sample shape is
torch.Size([1])
.This is because local options are overwritten by global options and ignored.
Solution
global_option
from.update
to=
(set) so thatglobal_option
can be rewritten byset_option
method.