Shuriken13 / ENRL

9 stars 3 forks source link

Missing Rule Extraction Module #1

Closed voladorlu closed 2 years ago

voladorlu commented 2 years ago

Hi, thank you so much for sharing such important code. From the implementation, it's clear to get how the model is defined. While, it's still not easy to get how to extract the rules from categorical features. May I ask how to decide the threshold \gamma for categorical feature, when parsing the rules from the well-trained model? Do you have any suggestion or experiences to share about how to choose the threshold value? -:)

zichuan-liu commented 2 years ago

I also have the same question, how to get rules from _gelayers and _lelayers and _ruleweight

Shuriken13 commented 2 years ago

You may modify the following codes to show rules.

# coding=utf-8
import pickle
import sys
import os
import re
import socket
import numpy as np
import pandas as pd
from collections import defaultdict, Counter
from datetime import datetime, timezone
from sklearn.metrics import accuracy_score, roc_auc_score, auc
from scipy import stats
import torch
import treelib

sys.path.insert(0, '../')
sys.path.insert(0, './')

from enrl.configs.constants import *
from enrl.configs.settings import *
from enrl.models import *

np.random.seed(DEFAULT_SEED)
print(socket.gethostname())

def read_model(model_name, model_version, dataset_name, model_args=None):
    if model_args is None:
        model_args = {}
    model = eval('{0}.{0}'.format(model_name))(**model_args)

    # # # read data
    # reader = model.read_data(dataset_dir=os.path.join(DATASET_DIR, dataset_name))
    #
    # # # init modules
    # model.init_modules()
    # model.summarize(mode='full')
    #
    # # # init metrics
    # train_metrics = model_args['train_metrics'] if 'train_metrics' in model_args else None
    # val_metrics = model_args['val_metrics'] if 'val_metrics' in model_args else None
    # test_metrics = model_args['test_metrics'] if 'test_metrics' in model_args else None
    # model.init_metrics(train_metrics=train_metrics, val_metrics=val_metrics,
    #                    test_metrics=test_metrics)

    # # load ckpt
    log_dir = os.path.join(MODEL_DIR, dataset_name, model_name, model_version)
    checkpoint_path = os.path.join(log_dir, CKPT_DIR, CKPT_F + '.ckpt')
    hparams_file = os.path.join(log_dir, 'hparams.yaml')
    model.load_model(checkpoint_path=checkpoint_path, hparams_file=hparams_file)
    return model

def translate_context(model, op, f, ctxt_v, bin_dict):
    # print(f, op)
    if op == '<=' or op == '>=':
        lo, hi = model.nu_dict[f]
        # print(f, lo, hi)
        f_embeddings = model.numeric_embeddings.weight[lo:hi + 1]
        ctxt_vs = ctxt_v.expand_as(f_embeddings)
        le_vs = model.le_layers(torch.cat([f_embeddings, ctxt_vs], dim=-1)).flatten().detach().cpu().numpy()
        ge_vs = model.ge_layers(torch.cat([f_embeddings, ctxt_vs], dim=-1)).flatten().detach().cpu().numpy()
        v_idx = np.argmin(np.abs(le_vs - ge_vs))
        if bin_dict is not None and (f in bin_dict or f[:-2] in bin_dict):
            if f in bin_dict:
                lo, hi = bin_dict[f][v_idx], bin_dict[f][v_idx + 1]
            else:
                lo, hi = bin_dict[f[:-2]][v_idx], bin_dict[f[:-2]][v_idx + 1]
            v_idx = '{}:({},{}]'.format(v_idx, lo, hi)
    else:
        lo, hi = model.mh_dict[f]
        f_embeddings = model.multihot_embeddings.weight[lo:hi + 1]
        ctxt_vs = ctxt_v.expand_as(f_embeddings)
        in_vs = model.blto_layers(torch.cat([f_embeddings, ctxt_vs], dim=-1)).flatten().detach().cpu().numpy()
        # v_idx = in_vs.argsort()[::-1]
        v_idx = in_vs.argsort()[-5:][::-1]
        v_idx = list(v_idx[in_vs[v_idx] > 0.5])
        # if len(v_idx) > 0:
        #     print(in_vs)
        if bin_dict is not None:
            v_idx = '{}:[{}]'.format(str(v_idx), ','.join([bin_dict[f][c] for c in v_idx])).replace(' ', '')
            # print(v_idx)
    return str(v_idx).replace(' ', '')

def translate_voting_weights(model, state_vs):
    states = [state_vs[-1]]
    idx = len(state_vs) - 1
    while idx > 0:
        parent_idx = (idx - 1) // 2
        yn = 0 if idx % 2 == 1 else 1
        parent_state = state_vs[parent_idx][yn].expand_as(states[-1])
        states.insert(0, parent_state)
        idx = parent_idx
    states = torch.cat(states, dim=-1)
    v_w = model.rule_weight_layers(states)
    return v_w.flatten().detach().cpu().numpy()

def draw_tree(model, tree_idx, bin_dict):
    nuf_list = list(model.nu_dict.keys())
    mhf_list = list(model.mh_dict.keys())
    nu_num = len(nuf_list)
    mh_num = len(mhf_list)
    nas_w = model.nas_w[tree_idx]
    state_vs = []
    tree = treelib.Tree()
    for node_idx in range(model.node_n):
        node_w = nas_w[node_idx]
        op_idx = node_w.argmax().cpu().numpy()
        if op_idx < nu_num:
            op = '>='
            f_idx = op_idx
            f = nuf_list[f_idx]
            ctxt_v = model.ge_v[tree_idx][node_idx][f_idx]
            ctxt = translate_context(model, op, f, ctxt_v, bin_dict)
            state_2v = model.ge_state_layers(ctxt_v).view(2, -1)
        elif op_idx < 2 * nu_num:
            op = '<='
            f_idx = op_idx - nu_num
            f = nuf_list[f_idx]
            ctxt_v = model.le_v[tree_idx][node_idx][f_idx]
            ctxt = translate_context(model, op, f, ctxt_v, bin_dict)
            state_2v = model.le_state_layers(ctxt_v).view(2, -1)
        else:
            op = 'in'
            f_idx = op_idx - 2 * nu_num
            f = mhf_list[f_idx]
            ctxt_v = model.blto_v[tree_idx][node_idx][f_idx]
            ctxt = translate_context(model, op, f, ctxt_v, bin_dict)
            state_2v = model.blto_state_layers(ctxt_v).view(2, -1)
        state_vs.append(state_2v)
        node_str = '{} {} {}'.format(f, op, ctxt)
        if node_idx >= model.node_n // 2:
            yes_w, no_w = translate_voting_weights(model, state_vs)
            node_str += ' {:.4f},{:.4f}'.format(yes_w, no_w)
        parent = (node_idx - 1) // 2
        if parent < 0:
            tree.create_node('{} {}'.format(node_idx, node_str), node_idx)
        else:
            tree.create_node('{} {}'.format(node_idx, node_str), node_idx, parent=parent)
    return tree

def read_bin_dict(dataset_name):
    bin_path = os.path.join(DATASET_DIR, dataset_name, 'bin_dict.pkl')
    bin_dict = pickle.load(open(bin_path, 'rb')) if os.path.exists(bin_path) else None
    if bin_dict is None:
        return None
    for key in bin_dict:
        if type(bin_dict[key]) is dict:
            new_dict = {}
            for c in bin_dict[key]:
                new_dict[bin_dict[key][c]] = c
            bin_dict[key] = new_dict
    return bin_dict

def reverse_tag(tag):
    if ' <= ' in tag:
        tag = tag.replace(' <= ', ' > ')
    elif ' >= ' in tag:
        tag = tag.replace(' >= ', ' < ')
    else:
        tag = tag.replace(' in ', ' notin ')
    return tag

def draw_rules(model, bin_dict):
    rules = []
    for tree_idx in range(model.rule_n):
        tree = draw_tree(model, tree_idx, bin_dict)
        # print(tree.nodes)
        for leaf_idx in range(model.node_n // 2, model.node_n):
            leaf_node = tree.get_node(leaf_idx)
            tags = leaf_node.tag.split(' ')
            left_w, right_w = tags[-1].split(',')
            left = [' '.join(tags[:-1] + [left_w])]
            right = [reverse_tag(' '.join(tags[:-1] + [right_w]))]
            while leaf_idx > 0:
                parent_idx = (leaf_idx - 1) // 2
                parent_node = tree.get_node(parent_idx)
                parent_tag = parent_node.tag
                if leaf_idx % 2 == 0:
                    parent_tag = reverse_tag(parent_tag)
                # print(leaf_idx, parent_idx, parent_tag)
                left.insert(0, parent_tag)
                right.insert(0, parent_tag)
                leaf_idx = parent_idx
            rules.append((tree_idx, left))
            rules.append((tree_idx, right))
    results = []
    for tree_idx, rule in rules:
        weight = None
        tree = treelib.Tree()
        for idx, node in enumerate(rule):
            if idx == 0:
                tree.create_node(node, idx)
            else:
                if idx != len(rule) - 1:
                    tree.create_node(node, idx, idx - 1)
                else:
                    node = node.split(' ')
                    weight = float(node[-1])
                    tree.create_node(' '.join(node[:-1]), idx, idx - 1)
        results.append((weight, tree_idx, tree))
    return results

def main():
    # # deepest
    dataset_name = 'adult'
    model_version = '039cc7d4e47ee0bce50c_1951'

    bin_dict = read_bin_dict(dataset_name)
    model = read_model(model_name='Our', model_version=model_version, dataset_name=dataset_name)
    tree = draw_tree(model, 0, bin_dict)
    rules = draw_rules(model, bin_dict)
    rules = sorted(rules, key=lambda x: x[0], reverse=True)
    for w, tree, rule in rules[:10] + rules[-10:]:
        print(w, tree)
        print(rule)
    # print(draw_tree(model, 19, bin_dict))
    # print(draw_tree(model, 22, bin_dict))
    print(bin_dict)
    return

if __name__ == '__main__':
    main()
voladorlu commented 2 years ago

@Shuriken13 Thank you so much for sharing the complete code. From the definition of translate_context, the threshold to parse the ecm for categorical feature is set to 0.5, right? -:)

Shuriken13 commented 2 years ago

yes~ you may adjust it if necessary

voladorlu commented 2 years ago

Thanks so much. That's an awesome project.