jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.49k stars 2.8k forks source link

Disable gradients with special type #2588

Open awav opened 4 years ago

awav commented 4 years ago

I have difficulties with structures that have non-differentiable parts. This causes a lot of issues in my current design. E.g. I want a differentiable structure that has parts that are discrete.

I decided to oversmart JAX and make an object that behaves like a scalar (an array) and pass it to the initialiser of DifferantiableView. But, it gives me a different issue related to tracking gradients and allowed types for tracking.

import typing
import jax.numpy as jnp

class NonDifferentiable(int):
    pass

class DifferentiableView(typing.NamedTuple):
    parameter: jnp.array
    num_data: int

def func(s):
    return jnp.exp(s.parameter) ** 2 / s.num_data

s = DifferentiableView(2.0, NonDifferentiable(100))
grad_func = jax.grad(func)(s)

Gives an error:

TypeError: <class '__main__.NonDifferentiable'> is not a valid Jax type

I think it would be pretty cool to have such recognisable class for labelling (wrapping) scalars and shaped arrays. For JAX that would be a signal to stop tracing gradients for a variable.

awav commented 4 years ago

Hello @mattjj, what do you think about the idea, and is that clear what the problem is?

yingted commented 4 years ago

Have you tried flax.struct?

import jax
import jax.numpy as jnp
import flax.struct

@flax.struct.dataclass
class DifferentiableView():
    parameter: jnp.array
    num_data: int = flax.struct.field(pytree_node=False)

def func(s):
    return jnp.exp(s.parameter) ** 2 / s.num_data

s = DifferentiableView(2.0, 100)
grad_func = jax.grad(func)(s)
grad_func
awav commented 4 years ago

@yingted, no, I didn't know about flax. Thanks a lot for the reference! @mattjj, I think that could be supported by JAX out of the box, and implementation looks trivial.

sharadmv commented 4 years ago

The Pytree API that JAX offers can also accomplish this.

@dataclasses.dataclass
class DifferentiableView:
   parameter: jnp.array
   num_data: int

def flatten(view):
  return (view.parameter,), (view.num_data,)

def unflatten(data, xs):
  num_data, = data
  parameter, = xs
  return DifferentiableView(parameter, num_data)

jax.tree_util.register_pytree_node(DifferentiableView, flatten, unflatten)

The notion of what is differentiable can be handled by the flatten/unflatten functions. Note that you don't have to use a dataclass and can use whatever class you want.

awav commented 4 years ago

@sharadmv, yes, I'm aware of this solution. The problem is that every time when a user creates a new class, the user will have to write custom encoding and decoding - that's exactly what a user doesn't want to do and sometimes it is not even a trivial thing to do. Undifferentiable types are ubiquitous. "dataclasses" and "namedtuples" are major structures to organise a model in JAX. Ideally, the JAX must understand straightaway what can be differentiated in a structure and provide a tool to control this behaviour when it cannot deduce either the object is differentiable or not.

NeilGirdhar commented 4 years ago

My feeling is that this isn't JAX's problem, but rather a design problem on the user side. My design preference would be to break up your object into differentiable and nondifferentiable components before grad:

jax.grad(func)(parameter, num_data)

This seems much closer to "saying what you mean" to me.

After all, you might choose to differentiate with respect to one thing in one context and another thing in another context:

jax.grad(func)(state, parameter)

Trying to solve this problem using JAX's partitioning of structured data into pytrees and hashable components runs into problems down the line. If you try to put your "non-differentiable components" into the set of the hashable components, then you can never have a tracer in their place. This can happen if, for example, they are produced from jitted arguments. If this happens, you will find yourself debugging leaked tracers.

This confusion between "static arguments" and "nondifferentiable arguments" is the cause of https://github.com/google/jax/issues/2912.

awav commented 4 years ago

My feeling is that this isn't JAX's problem, but rather a design problem on the user side.

I disagree. The very first question when I started looking at JAX and recommended it to others was: "How do I stop gradient propagation without changing the code?". This is a UX question not a design issue on the user's side, https://twitter.com/srush_nlp/status/1260583364102434817?s=20.

The point is that the differentiable parts are not static. A user may decide change differentiation path w.r.t a structure. Therefore there is no way to know in advance how to split the object into differentiable and non-differentiable parts, especially when you have the ability to extend that differentiable structure, the number of parameters is unknown, and the underlying objective algorithm is the same "regardless" of parameters.

Here is an example (apologies this is almost a duplication of what I showed before):

@dataclass(frozen=True)
class Kernel:
    variance: float

@dataclass(frozen=True)
class SquaredExponential(Kernel):
    lenthscale: float

@dataclass(frozen=True):
class UnkownKernelWithAdditional100HyperParameters(Kernel):
    ...

def loss(kernel: Kernel) -> jnp.ndarray:
    pass # Compute marginal likelihood of the Gaussian process here

Often we will need to compute gradients w.r.t. lengthscale or variance only and work with the same code that uses that kernel. I propose to introduce a proxy type that will stop gradient computation for plain types, and it will allow successful interaction with other structures:

# hypothetical proxy class
class NonDifferentiable:
    def __init__(self, wrap_this_object):
          ...

a = jnp.zeros(10)
non_differentiable_a = NonDifferentiable(a)
a2 = non_differentiable_a + a    # Still works!

Therefore we could do this:

k1 = Kernel(1.0)
k2 = SquaredExponential(NonDifferentiable(1.0), 1.0)

jax.grad(loss)(k1)
jax.grad(loss)(k2)

After all, you might choose to differentiate with respect to one thing in one context and another thing in another context.

I'm sorry, I don't understand your example here. Can you elaborate? Thanks!

then you can never have a tracer in their place

I wouldn't say never. As far as I understand nobody has tried to implement it or fix it if that's a bug. Also, it doesn't mean that that level of flexibility cannot and should not be provided to a user. TensorFlow and PyTorch have this naive and simple feature since the beginning. In turn, JAX can give a hybrid solution to trainable arrays (structures) that will benefit lots of potential users, additionally to really cool features like vmap and pmap.

NeilGirdhar commented 4 years ago

Just so it's clear, I'm not a JAX developer. I'm just curious about how all this fits together, and I learn a lot reading these issues. I have written a lot of JAX code over the last few months and your issue resonated with me, which Is why I answered.

The point is that the differentiable parts are not static. A user may decide change differentiation path w.r.t a structure. Therefore there is no way to know in advance how to split the object into differentiable and non-differentiable parts, especially when you have the ability to extend that differentiable structure, the number of parameters is unknown, and the underlying objective algorithm is the same "regardless" of parameters.

Totally agree.

Often we will need to compute gradients w.r.t. lengthscale or variance only and work with the same code that uses that kernel.

Yup, that makes sense.

I'm sorry, I don't understand your example here. Can you elaborate? Thanks!

I was just trying to say exactly the same thing that you said about "how there's no way to know in advance how to split the object..." Looks like we agree though :)

I wouldn't say never. As far as I understand nobody has tried to implement it or fix it if that's a bug.

I was talking about the problem with the solutions that were suggested to you in this thread whereby you use the hashable output of tree_flatten to mark nondifferentiable arguments. Maybe I don't understand your problem, but if it's that you have non-concrete, nondifferentiable arguments, then these solutions will not work in general since those arguments are treated like static arguments, and it doesn't make sense for tracers to be treated as static arguments. They can't be hashed, which is what the jitter needs to do in order to prevent recompilation.

I think it would be nice if JAX were a bit more proactive about preventing you from accidentally sending tracers down that path, but otherwise, you're in for a lot of debugging (what happened to me) if you let that happen. I was just trying to warn you to save you the pain I went through :)

Also, it doesn't mean that that level of flexibility cannot and should not be provided to a user. TensorFlow and PyTorch have this naive and simple feature since the beginning. In turn, JAX can give a hybrid solution to trainable arrays (structures) that will benefit lots of potential users, additionally to really cool features like vmap and pmap.

I'm just curious, but what's the problem with grad(loss)(kernel).variance? I think it's only expensive to compile; computationally, it should really fast, but maybe one of the JAX devs could confirm.

If you don't the simple solution, maybe you could propose a fancier interface for grad like:

attribute_grad(loss, 'variance')(kernel)  # Returns d loss by d kernel.variance.

Or else, what do you think of this more general solution?

@dataclass
class Kernel:
    variance: float
    scale: float

def loss(kernel: Kernel):
    return 2 * kernel.variance + 3 * kernel.scale

def fancy_grad(derivand, extract, replace):
    def f(derivator, compound_object, *args):
        return derivand(replace(compound_object, derivator), *args)
    def g(compound_object, *args):
        return grad(f)(extract(compound_object), compound_object, *args)
    return g

kernel = Kernel(1.4, 3.2)

# Prints d loss by d kernel.variance (2.0).
print(fancy_grad(loss,
                 lambda kernel: kernel.variance,
                 lambda kernel, variance: kernel.replace(variance=variance))(kernel))
awav commented 4 years ago

I was talking about the problem with the solutions that were suggested to you in this thread whereby you use the hashable output of tree_flatten to mark nondifferentiable arguments. Maybe I don't understand your problem, but if it's that you have non-concrete, nondifferentiable arguments, then these solutions will not work in general since those arguments are treated like static arguments, and it doesn't make sense for tracers to be treated as static arguments. They can't be hashed, which is what the jitter needs to do in order to prevent recompilation. I think it would be nice if JAX were a bit more proactive about preventing you from accidentally sending tracers down that path, but otherwise, you're in for a lot of debugging (what happened to me) if you let that happen. I was just trying to warn you to save you the pain I went through :)

Thanks a lot! I thought that flatten can split up an object on static and non-static parts, and then build it up back. Therefore a hash can be computed with non-static objects only. Caveat: I don't possess deep knowledge on this topic (yet).

I'm just curious, but what's the problem with grad(loss)(kernel).variance? I think it's only expensive to compile; computationally, it should really fast, but maybe one of the JAX devs could confirm.

In this particular case, it doesn't matter. But, there are situations, when we could switch on/off parameters like covariance matrix of a Gaussian distribution (sparse Gaussian processes). Where the computation of gradients takes a route of Cholesky and other expensive operators.

If you don't the simple solution, maybe you could propose a fancier interface for grad like: ... Or else, what do you think of this more general solution? ...

Both proposals are good. Although, I might have an example where this would not work (or can be not user friendly): compositional kernels like sum, product and hence their nested combinations.

compositional_kernel = kernel1 + kernel2 * kernel3   # all kernels have variance parameter

Writing extractors and replace function for such compositions would be a nightmare. But there might exist another design, better one, for kernel compositions.