rlouf / mcx

Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://rlouf.github.io/mcx
Apache License 2.0
325 stars 17 forks source link

generative_function is broken for models with non-optional args #102

Closed elanmart closed 3 years ago

elanmart commented 3 years ago

Given the model:

@mcx.model
def model(x):
    a <~ dist.Normal(0, 1)
    return a

This code: https://github.com/rlouf/mcx/blob/master/mcx/model.py#L218-L220

Will generate a function with a signature model_sample_posterior_predictive(rng_key, x, a)

This function will then be called here: https://github.com/rlouf/mcx/blob/master/mcx/model.py#L226-L228 passing *self.trace.chain_values(self.chain_id) as x.

The fix is to modify the generative_function.__call__ method to

    def __call__(self, rng_key, *args, **kwargs) -> jnp.DeviceArray:
        """Call the model as a generative function."""

        posterior_kwd = {k: self.trace[k].reshape(-1) for k in self.trace.keys()}
        return self.call_fn(
            rng_key, *args, **kwargs, **posterior_kwd,
        )
elanmart commented 3 years ago

The source of the issue is that order of arguments is different for models with optional and non-optional args

@mcx.model
def model(x):
    a <~ dist.Normal(0, 1)
    return a

gives

<function __main__.model_sample_posterior_predictive(rng_key, x, a)>

While

@mcx.model
def model(x = None):
    a <~ dist.Normal(0, 1)
    return a

gives

<function __main__.model_sample_posterior_predictive(rng_key, a, x=None)>
elanmart commented 3 years ago

I'd actually suggest to make it impossible to pass non-keyword arguments to those generated functions, I guess it would simplify things a lot.

Alternative would be to always generate a signature of type func(*params, *model_args, **model_kwargs).

rlouf commented 3 years ago

I'd actually suggest to make it impossible to pass non-keyword arguments to those generated functions, I guess it would simplify things a lot.

Yes, that would make things a lot easier. The reason why there is a difference is that I didn't need to pay attention to that: in the core, everything is passed as a keyword argument.

Alternative would be to always generate a signature of type func(*params, *model_args, **model_kwargs).

In terms of usability, which of these two options would you say is best?

rlouf commented 3 years ago

(again, thank you for raising an issue!)

elanmart commented 3 years ago

In terms of usability, which of these two options would you say is best?

After a second thought, actually I don't think it would be a good idea to go with "keyword-only" option.

So the generated functions would have to have the signature func(key, *params, *model_args, **model_kwargs)

rlouf commented 3 years ago

Basically the same signature as the original function?

rlouf commented 3 years ago

@elanmart There is actually a good reason why things are as they currently are; I will go with your solution and will try to write the patch to fix #90 at the same time.

elanmart commented 3 years ago

@rlouf thanks a lot! In the meantime I've just patched my local version of mcx so I can carry on without issues. For #90 I also have a small workaround, so it's all fine for now :)

rlouf commented 3 years ago

Cool! I'll go through the outstanding issues and missing features next week or the week after that. Hopefully this and the key splitting issue will be solved. Let me know if you find anything else.

rlouf commented 3 years ago

@elanmart This was fixed in #106