probml / dynamax

State Space Models library in JAX
https://probml.github.io/dynamax/
MIT License
701 stars 81 forks source link

Applying constraints for LogisticRegressionHMM #372

Closed gergogomori closed 3 months ago

gergogomori commented 3 months ago

Dear Developers,

I encountered an issue while fitting a LogisticRegressionHMM using the EM algorithm. Specifically, when I impose constraints—such as limiting the weights to be positive—the fitting process results in NaN values. Below, I have included a simple code snippet that reproduces the error.

Since I am relatively new to this package, I would greatly appreciate any guidance or suggestions you could offer.

Have a nice weekend, Gergő

import numpy as np
import tensorflow_probability as tfp

from dynamax.hidden_markov_model import LogisticRegressionHMM

rng = np.random.default_rng()

emissions = rng.choice([0, 1], size=100)
observations = rng.choice([-2, -1, 1, 2], size=100).reshape(-1, 1)

# Without constraints

lrhmm = LogisticRegressionHMM(num_states=2, input_dim=1)
lrhmm_params, lrhmm_props = lrhmm.initialize()

new_lrhmm_params, _ = lrhmm.fit_em(params=lrhmm_params, props=lrhmm_props, emissions=emissions, inputs=observations)

print(new_lrhmm_params.emissions.weights)

# With constraints

lrhmm_const = LogisticRegressionHMM(num_states=2, input_dim=1)
lrhmm_const_params, lrhmm_const_props = lrhmm_const.initialize()

# Constrain weights to be positive
lrhmm_const_props.emissions.weights.constrainer = tfp.substrates.jax.bijectors.Softplus()

new_lrhmm_const_params, _ = lrhmm_const.fit_em(params=lrhmm_const_params, props=lrhmm_const_props, emissions=emissions, inputs=observations)

print(new_lrhmm_const_params.emissions.weights)
gergogomori commented 3 months ago

Initializing the weights with positive numbers resolves the issue. Apologies for bringing this up prematurely.