tky823 / DNN-based_source_separation

A PyTorch implementation of DNN-based source separation.
286 stars 50 forks source link

Conv-TasNet Cumulative Layer Norm Bug? #101

Closed michelg10 closed 2 years ago

michelg10 commented 2 years ago

shouldn't lines 78-92 be

` step_sum = input.sum(dim=1) # -> (batch_size, T) cum_sum = torch.cumsum(step_sum, dim=1) # -> (batch_size, T)

cum_num = torch.arange(C, C(T+1), C, dtype=torch.float) # -> (T, ): [C, 2C, ..., T*C] cum_mean = cum_sum / cum_num # (batch_size, T) cum_var = (cum_sum - cum_mean)**2/cum_num

cum_mean = cum_mean.unsqueeze(dim=1) cum_var = cum_var.unsqueeze(dim=1)

output = (input - cum_mean) / (torch.sqrt(cum_var) + eps) * self.gamma + self.beta `

according to the Conv-TasNet paper?

tky823 commented 2 years ago

My CumulativeLayerNorm1d is based on official implementation. You can compare it with mine here. If there is any possibility that I have misunderstood something, let me know in more detail.

michelg10 commented 2 years ago

compared to the official implementation, your code writes cum_var=cum_squared_mean - cum_mean*2 while the official code writes cum_var = (cum_pow_sum - 2cum_mean*cum_sum) / entry_cnt + cum_mean.pow(2).

then cum_squared_mean should = (cum_pow_sum - 2cum_meancum_sum) / entry_cnt In your code cum_squared_mean=cum_squared_sum/cum_num. as entry_cnt=cum_sum, cum_squared_sum should = cum_pow_sum-2cum_meancum_sum

however cum_squared_sum is defined as torch.cumsum(step_pow_sum, dim=1) which equals cum_pow_sum in the official implementation, so you're missing 2cum_meancum_sum.

am I missing anything here? or is this omitted for a speed / accuracy tradeoff?

tky823 commented 2 years ago

formula formula formula formula formula formula

In official implementation: formula

In my repo: formula formula

I renamed some variables in check_layer_norm.ipynb at https://github.com/tky823/DNN-based_source_separation/issues/101#issuecomment-975686051 for readability.

michelg10 commented 2 years ago

My bad! thank you for clearing it up!