bethgelab / decompose

Blind source separation based on the probabilistic tensor factorisation framework
MIT License
43 stars 11 forks source link

Is there a way to get the hyperparameters from the model? #3

Closed jenwallace closed 6 years ago

jenwallace commented 6 years ago

I have one data set where the source extraction looks good and another where it doesn't do so well, and I'm curious to explore how this relates to the prior distributions I picked and the hyperparameters the model found for those priors. Thank you!

aboettcher commented 6 years ago

@jenwallace thanks for the suggestion. It makes absolutely sense to have easy access to them. I think of adding a parameters property to trained DECOMPOSE model instances that contain a nested dictionary with the parameters.

aboettcher commented 6 years ago

With #5 this should be implemented. You will need to look for a key starting with prior0/ and prior1/. For example model.parameters['prior0/CenNormal/tau'] should give you the precisions of the normal priors of the first factor.

jenwallace commented 6 years ago

Thanks so much for your quick response, this will be really helpful! I tried the new version with the test code, but now I'm getting an error:

\decompose\models\tensorFactorisation.py in __model(cls, priorTypes, M, K, stopCriterion, phase, dtype, reuse, trainsetProb, doRescale, transform, suffix) 312 for f, priorType in enumerate(priorTypes): 313 prior = priorType.random(shape=(K,), latentShape=(M[f],), --> 314 name=f"prior{suffix}{f}", dtype=dtype) 315 priors.append(prior) 316 tefa = cls.random(priorU=priors, likelihood=likelihood, M=M, K=K,

TypeError: random() missing 1 required positional argument: 'self'

Any idea what the problem might be?

aboettcher commented 6 years ago

That is due to a non backward compatible change I made to the API. The prior distributions have now to be instantiated:

# create an instance of a decompose model
model = DECOMPOSE(modelDirectory="/tmp/myNewModel",
                  priors=[CenNormal(), CenNormal()],
                  n_components=3)

Adding the brackets () behind the prior classes should fix the error. Unfortunately I cannot yet guarantee that the API will stay the same.