masa-su / pixyz

A library for developing deep generative models in a more concise, intuitive and extendable way
https://pixyz.io
MIT License
491 stars 41 forks source link

Add a product of normal distributions #66

Closed masa-su closed 5 years ago

masa-su commented 5 years ago

Example (ProductOfNormal)

    >>> pon = ProductOfNormal([p_x, p_y])
    >>> pon.sample({"x": x, "y": y})
    {'x': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],),
     'y': tensor([[0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 1., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.]]),
     'z': tensor([[ 0.6611,  0.3811,  0.7778,  ..., -0.0468, -0.3615, -0.6569],
         [-0.0071, -0.9178,  0.6620,  ..., -0.1472,  0.6023,  0.5903],
         [-0.3723, -0.7758,  0.0195,  ...,  0.8239, -0.3537,  0.3854],
         ...,
         [ 0.7820, -0.4761,  0.1804,  ..., -0.5701, -0.0714, -0.5485],
         [-0.1873, -0.2105, -0.1861,  ..., -0.5372,  0.0752,  0.2777],
         [-0.2563, -0.0828,  0.1605,  ...,  0.2767, -0.8456,  0.7364]])}
    >>> pon.sample({"y": y})
    {'y': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 1.],
         [0., 0., 0.,  ..., 1., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
     'z': tensor([[-0.3264, -0.4448,  0.3610,  ..., -0.7378,  0.3002,  0.4370],
         [ 0.0928, -0.1830,  1.1768,  ...,  1.1808, -0.7226, -0.4152],
         [ 0.6999,  0.2222, -0.2901,  ...,  0.5706,  0.7091,  0.5179],
         ...,
         [ 0.5688, -1.6612, -0.0713,  ..., -0.1400, -0.3903,  0.2533],
         [ 0.5412, -0.0289,  0.6365,  ...,  0.7407,  0.7838,  0.9218],
         [ 0.0299,  0.5148, -0.1001,  ...,  0.9938,  1.0689, -1.1902]])}
    >>> pon.sample()  # same as sampling from unit Gaussian.
    {'z': tensor(-0.4494)}

Example (ElementWiseProductOfNormal)

    >>> pon = ElementWiseProductOfNormal(p)
    >>> pon.sample({"x": x})
    {'x': tensor([[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]]),
     'z': tensor([[-0.3572, -0.0632,  0.4872,  0.2269, -0.1693, -0.0160, -0.0429,  0.2017,
          -0.1589, -0.3380, -0.9598,  0.6216, -0.4296, -1.1349,  0.0901,  0.3994,
           0.2313, -0.5227, -0.7973,  0.3968,  0.7137, -0.5639, -0.4891, -0.1249,
           0.8256,  0.1463,  0.0801, -1.2202,  0.6984, -0.4036,  0.4960, -0.4376,
           0.3310, -0.2243, -0.2381, -0.2200,  0.8969,  0.2674,  0.4681,  1.6764,
           0.8127,  0.2722, -0.2048,  0.1903, -0.1398,  0.0099,  0.4382, -0.8016,
           0.9947,  0.7556, -0.2017, -0.3920,  1.4212, -1.2529, -0.1002, -0.0031,
           0.1876,  0.4267,  0.3622,  0.2648,  0.4752,  0.0843, -0.3065, -0.4922],
         [ 0.3770, -0.0413,  0.9102,  0.2897, -0.0567,  0.5211,  1.5233, -0.3539,
           0.5163, -0.2271, -0.1027,  0.0294, -1.4617,  0.1640,  0.2025, -0.2190,
           0.0555,  0.5779, -0.2930, -0.2161,  0.2835, -0.0354, -0.2569, -0.7171,
           0.0164, -0.4080,  1.1088,  0.3947,  0.2720, -0.0600, -0.9295, -0.0234,
           0.5624,  0.4866,  0.5285,  1.1827,  0.2494,  0.0777,  0.7585,  0.5127,
           0.7500, -0.3253,  0.0250,  0.0888,  1.0340, -0.1405, -0.8114,  0.4492,
           0.2725, -0.0270,  0.6379, -0.8096,  0.4259,  0.3179, -0.1681,  0.3365,
           0.6305,  0.5203,  0.2384,  0.0572,  0.4804,  0.9553, -0.3244,  1.5373]])}
    >>> pon.sample({"x": torch.zeros_like(x)})  # same as sampling from unit Gaussian.
    {'x': tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
     'z': tensor([[-0.7777, -0.5908, -1.5498, -0.7505,  0.6201,  0.7218,  1.0045,  0.8923,
          -0.8030, -0.3569,  0.2932,  0.2122,  0.1640,  0.7893, -0.3500, -1.0537,
          -1.2769,  0.6122, -1.0083, -0.2915, -0.1928, -0.7486,  0.2418, -1.9013,
           1.2514,  1.3035, -0.3029, -0.3098, -0.5415,  1.1970, -0.4443,  2.2393,
          -0.6980,  0.2820,  1.6972,  0.6322,  0.4308,  0.8953,  0.7248,  0.4440,
           2.2770,  1.7791,  0.7563, -1.1781, -0.8331,  0.1825,  1.5447,  0.1385,
          -1.1348,  0.0257,  0.3374,  0.5889,  1.1231, -1.2476, -0.3801, -1.4404,
          -1.3066, -1.2653,  0.5958, -1.7423,  0.7189, -0.7236,  0.2330,  0.3117],
         [ 0.5495,  0.7210, -0.4708, -2.0631, -0.6170,  0.2436, -0.0133, -0.4616,
          -0.8091, -0.1592,  1.3117,  0.0276,  0.6625, -0.3748, -0.5049,  1.8260,
          -0.3631,  1.1546, -1.0913,  0.2712,  1.5493,  1.4294, -2.1245, -2.0422,
           0.4976, -1.2785,  0.5028,  1.4240,  1.1983,  0.2468,  1.1682, -0.6725,
          -1.1198, -1.4942, -0.3629,  0.1325, -0.2256,  0.4280,  0.9830, -1.9427,
          -0.2181,  1.1850, -0.7514, -0.8172,  2.1031, -0.1698, -0.3777, -0.7863,
           1.0936, -1.3720,  0.9999,  1.3302, -0.8954, -0.5999,  2.3305,  0.5702,
          -1.0767, -0.2750, -0.3741, -0.7026, -1.5408,  0.0667,  1.2550, -0.5117]])}