Closed xxoospring closed 3 years ago
Here is my implement of the complex drop, Can you help to see if this is correct?
def complex_dropout(input_r, input_i, p=0.5, training=True, inplace=False):
if not training:
return input_r, input_i
bernoulli_dist = torch.from_numpy(np.random.binomial(1, 1-p, input_r.shape))
d_r, d_i = input_r.masked_fill(bernoulli_dist == 0, 0.), input_i.masked_fill(bernoulli_dist == 0, 0.)
# TODO: inplace implement
return d_r, d_i
class ComplexDropout(nn.Module):
def __init__(self, p=0.5, inplace=False):
super(ComplexDropout, self).__init__()
self.p = p
self.inplace = inplace
def forward(self, input_r, input_i):
return complex_dropout(input_r, input_i, self.p, self.training, self.inplace)
Indeed, the dropout was not correct as the real and imaginary parts did not drop the same elements. I also had a related issue with max_pool. I rewrote the entire thing to use the new complex tensors, I corrected the drop_out. I use a not so elegant solution, I apply a dropout to a tensor filled with ones and use the result as a mask that I multiplied the complex tensor with.
def complex_dropout(input, p=0.5, training=True):
# need to have the same dropout mask for real and imaginary part,
# this not a clean solution!
mask = torch.ones_like(input).type(torch.float32)
mask = dropout(mask, p, training)*1/(1-p)
return mask*input
Here is your implyment of complex dropout https://github.com/wavefrontshaping/complexPyTorch/blob/master/complexFunctions.py#L21-L23 we know that dropout process is disabling some cell in a certain pdf., so for the complex dropout, real part and imag part shoud share the same "diabale cell index", but How can I guarantee that their indices are the same in above implyment?