yourh / AttentionXML

Implementation for "AttentionXML: Label Tree-based Attention-Aware Deep Model for High-Performance Extreme Multi-Label Text Classification"
238 stars 40 forks source link

Loading a trained model with a different number of GPUs #38

Open katjakon opened 1 month ago

katjakon commented 1 month ago

Hello, Thank you for your work! In our project, we trained an AttentionXML model on 4 GPUs but are now trying to load it in an environment where only one GPU is available. After modifying the code according to this issue #34, we get the following error:

RuntimeError: Error(s) in loading state_dict for ModuleDict:
        Missing key(s) in state_dict: "Network.attention.attention.weight". 
        Unexpected key(s) in state_dict: "AttentionWeights.emb.0.weight", "AttentionWeights.emb.1.weight", "AttentionWeights.emb.2.weight".

This error occurs only when prediction on Level-1 is performed. No error occurs in the 4 GPU environment. We have already tried to concatenate "AttentionWeights.emb.0.weight", "AttentionWeights.emb.1.weight", "AttentionWeights.emb.2.weight" but they seem to have a different dimension than required.

Do you have any idea how we can get this to work? Best wishes, Katja

yourh commented 1 month ago

We have already tried to concatenate "AttentionWeights.emb.0.weight", "AttentionWeights.emb.1.weight", "AttentionWeights.emb.2.weight" but they seem to have a different dimension than required.

Could you please tell me the dimension of these weights, I think it should be ok to concatenate them.

katjakon commented 4 weeks ago

Thank you for your response! The weights have the following dimensions: (67619, 1024) (67618, 1024) (67618, 1024) When I try to concatenate them, I get the following error:

 RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for ModuleDict:
        size mismatch for Network.attention.attention.weight: copying a param with shape torch.Size([202855, 1024]) from checkpoint, the shape in current model is torch.Size([202856, 1024])

As far as I understand, the required dimension is 202856 but when I concatenate the tensor with the dimensions mentioned above I get 202855.

yourh commented 3 weeks ago

How many labels do you have? 202856 or 202855? I check the sum by using assert sum( == labels_num in the so I'm confused about it.