dice-group / dice-embeddings

Hardware-agnostic Framework for Large-scale Knowledge Graph Embeddings
MIT License
48 stars 13 forks source link

PYKEEN Integration #54

Closed Demirrr closed 1 year ago

Demirrr commented 1 year ago

Update at 07.07.2023:

Currently, we can train many KGE models implemented in pykeen within our framework, e.g.,

python main.py --model Pykeen_MuRE --num_epochs 10 --batch_size 256 --lr 0.1 --trainer "PL" --num_core 4 --scoring_technique KvsAll --pykeen_model_kwargs embedding_dim=64

Demirrr commented 1 year ago

Assigned to a student

Demirrr commented 1 year ago

As shown in the development branch, we can now use most knowledge graph embedding models provided in Pykeen. Yet, these models cannot be loaded and used within KGE()

(dice) demir@demir:~/Desktop/Softwares/dice-embeddings$ python main.py --model Pykeen_MuRE --num_epochs 10 --batch_size 256 --lr 0.1 --trainer "PL" --num_core 4 --scoring_technique KvsAll --pykeen_model_kwargs embedding_dim=64
Start time:2023-07-07 13:44:26.150156
*** Read or Load Knowledge Graph  ***
...
------------------- Description of Dataset KGs/UMLS -------------------
Number of entities:135
Number of relations:92
Number of triples on train set:10432
Number of triples on valid set:1304
Number of triples on test set:1322
...
MyLCWALitModule(
  (model): MuRE(
    (loss): MarginRankingLoss(
      (margin_activation): ReLU()
    )
    (interaction): MuREInteraction()
    (entity_representations): ModuleList(
      (0): Embedding(
        (_embeddings): Embedding(135, 64)
      )
      (1-2): 2 x Embedding(
        (_embeddings): Embedding(135, 1)
      )
    )
    (relation_representations): ModuleList(
      (0-1): 2 x Embedding(
        (_embeddings): Embedding(92, 64)
      )
    )
    (weight_regularizers): ModuleList()
  )
  (loss): MarginRankingLoss(
    (margin_activation): ReLU()
  )
)
Initializing Dataset... Took 0.0228 seconds
Took 0.0240 seconds | Current Memory Usage  520.83 in MB
Took 0.0242 seconds | Current Memory Usage  520.83 in MB
Initializing Dataloader...      Took 0.0003 seconds | Current Memory Usage  520.83 in MB
Epoch 9: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 24.29it/s, loss=0.0188]
*** Save Trained Model ***
Took 0.0016 seconds | Current Memory Usage  673.28 in MB
Total computation time: 3.401 seconds
Evaluate MuRE on Train set: Evaluate MuRE on Train set
{'H@1': 0.9430598159509203, 'H@3': 0.9958780674846626, 'H@10': 1.0, 'MRR': 0.9687399576394975}
Evaluate MuRE on Validation set: Evaluate MuRE on Validation set
{'H@1': 0.6855828220858896, 'H@3': 0.8826687116564417, 'H@10': 0.9693251533742331, 'MRR': 0.7933095053598029}
Evaluate MuRE on Test set: Evaluate MuRE on Test set
{'H@1': 0.705748865355522, 'H@3': 0.8872919818456884, 'H@10': 0.9674735249621785, 'MRR': 0.8049328531462063}
Total computation time: 4.849 seconds

To upload this pre-trained model, we use the KGE() class

from dicee import KGE
KGE('Experiments/2023-07-07 13:44:26.149833')

Yet, this results in the


RuntimeError: Error(s) in loading state_dict for MyLCWALitModule:
    size mismatch for model.entity_representations.0._embeddings.weight: copying a param with shape torch.Size([135, 64]) from checkpoint, the shape in current model is torch.Size([14, 64]).
    size mismatch for model.entity_representations.1._embeddings.weight: copying a param with shape torch.Size([135, 1]) from checkpoint, the shape in current model is torch.Size([14, 1]).
    size mismatch for model.entity_representations.2._embeddings.weight: copying a param with shape torch.Size([135, 1]) from checkpoint, the shape in current model is torch.Size([14, 1]).
    size mismatch for model.relation_representations.0._embeddings.weight: copying a param with shape torch.Size([92, 64]) from checkpoint, the shape in current model is torch.Size([55, 64]).
    size mismatch for model.relation_representations.1._embeddings.weight: copying a param with shape torch.Size([92, 64]) from checkpoint, the shape in current model is torch.Size([55, 64]).

@renzhonglu11 Could you please take a look at this problem in the Pykeen branch.

Demirrr commented 1 year ago

I took care of it @renzhonglu11 :)