ivannz / cplxmodule

Complex-valued neural networks for pytorch and Variational Dropout for real and complex layers.
MIT License
138 stars 27 forks source link

Complex Backprop and Learning speed #2

Closed pfeatherstone closed 4 years ago

pfeatherstone commented 4 years ago

Question: In the real domain, you only require differentiability for back-propagation to work. In the complex domain, you need holomorphism. Now pytorch doesn't check this because it doesn't natively support complex numbers. Do you think there could be a learning/training problem with back-propagation if some of the functions don't support Cauchy-Riemann equations?

Question: In complex analysis, a function f(z) has two derivatives: df/dz and df/dz*. If the forward passes are implemented correctly, as you have done, is back-propagation well defined? Specifically, do you get both derivatives being back-propagated?

ivannz commented 4 years ago

Differentiability in the complex domain requires that a C to C function satisfy Cauchy-Riemann conditions, which imply that a holomorphic C to R function is constant. Therefore there cannot be an end-to-end holomorphic learning problem, since we ultimately use some sort of non-constant real-valued fitness function L, e.g. empirical/theoretical risk, negative log-likelihood, which makes L o f non-holomorphic.

To be able to learn complex-valued networks we have to forgo the C-differentiability requirement. This can be done by using C-R (Wirtinger) calculus, which extends holomorphic calculus to non-holomorphic functions by considering C-domain functions as defined on R^2 with independent z and conj(z). It still works correctly and gives correct derivatives in the holomorphic case. Since CR calculus treats complex parameter and its conjugate as independent variables, it is possible to equivalently parameterize the function f(z) = F(z, conj(z)) as f(z) = f(a + j b) = H(a, b) -- a R^2 to C function. Thus the complex-valued arithmetic is emulated within double-real network through skew-symmetry constraint on linear operations. Via CR calculus we get df/dz = 1/2(dH/da - j dH/db) and df/d conj(z) = 1/2(dH/da + j dH/db) for f C to C. The gradient of real-valued loss f w.r.t. z, i.e. " the direction of maximum increase" is df/d conj(z), not df/dz, which in the case of R-valued loss is corect: conj(df/dz) = d conj(f)/d conj(z) = df/d conj(z), which autodiff of pytorch effectively computes. see Hunger (2007).

Other references include Trabelsi et al. (2018), Zhang et al. (2014), Adali et al. (2011), Benvenuto and Piazza (1992), appendix D of Nazarov and Burnaev (2020) or a discussion here.

pfeatherstone commented 4 years ago

@ivannz Thank you for the detailed answer.

pfeatherstone commented 4 years ago

Might need to brush up on some complex analysis

pfeatherstone commented 4 years ago

@ivannz I thought i might carry on asking questions here rather than opening seperate issues. Hence why i've edited the title.

I've ported a mobilenetv2 network from torchvision to use modules in this repo, namely: CplxConv1d, CplxBatchNorm1d and CplxAdaptiveModReLU. I'm training this network to classify some complex valued 1D data. But it's very very slow to learn. Have you noticed the same thing when applying these modules to your datasets? I've only tried this in one domain, so hard to tell if its the nature of the modules or if it's the problem domain.

ivannz commented 4 years ago

I have used this module for both classification and regression tasks:

If by learning speed you mean its overall arithmetic complexity, then yes -- a complex network uses 4 times as many multiplications as a real-valued network with the same number of intermediate features, i.e. the number of linear outputs or convolutional channels and not overall parameters. Even if one compares a complex network to a real one with the same number of floating-point parameters, then a complex network is still slower, but not dramatically. Please, also bear in mind the discussion in issue #1 -- cplxmodule is a pure python extension for pytorch and, although it offloads all computation to torch, there is still some insignificant overhead.

If by learning speed you mean the convergence rate of gradient descent or the rate of train loss decrease, I haven't measured it per se, but at the same time haven't noticed anything suggesting, that complex nets are slower to learn. The test performance depends on the dataset, but in my experience complex-valued networks seldom outperformed real-valued networks, mostly on par. See Nazarov and Burnaev (2020) and references therein.

PS: I'd rather you created another issue to keep unrelated discussions separate.

pfeatherstone commented 4 years ago

@ivannz Again thank you for your response and I will make sure to open new issues for separate discussions. Regarding classification, did you use the RadioML dataset? If so I found that the modulation labels are wrong. If not, did you use another publicly available dataset?

ivannz commented 4 years ago

I used a private dataset for a digital predistortion task, i.e. approximating a perturbation of an input signal, so that a power amplifier would operate in linear gain regime.