wavefrontshaping / complexPyTorch

A high-level toolbox for using complex valued neural networks in PyTorch
MIT License
610 stars 148 forks source link

complex dropout #7

Closed xxoospring closed 3 years ago

xxoospring commented 3 years ago

Here is your implyment of complex dropout https://github.com/wavefrontshaping/complexPyTorch/blob/master/complexFunctions.py#L21-L23 we know that dropout process is disabling some cell in a certain pdf., so for the complex dropout, real part and imag part shoud share the same "diabale cell index", but How can I guarantee that their indices are the same in above implyment?

xxoospring commented 3 years ago

Here is my implement of the complex drop, Can you help to see if this is correct?

def complex_dropout(input_r, input_i, p=0.5, training=True, inplace=False):
    if not training:
        return input_r, input_i

    bernoulli_dist = torch.from_numpy(np.random.binomial(1, 1-p, input_r.shape))
    d_r, d_i = input_r.masked_fill(bernoulli_dist == 0, 0.), input_i.masked_fill(bernoulli_dist == 0, 0.)
    # TODO: inplace implement
    return d_r, d_i

class ComplexDropout(nn.Module):
    def __init__(self, p=0.5, inplace=False):
        super(ComplexDropout, self).__init__()
        self.p = p
        self.inplace = inplace

    def forward(self, input_r, input_i):
        return complex_dropout(input_r, input_i, self.p, self.training, self.inplace)
wavefrontshaping commented 3 years ago

Indeed, the dropout was not correct as the real and imaginary parts did not drop the same elements. I also had a related issue with max_pool. I rewrote the entire thing to use the new complex tensors, I corrected the drop_out. I use a not so elegant solution, I apply a dropout to a tensor filled with ones and use the result as a mask that I multiplied the complex tensor with.

def complex_dropout(input, p=0.5, training=True):
    # need to have the same dropout mask for real and imaginary part, 
    # this not a clean solution!
    mask = torch.ones_like(input).type(torch.float32)
    mask = dropout(mask, p, training)*1/(1-p)
    return mask*input