DeepGraphLearning / RNNLogic

120 stars 25 forks source link

Case Study of Generated Logic Rules #7

Open LuMflowers opened 3 years ago

LuMflowers commented 3 years ago

Hi, I'm interested in how to get the generated logic rules in Table 4 and Table 7 as your paper shows. And I can't find the relations in the dataset of FB15k-237 as what you have mentioned in Table 4 and Table 7.

immortalCO commented 3 years ago

The relation names of FB15k-237 are very long. So we have renamed them in the paper by understanding their meaning.

The following code can be used to print rules from trained models, with various constraints:

import gc
import copy
gc.enable()

import os
import sys
from sys import *
#from random import *
from collections import defaultdict
import torch
import torch.nn as nn
from torch import *
from torch.nn import *
from torch.optim import *
from random import shuffle
from random import randint
import time
import datetime
import json
import torch.nn.functional as F
from model import *

relation2id = dict()
id2relation = dict()

with open(f'dataset/FB15k-237/relations.dict') as fin:
    relation2id = dict()
    for line in fin:
        rid, relation = line.strip().split('\t')
        relation2id[relation] = int(rid)

        rel = ""
        cnt = 0
        for c in reversed(relation):
            if c == '/':
                cnt += 1
                if cnt == 2 and len(rel) >= 30:
                    break
                if cnt >= 3:
                    break
            rel = c + rel

        id2relation[int(rid)] = rel

R = len(relation2id)

mov = R
inv = [0] * 2 * R

for i in range(R):
    inv[i + mov] = i
    inv[i] = i + mov
    id2relation[i + mov] = "!" + id2relation[i]

a = torch.load(sys.argv[1])
r = int(a['r'])

def has_revlink(rule):
    for i in range(len(rule) - 1):
        if rule[i] == inv[rule[i + 1]]:
            return True
    return False

def contains_r(rule):
    r = int(a['r'])
    ret = r in set(map(int, rule))
    return ret

def prt(p, n=10):
    for _i in range(n):
        i = p[_i]
        r = id2relation[int(a['r'])]
        path = map(lambda x : id2relation[int(x)], a['rules'][i])
        val = a['predictor']['rule_weight_raw'][i]

        print(f"&$\\gets$&$",end='')
        Ltr = "XUVW"
        for i, r in enumerate(path):
            r = r.replace('_', '\\_')
            if r[0] != '!':
                print(f"{Ltr[i]}\\relarr{{{r}}}",end="")
            else:
                print(f"{Ltr[i]}\\relarrl{{{r[1:]}}}",end="")

        print("Y$\\\\")

weight = a['predictor']['rule_weight_raw']

print("Relation:", id2relation[int(a['r'])])

print("general:")
p = sorted(range(len(a['rules'])), 
    key=lambda i : (weight[i]),
    reverse=True)
prt(p)

print("no self:")
p = sorted(range(len(a['rules'])), 
    key=lambda i : (not contains_r(a['rules'][i]), not has_revlink(a['rules'][i]), weight[i]),
    reverse=True)
prt(p,n=40)

print("revlink:")
p = sorted(range(len(a['rules'])), 
    key=lambda i : (has_revlink(a['rules'][i]), weight[i]),
    reverse=True)
prt(p,n=40)
LuMflowers commented 3 years ago

When I run the above code, I encounter an error " a = torch.load(sys.argv[1]) IndexError: list index out of range". How to solve this problem?

navdeepkjohal commented 2 years ago

Hello,

sys.argv[1] is the path to the model file, that is learnt during the train time. set it to ''./workspace/model_0.pth' to see the rules for relation 0 (r=0) and so on for the other relations. Also, the rules generated by this code are in the form of pdf code. Give the output of this file as the input to latex file. I hope this helps.