JuliaGast / TGB2

Temporal Graph Benchmark project repo
1 stars 0 forks source link

Some more utility functions for dataset.py? #9

Closed JuliaGast closed 5 months ago

JuliaGast commented 5 months ago

I need the following methods for recurrencybaseline.py, they will also be needed by other methods later. does it make sense to move them somewhere else, e.g. to utils? (you can also find them here: https://github.com/JuliaGast/TGB2/blob/julia_new/examples/linkproppred/tkgl-polecat/recurrencybaseline.py)

def group_by(data: np.array, key_idx: int) -> dict:
    """
    group data in an np array to dict; where key is specified by key_idx. for example groups elements of array by relations
    :param data: [np.array] data to be grouped
    :param key_idx: [int] index for element of interest
    returns data_dict: dict with key: values of element at index key_idx, values: all elements in data that have that value
    """
    data_dict = {}
    data_sorted = sorted(data, key=itemgetter(key_idx))
    for key, group in groupby(data_sorted, key=itemgetter(key_idx)):
        data_dict[key] = np.array(list(group))
    return data_dict

def add_inverse_quadruples(triples: np.array, num_rels:int) -> np.array:
    """
    creates an inverse triple for each triple in triples. inverse triple swaps subject and objsect, and increases 
    relation id by num_rels
    :param triples: [np.array] dataset triples
    :param num_rels: [int] number of relations that we have originally
    returns all_triples: [np.array] triples including inverse triples
    """
    inverse_triples = triples[:, [2, 1, 0, 3]]
    inverse_triples[:, 1] = inverse_triples[:, 1] + num_rels  # we also need inverse triples
    all_triples = np.concatenate((triples[:,0:4], inverse_triples))

    return all_triples

def reformat_ts(timestamps):
    """ reformat timestamps s.t. they start with 0, and have stepsize 1.
    :param timestamps: np.array() with timestamps
    returns: np.array(ts_new)
    """
    all_ts = list(set(timestamps))
    all_ts.sort()
    ts_min = np.min(all_ts)
    ts_dist = all_ts[1] - all_ts[0]

    ts_new = []
    timestamps2 = timestamps - ts_min
    for timestamp in timestamps2:
        timestamp = int(timestamp/ts_dist)
        ts_new.append(timestamp)
    return np.array(ts_new)
JuliaGast commented 5 months ago

keep in specific tkg methods.

shenyangHuang commented 5 months ago

more on the method specific side, won't add for now