rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.5k stars 909 forks source link

CNF Implementation details #37

Closed kmkolasinski closed 5 years ago

kmkolasinski commented 5 years ago

Hi, I'm trying to play with Neural ODEs and do some reimplementations in Tensorflow. I was able to implement basic solver which works with the spiral problem in your examples. However, I got stuck on CNF implementation :/

Here is my current version of planar flow implemented as a Keras Model:

class PlanarFlow(tf.keras.Model):
    def __init__(self, dim, scale=1.0, bias_scale=1.0, w_scale=1.0):
        super().__init__()
        self.weight = tf.Variable(w_scale * (w_initializer([1, dim])), name='weight')
        self.bias = tf.Variable(scale * w_initializer([1, 1]), name='bias')
        self.scale = tf.Variable(bias_scale * (w_initializer([1, dim])), name='scale')
        self.activation = tf.nn.tanh

    def linear(self, z):
        return tf.reduce_sum(z * self.weight, axis=-1, keepdims=True) + self.bias    

    def call(self, inputs, **kwargs):
        t, z, logdet = inputs

        with tf.GradientTape() as g:
            g.watch(z)
            logits = self.linear(z)  
            hfunc = self.activation(logits)        

        new_z = self.scale * hfunc  # (batch_size, 2)
        # compute gradient: df(z) / dz
        gradients = g.gradient(
            target=hfunc,
            sources=z,
        )
        # trace  - T (batch_size, 1)
        new_logdet = - tf.matmul(gradients, tf.transpose(self.scale))        
        # return dynamics gradients for z and log(p(z))
        return new_z, logdet + new_logdet

This implementation has missing gating mechanism (some NN as explained in paper), which is not described in the paper. How was the gating mechanism implemented in your case ?

Secondly, please correct me if I understand the algorithm correctly. Once I will have correct implementation of Planar flow I have to create a combination of them (Eq. 10), like this:

class MultipleFlow(tf.keras.Model):

    def __init__(self, num_flows, flow_factory=lambda: PlanarFlow(2)):
        super().__init__()
        self.flows = [flow_factory() for _ in range(num_flows)]

    def call(self, inputs, **kwargs):
        t, z, logdet = inputs
        for flow in self.flows:
            z_k, logdet_k = flow(inputs)
            z = z + z_k
            logdet = logdet + logdet_k            
        return t, z, logdet

# create CNF with M = 32 ???
cnf = MultipleFlow(num_flows=32)

And finally I want to maximize probability of energy function from Fig 4:


# pseudo code:
z_samples = tf.random_normal([512, 2])
# integrate dynamics
z_output, z_logdet = odeint(cnf, [z_samples, 0.0], tstart=1, tend=0, num_steps=100) 
# potential_energy => p(z_output) = exp(- U(z_output)) => - log(p(z_output)) = U(z_output)
loss = - potential_energy(z_output) - z_logdet
# maximize this 
loss = tf.reduce_mean(loss)

Is this approach correct? Thank you in advance !

rtqichen commented 5 years ago

Hi @kmkolasinski that looks about right. Just make sure you have the signs in the correct direction.

The planar CNF shown in the paper used a 1-hidden layer hypernet that takes in t, and outputs multiple weight, bias, and scale parameters. The gate mechanism was implemented by outputing the scale parameter as a multiplication between a regular output and a sigmoid output of the hypernet. See https://gist.github.com/rtqichen/91924063aa4cc95e7ef30b3a5491cc52.

kmkolasinski commented 5 years ago

Thanks a lot @rtqichen !