dice-group / Ontolearn

Ontolearn is an open-source software library for explainable structured machine learning in Python. It learns OWL class expressions from positive and negative examples.
https://ontolearn-docs-dice-group.netlify.app/index.html
MIT License
37 stars 9 forks source link

DRILL crashes from time to time when training on randomly generated learning problems #447

Open MichaelRoeder opened 6 days ago

MichaelRoeder commented 6 days ago

Problem

DRILL crashes sometimes during its training phase because of an assertion that is not fulfilled.

Goal Concept:    Q154797        E^+:[3]  E^-:[3]
Traceback (most recent call last):
  File "/local/upb/users/q/quannian/profiles/unix/cs/Pycharm_Project/Ontolearn-0.7.3/Ontolearn/examples/concept_learning_drill_train.py", line 125, in <module>
    start(parser.parse_args())
  File "/local/upb/users/q/quannian/profiles/unix/cs/Pycharm_Project/Ontolearn-0.7.3/Ontolearn/examples/concept_learning_drill_train.py", line 50, in start
    drill.train(num_of_target_concepts=args.num_of_target_concepts,
  File "/upb/users/q/quannian/profiles/unix/cs/.conda/envs/ontolearn-0.7.3/lib/python3.10/site-packages/ontolearn/learners/drill.py", line 263, in train
    sum_of_rewards_per_actions = self.rl_learning_loop(pos_uri=frozenset(positives),
  File "/upb/users/q/quannian/profiles/unix/cs/.conda/envs/ontolearn-0.7.3/lib/python3.10/site-packages/ontolearn/learners/drill.py", line 224, in rl_learning_loop
    sequence_of_states, rewards = self.sequence_of_actions(root_rl_state)
  File "/upb/users/q/quannian/profiles/unix/cs/.conda/envs/ontolearn-0.7.3/lib/python3.10/site-packages/ontolearn/learners/drill.py", line 458, in sequence_of_actions
    assert current_state.quality > 0
AssertionError

How to recreate

import json
from argparse import ArgumentParser
from ontolearn.triple_store import TripleStoreKnowledgeBase
from ontolearn.triple_store import TripleStore
import numpy as np
from sklearn.model_selection import StratifiedKFold
from ontolearn.utils.static_funcs import compute_f1_score
from ontolearn.knowledge_base import KnowledgeBase
from ontolearn.learning_problem import PosNegLPStandard
from ontolearn.refinement_operators import LengthBasedRefinement
from ontolearn.learners import Drill
from ontolearn.metrics import F1
from ontolearn.heuristics import CeloeBasedReward
from owlapy.owl_individual import OWLNamedIndividual, IRI
from owlapy.render import DLSyntaxObjectRenderer

def start(args):
    kb = TripleStoreKnowledgeBase('http://dice-quan.cs.uni-paderborn.de:9080/sparql')
    drill = Drill(knowledge_base=kb,
                  path_embeddings=args.path_embeddings,
                  refinement_operator=LengthBasedRefinement(knowledge_base=kb),
                  quality_func=F1(),
                  reward_func=CeloeBasedReward(),
                  epsilon_decay=args.epsilon_decay,
                  learning_rate=args.learning_rate,
                  num_of_sequential_actions=args.num_of_sequential_actions,
                  num_episode=args.num_episode,
                  iter_bound=args.iter_bound,
                  max_runtime=args.max_runtime)
    if args.path_pretrained_dir:
        drill.load(directory=args.path_pretrained_dir)
    else:
        drill.train(num_of_target_concepts=args.num_of_target_concepts,
                    num_learning_problems=args.num_of_training_learning_problems)
        drill.save(directory="pretrained_drill")

    with open(args.path_learning_problem) as json_file:
        examples = json.load(json_file)
        print(examples)
    p = examples['positive_examples']
    n = examples['negative_examples']
    kf = StratifiedKFold(n_splits=args.folds, shuffle=True, random_state=args.random_seed)
    X = np.array(p + n)
    Y = np.array([1.0 for _ in p] + [0.0 for _ in n])
    dl_render = DLSyntaxObjectRenderer()
    for (ith, (train_index, test_index)) in enumerate(kf.split(X, Y)):
        train_pos = {pos_individual for pos_individual in X[train_index][Y[train_index] == 1]}
        train_neg = {neg_individual for neg_individual in X[train_index][Y[train_index] == 0]}
        test_pos = {pos_individual for pos_individual in X[test_index][Y[test_index] == 1]}
        test_neg = {neg_individual for neg_individual in X[test_index][Y[test_index] == 0]}
        train_lp = PosNegLPStandard(pos=set(map(OWLNamedIndividual, map(IRI.create, train_pos))),
                                    neg=set(map(OWLNamedIndividual, map(IRI.create, train_neg))))

        test_lp = PosNegLPStandard(pos=set(map(OWLNamedIndividual, map(IRI.create, test_pos))),
                                   neg=set(map(OWLNamedIndividual, map(IRI.create, test_neg))))

        pred_drill = drill.fit(train_lp).best_hypotheses()
        train_f1_drill = compute_f1_score(individuals=frozenset({i for i in kb.individuals(pred_drill)}),
                                          pos=train_lp.pos,
                                          neg=train_lp.neg)
        # () Quality on test data
        test_f1_drill = compute_f1_score(individuals=frozenset({i for i in kb.individuals(pred_drill)}),
                                         pos=test_lp.pos,
                                         neg=test_lp.neg)
        print(
            f"Prediction: {dl_render.render(pred_drill)} | Train Quality: {train_f1_drill:.3f} | Test Quality: {test_f1_drill:.3f} \n")

if __name__ == '__main__':
    parser = ArgumentParser()
    # General
    parser.add_argument("--path_knowledge_base", type=str,
                        default="http://dice-quan.cs.uni-paderborn.de:9080/sparql")
    parser.add_argument("--path_embeddings", type=str,
                        default='/upb/users/q/quannian/profiles/unix/cs/Pycharm_Project/Dice-Embeddings-CSV/QALD10/2024-09-25 13-44-09.306901/Embedding/Removed_Keci_entity_embeddings.csv')
    parser.add_argument("--num_of_target_concepts",
                        type=int,
                        default=1)
    parser.add_argument("--num_of_training_learning_problems",
                        type=int,
                        default=1)
    parser.add_argument("--path_pretrained_dir", type=str, default=None)

    parser.add_argument("--path_learning_problem", type=str, default='/upb/users/q/quannian/profiles/unix/cs/Pycharm_Project/Dice-Embeddings-CSV/QALD10/TandF_MST5_Drill.json',
                        help="Path to a .json file that contains 2 properties 'positive_examples' and "
                             "'negative_examples'. Each of this properties should contain the IRIs of the respective"
                             "instances. e.g. 'some/path/lp.json'")
    parser.add_argument("--max_runtime", type=int, default=1, help="Max runtime")
    parser.add_argument("--folds", type=int, default=10, help="Number of folds of cross validation.")
    parser.add_argument("--random_seed", type=int, default=1)
    parser.add_argument("--iter_bound", type=int, default=10_000, help='iter_bound during testing.')
    # DQL related
    parser.add_argument("--num_episode", type=int, default=1, help='Number of trajectories created for a given lp.')

    parser.add_argument("--epsilon_decay", type=float, default=.01, help='Epsilon greedy trade off per epoch')
    parser.add_argument("--max_len_replay_memory", type=int, default=1024,
                        help='Maximum size of the experience replay')
    parser.add_argument("--num_epochs_per_replay", type=int, default=2,
                        help='Number of epochs on experience replay memory')
    parser.add_argument('--num_of_sequential_actions', type=int, default=1, help='Length of the trajectory.')

    # NN related
    parser.add_argument("--learning_rate", type=int, default=.01)

    start(parser.parse_args())```

Please message Quan if you need the embedding file or any other additional information.

Demirrr commented 5 days ago

The issue stems from the fact that for the sampled learning problem, the root state (Top concept) has zero quality. For a given E^+:[3] E^-:[3], Quality(T,E+, E-) must be strictly greater than 0. Yet, on that particular dataset, and the learning problem, this is not satisfied.

Demirrr commented 5 days ago

Thank you. I will improve the assertion error message to ensure that the reason behind this error is clear

MichaelRoeder commented 5 days ago

I understand that the top concept shouldn't have a quality of zero. That is clearly a problem. So I guess this means that we should check in detail why the scoring method returns a quality of zero although the algorithm was able to select positive and negative examples beforehand, right? :thinking: