twoertwein / NeuralMixedEffects

BSD 3-Clause "New" or "Revised" License
3 stars 0 forks source link

NME without z-normalization #1

Open sardnar opened 7 months ago

sardnar commented 7 months ago

Hello, Thank you for this inspirational work. I have been trying to use this code for specific kind of pharmacometric data but I don't want to z-norm the features and outputs. Every time I turn of z-norm, I get strange output that seems to be constrained somehow.

Is there is a workaround for that?

twoertwein commented 7 months ago

Can you please provide more details, ideally a minimal reproducible example?

Two thoughts (sorry, just guessing):

edit:

sardnar commented 7 months ago

sure. I am trying to fit data with 3 dimensions. x is time, y is concentration, and dose is regressor. So, I had to tweak the examples a little to pass dose from meta key. This is different from the example provided with PDsim.

It is similar to here: https://saemixdevelopment.github.io/saemix_bookdown/casestudies.html

But also when I suppress the z-normalize in the case provided in the test file, I get strange results for observed vs predicted plots.

class theophylline(LossModule):
    def __init__(self, **kwargs: Any) -> None:
        """Model has three mixed effects."""
        super().__init__()

        self.ka = torch.nn.Parameter(torch.tensor(1.0))
        self.v = torch.nn.Parameter(torch.tensor(20.0))
        self.cl = torch.nn.Parameter(torch.tensor(0.5))

    def get_parameter_names(self):
        return tuple(name for name in self.__dict__ if not name.startswith('_'))

    def forward(
        self,
        x: torch.Tensor,
        meta: dict[str, torch.Tensor],
        y: Optional[torch.Tensor] = None,
        dataset: str = "",
    ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
        """Predict drug response given the doseage."""

        ka = self.ka
        v = self.v
        k = self.cl/v
        dose = meta["AMT"]
        # print(f'subject: {meta["meta_id"]}')
        # print(f' AMT {torch.unique(dose)}')
        return (dose*ka/(v*(ka-k))*(torch.exp(-k*x)-torch.exp(-ka*x))).float(), meta

I just tried to comment out this part in the transforms, as normalizing time is not going to be a good idea for interpreation.

 data = (data - transform_dict["mean"]) / transform_dict["std"]
 assert np.isfinite(data).all(), key

            # clip outliers at p~0.998
  if key == "x":
       data.clip(-3, 3, out=data)

saemix is working fine for sure, but things will get interesting if I can scale this idea up by integrating patient information !

If I understand your edit correctly, do you mean getting initial estimates from fixed effects first?

twoertwein commented 7 months ago

Wow, I think you are probably the first person except me who uses python_tools :) I hope it is not too frustrating, I wrote it to re-use code blocks I often used.

I just tried to comment out this part in the transforms

If you don't z-norm, you definitely also have to remove the clip command (clips everything larger than 3 standard deviation - assuming it is z-normed).

If I understand your edit correctly, do you mean getting initial estimates from fixed effects first?

Yes, train a normal neural network without random effects and then use the estimated parameter to initialize NME.

I would expect that even a normal neural network will take a long time to converge without z-norming (it takes gradient descent a long time to change parameters, much faster to change from -1 to 1 than from 20 to -20). I would try starting with a larger learning rate (and decaying it as the parameters converge) and letting the model run for many, many, many epochs.

twoertwein commented 7 months ago

Another idea: you might need to adjust simulated_annealing_alpha as the weighted L2 norm of NME might become too strong too quickly if it takes a long time to train the model. I think this is less of an issue with saemix as it uses sampling (can make large jumps, unlike gradient descent).

sardnar commented 7 months ago

thanks a lot for your inputs. It indeed seems to converge way slower than saemix.