ucl-bug / jaxdf

A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
GNU Lesser General Public License v3.0
117 stars 7 forks source link

Running the example on the README does not work. #118

Closed elma16 closed 1 year ago

elma16 commented 1 year ago

Describe the bug When I run the example given in the readme, it gives me the following error

SignatureError: The argument 'params' must be a keyword argument in . Example: def evaluate(x, *, params): ...

To Reproduce Steps to reproduce the behavior:

  1. Execute the code example given in the readme

Expected behavior No error. Computing the gradient at the end.

Desktop (please complete the following information):

Additional context Add any other context about the problem here.

astanziola commented 1 year ago

Thanks for noticing it, I forgot to update that! And I should make sure to throw a clearer error...

Fundamentally, the reason for that error is that an operator is expecting a keyword parameter called params.

From the documentation of operator:

Keyword arguments are defined after the * in the function signature.

@operator
def my_operator(x: FourierSeries, *, dx: float, params=None):
  ...

The argument params is mandatory and it must be a keyword argument. It is used to pass the parameters of the operator, for example the stencil coefficients of a finite difference operator.

The default value of the parameters is specified by the init_params function, as follows:


def params_initializer(x, *, dx):
  return {"stencil": jnp.ones(x.shape) * dx}

@operator(init_params=params_initializer)
def my_operator(x, *, dx, params=None):
  b = params["stencil"] / dx
  y_params = jnp.convolve(x.params, b, mode="same")
  return x.replace_params(y_params)

The default value of params is not considered during computation. If the operator has no parameters, the init_params function can be omitted. In this case, the params value is set to None.

For constant parameters, the constants function from jaxdf can be used:

@operator(init_params=constants({"a": 1, "b": 2.0}))
def my_operator(x, *, params):
  return x + params["a"] + params["b"]

For the readme example, the operator needs to be defined as

@operator
def custom_op(u, *, params=None):
  grad_u = jops.gradient(u)
  diag_jacobian = jops.diag_jacobian(grad_u)
  laplacian = jops.sum_over_dims(diag_jacobian)
  sin_u = jops.compose(u)(jnp.sin)
  return laplacian + sin_u

I will change it in the README now 😄

astanziola commented 1 year ago

Should be fixed now, but please feel free to reopen if that's not the case