DeepGraphLearning / torchdrug

A powerful and flexible machine learning platform for drug discovery
https://torchdrug.ai/
Apache License 2.0
1.44k stars 200 forks source link

ImportError: No module named 'embedding' #161

Closed iamme1234567 closed 1 year ago

iamme1234567 commented 1 year ago

Below is my code

import torch
from torchdrug import core, datasets, tasks, models
from torchdrug.models import RotatE

import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt

dataset = datasets.FB15k237("~/kg-datasets/")
train_set, valid_set, test_set = dataset.split()

model: RotatE = models.RotatE(num_entity=dataset.num_entity,
                      num_relation=dataset.num_relation,
                      embedding_dim=2048, max_score=9)

task = tasks.KnowledgeGraphCompletion(model, num_negative=256,
                                      adversarial_temperature=1)

optimizer = torch.optim.Adam(task.parameters(), lr=2e-5)
solver= core.Engine(task, train_set, valid_set, test_set, optimizer,
                     gpus=[0], batch_size=1024)
solver.train(num_epoch=100)
solver.evaluate("valid")

Below is the error:

Traceback (most recent call last):
  File "C:\Users\lenovo\PycharmProjects\pythonProject2\main.py", line 23, in <module>
    solver.train(num_epoch=100)
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torchdrug\core\engine.py", line 155, in train
    loss, metric = model(batch)
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torchdrug\tasks\reasoning.py", line 85, in forward
    pred = self.predict(batch, all_loss, metric)
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torchdrug\tasks\reasoning.py", line 160, in predict
    pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric)
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torchdrug\models\embedding.py", line 191, in forward
    score = functional.rotate_score(self.entity, self.relation * self.relation_scale,
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torchdrug\layers\functional\embedding.py", line 266, in rotate_score
    score = RotatEFunction.apply(entity, relation, h_index, t_index, r_index)
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torchdrug\layers\functional\embedding.py", line 108, in forward
    forward = embedding.rotate_forward_cuda
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torchdrug\utils\torch.py", line 27, in __getattr__
    return getattr(self.module, key)
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torchdrug\utils\decorator.py", line 102, in __get__
    result = self.func(obj)
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torchdrug\utils\torch.py", line 31, in module
    return cpp_extension.load(self.name, self.sources, self.extra_cflags, self.extra_cuda_cflags,
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torch\utils\cpp_extension.py", line 1079, in load
    return _jit_compile(
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torch\utils\cpp_extension.py", line 1317, in _jit_compile
    return _import_module_from_library(name, build_directory, is_python_module)
  File "C:\Users\lenovo\.conda\envs\td2\lib\site-packages\torch\utils\cpp_extension.py", line 1699, in _import_module_from_library
    file, path, description = imp.find_module(module_name, [path])
  File "C:\Users\lenovo\.conda\evns\td2\lib\imp.py", line 296, in find_module
    raise ImportError(_ERR_MSG.format(name), name=name)
ImportError: No module named 'embedding'

I've did some research but couldn't figure out why, can anyone help me here?

KiddoZhu commented 1 year ago

Hi! The embedding module is intended to be generated by JIT. Could you try to clean the JIT cache using the instructions here?