yaringal / ConcreteDropout

Code for Concrete Dropout as presented in https://arxiv.org/abs/1705.07832
MIT License
245 stars 68 forks source link

How to adapt heteroscedastic loss for a classification problem? #11

Open manuelblancovalentin opened 5 years ago

manuelblancovalentin commented 5 years ago

How could we adapt this for a classification problem using crossentropy?:

def heteroscedastic_mse(true, pred):
        mean = pred[:, :D]
        log_var = pred[:, D:]
        precision = K.exp(-log_var)
        return K.sum(precision * (true-mean)**2. + log_var, -1)

I found a post where it was suggested to use a montecarlo simulation, however accuracy gets stuck at a very low value, and won't go any up:

def heteroscedastic_categorical_crossentropy(true, pred):
        mean = pred[:, :D]
        log_var = pred[:, D:]

        log_std = K.sqrt(log_var)

        # variance depressor
        logvar_dep = K.exp(log_var) - K.ones_like(log_var)

        #undistorted loss
        undistorted_loss = K.categorical_crossentropy(mean, true, from_logits=True)

        # apply montecarlo simulation
        T = 100
        iterable = K.variable(np.ones(T))
        dist = distributions.Normal(loc=K.zeros_like(log_std), scale=log_std)
        monte_carlo_results = K.map_fn(\
                        gaussian_categorical_crossentropy(true, mean, \
                                                          dist, \
                                                          undistorted_loss,\
                                                          D), iterable, \
                                                          name='monte_carlo_results')

        var_loss = K.mean(monte_carlo_results, axis=0) * undistorted_loss

        return var_loss + undistorted_loss + K.sum(logvar_dep,-1)

where gaussian_categorical_crossentropy is defined by:

def gaussian_categorical_crossentropy(true, pred, dist, undistorted_loss, num_classes):
  def map_fn(i):
    std_samples = dist.sample(1)
    distorted_loss = K.categorical_crossentropy(pred + std_samples[0], true, 
                                                from_logits=True)
    diff = undistorted_loss - distorted_loss
    return -K.elu(diff)
  return map_fn

The source of the last code: https://github.com/kyle-dorman/bayesian-neural-network-blogpost

Thanks in advance!

ianpdavies commented 4 years ago

Yeah I agree that this would be very useful. The loss function for a classification problem is in the Concrete Dropout paper, but it's unclear to me how to implement it in this code.