lukemelas / EfficientNet-PyTorch

A PyTorch implementation of EfficientNet
Apache License 2.0
7.88k stars 1.53k forks source link

Question about Drop Connect & Stochastic Depth #104

Open Bear-kai opened 4 years ago

Bear-kai commented 4 years ago

The parameter drop_connect_rate is used for stochastic depth of the network, but the function drop_connect() seems to drop samples ?

def drop_connect(inputs, p, training):                        
    if not training: return inputs
    batch_size = inputs.shape[0]
    keep_prob = 1 - p
    random_tensor = keep_prob
    random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
    binary_tensor = torch.floor(random_tensor)
    output = inputs / keep_prob * binary_tensor
    return output

Stochastic Depth should drop the building block randomly. I don't understand the above function. Anyone can help ?!

Bear-kai commented 4 years ago

Obviously, the two "drop connect" above are totally different things! I think the implementation is confused.

Renthal commented 4 years ago

Relevant issue: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956

JVGD commented 4 years ago

Thank you for the clarification since I came here with similar questions. I still have some doubts with regard to the implementation. Specifically, I was expecting to see outputs being dropped out according to the binary_tensor (with multiplication):

output = inputs * binary_tensor

However I see that apart from the dropping (implemented with binary_tensor) there is also a scaling factor tensor (keep_prob) applied to the inputs:

output = inputs / keep_prob * binary_tensor

Why is that? I know there are kinds of Gaussian Dropouts where inputs are scaled up or down according to Normal distribution... but in here I'm not quite sure why.

I see that by doing inputs / keep_prob you are actually scaling up inputs by 1/(1-p) or p/(p-1), is this some sort of regularization to ensure the mean of the output values are the same as the inputs values before dropping?