braun-steven / simple-einet

An implementation of EinsumNetworks in PyTorch.
MIT License
18 stars 6 forks source link

Is Differentiable Sampling implemented? #5

Closed agneet42 closed 10 months ago

agneet42 commented 10 months ago

Hi @braun-steven, thanks for the great work. I wanted to confirm with you if this repository contains the updated code for the differentiable sampling as in the paper - https://proceedings.mlr.press/v181/lang22a/lang22a.pdf

On similar lines, I see here (https://github.com/braun-steven/simple-einet/blob/main/examples/test_iris.py ), that you are using CE as loss, that would mean I can base this as an example for differentiable sampling ?

If not, where can I refer to an end-to-end example, where the backprop through an EiNet happens, with the gumbel softmax implementation?

braun-steven commented 10 months ago

Hey Agneet,

thanks for reaching out! Correct, this differentiable sampling implementation was used in our paper.

On similar lines, I see here (https://github.com/braun-steven/simple-einet/blob/main/examples/test_iris.py ), that you are using CE as loss, that would mean I can base this as an example for differentiable sampling ?

Those two things are independent of each other. The CE loss simply ensures that the model log-likelihoods $p(x | y)$ are converted into the posterior, and then the NLL loss is applied to the posterior.

Btw., I've just moved the iris example into a notebook to make it easier to understand and follow-through: https://github.com/braun-steven/simple-einet/blob/main/notebooks/iris_classification.ipynb.

If not, where can I refer to an end-to-end example, where the backprop through an EiNet happens, with the gumbel softmax implementation?

We have provided the code for our experiments over here: https://github.com/ml-research/differentiable-sampling-pc/

Let me know if you have any other questions!

agneet42 commented 10 months ago

Thanks @braun-steven for getting back to me! I have indeed looked at https://github.com/ml-research/differentiable-sampling-pc/ repo too, however what I am exactly looking for is a basic implementation where a diff. loss (such as C2/L1/L2) backprops through the network. I am developing a loss function based on EiNets, and am looking for a sample implementation off which I can base off my own implementation. Essentially, something like this -

model = EiNet()
out = model(data)
loss = custom_loss(out,target)
loss.backward()
braun-steven commented 10 months ago

Differentiable sampling allows you to use the sample, compute a loss on that sample, and backpropagate this loss through the network. That is:

model = Einet()
sample = model.sample(num_samples=N, is_differentiable=True)
loss = custom_loss(sample, data)
loss.backward()
agneet42 commented 10 months ago

Thanks @braun-steven let me try some of the implementations out and I will reach out to you!