SaynaEbrahimi / UCB

Original PyTorch implementation of Uncertainty-guided Continual Learning with Bayesian Neural Networks, ICLR 2020
https://openreview.net/pdf?id=HklUCCVKDB
MIT License
73 stars 11 forks source link

Tensorflow Probabality Implementation #2

Open sumitsinha opened 4 years ago

sumitsinha commented 4 years ago

Is there any Tensorflow Probability-based implementation of the uncertainty based change in learning rates, that uses Flipout layers?

SaynaEbrahimi commented 4 years ago

I have only implemented UCB in PyTorch but with a quick search I was able to find this notebook which has implemented Bayes-by-Backprop (BBB), the core block of UCB, in Tensorflow with TensorFlow Probability library. The learning rate update using the uncertainty should be a quick thing to add to this implementation.

sumitsinha commented 4 years ago

Thank you for this. I have the BBB part setup within my repo. I was looking at a way to extract the weight standard deviations and modify the learning rate specifically for each weight as mentioned in the paper which apparently is not so simple in Tensorflow Probability. Please let me know about any insights that can help with this. I will be starting work on this to build a continual learning framework for manufacturing systems.

SaynaEbrahimi commented 4 years ago

If you have BBB setup then you already have \mu and \rho computed for each weight and from \rho you can compute standard deviation. All you need to do next is set a learning rate multiplier (which is the uncertainty here) for each layer. It's been a while I have not used tensorflow but it seems that you can set learning rate multipliers for each layer as described in here. They have suggested multiple ways, I am sure you can use one at least.

sumitsinha commented 4 years ago

Thanks for this, I will have a look and implement a learning rate multiplier. Just as clarification what I understand from the paper is not only each layer but each weight within the layers will need a different multiplication coefficient (standard deviation)