Closed Demirrr closed 2 years ago
For your question regarding the subnetwork you are missing an argument "subnetwork_indices". Before Laplace you need to apply the following steps first subnetwork_mask = ModuleNameSubnetMask(model, module_names=module_name (eg. 'fc1')) subnetwork_mask.select() subnetwork_indices = subnetwork_mask.indices
if you run it in cuda you might need to adjust the last part to
subnetwork_indices=subnetwork_mask.indices.type(torch.LongTensor)
@georgezefko Very appreciated. Problem is solved!
Thanks a lot @Demirrr for raising this issue and to @georgezefko for commenting and providing a solution to this (and thanks to both for your interest in and usage of our library, we really appreciate it) -- I'm glad that it works now, and hope that you'll get reasonable results; please let us know if there's anything else we can help with!
Just FYI, in addition to ModuleNameSubnetMask
there exist other ways to select/specify subnetworks -- there are a few examples in the README of this repository (need to scroll a bit down to see the code examples). A list of all currently supported options can be found as subclasses of SubnetMask
within the following file: https://github.com/AlexImmer/Laplace/blob/main/laplace/utils/subnetmask.py
@georgezefko The issue you mention when using CUDA will soon be resolved once PR #87 is merged, addressing issue #86 raised earlier -- thanks again for your interest and help!
Dear all,
Thank you for this open-source project. I did enjoy your paper and presentation. I wanted to use laplace API to calibrate knowledge graph embedding (KGE) models in the link prediction task. Yet, I could not do it. In below, I aimed to briefly described the situation
where
forward_k_vs_all(self, x)
returns a 2D torch.tensor. After the model is trained , the below part is executedand NotImplementedError(NotImplementedError: Extension saving to diag_ggn_exact does not have an extension for Module <class 'core.models.real.DistMult'> is raised.
Similarly, using
raises TypeError: init() missing 1 required positional argument: 'subnetwork_indices'
I was wondering whether you have any suggestion pertaining the above descired problem.
Cheers!