HEmile / storchastic

Stochastic Automatic Differentiation library for PyTorch.
GNU General Public License v3.0
178 stars 5 forks source link

Implement GO/Generalized Stochastic Backpropagation #85

Open HEmile opened 3 years ago

HEmile commented 3 years ago

See https://hal.archives-ouvertes.fr/hal-02968975/document

Implement: For multivariate Bernoulli, Sample a single normal sample, then for each dimension, use this sample and flip the corresponding dimension.

Run the normal sample and the flipped samples. Weight the original sample with 1. The multiplicative is: the sum of the parameters for the normal sample. These are negated if the corresponding dimension is 1. Then for the flipped samples, use -parameter if it flipped to 0, otherwise use just the parameter (ie, not negated).

To make the zeroth-order correct, use importance sampling for the flipped samples. In the multiplicative estimator, divide again by this importance sampling to ensure correctness.

This doesn't use a baseline.

HEmile commented 3 years ago

For the unbiased GO-gradient implementation, see the paper.