benedekrozemberczki / karateclub

Karate Club: An API Oriented Open-source Python Framework for Unsupervised Learning on Graphs (CIKM 2020)
https://karateclub.readthedocs.io
GNU General Public License v3.0
2.17k stars 247 forks source link

How to mini-batch during model.fit() ? #103

Closed johnnytam100 closed 2 years ago

johnnytam100 commented 2 years ago

My hardware was unable to load all graphs (~300,000 graphs with an average 200 nodes per graph) into the memory. I have written a simple loop by dividing all graphs into batches of 1,000 and fit the model at each loop. Is this operation equivalent to an ordinary mini-batch?

import networkx as nx
from karateclub import FeatherGraph
import os
import glob
import pickle
import numpy as np
import pandas as pd
import random

# Load
filepath_list = []

for filepath in glob.iglob(<my graphs>):
  filepath_list.append(filepath)

# Shuffle list
#filepath_list.sort()
random.shuffle(filepath_list)

# Define model
model = FeatherGraph()

idx = 0
fit_batch = 1000
repeat_fit = len(filepath_list) // fit_batch + 1

for x in range(repeat_fit):

    try:

        for idx in range(idx, idx + fit_batch):

            graph_list = []

            # Load graph
            graph_path = filepath_list[idx]
            print(idx, "Loading...", graph_path)

            with open(graph_path, 'rb') as f:  # notice the r instead of w
                g_load = pickle.load(f)

            # Convert graph index to integer (required by karateclub)
            g_load_reindex = nx.convert_node_labels_to_integers(g_load)

            graph_list.append(g_load_reindex)

        # Fit
        print("Fitting model...")
        model.fit(graph_list)

        idx += 1

    except:

        break
benedekrozemberczki commented 2 years ago

Dear @johnnytam100,

This algorithm is completely inductive, you should get the same result this way. Can you cite the FEATHER paper and the Karate Club design paper in your work?

Could you also star the repo and hit follow?

Bests,

Benedek

johnnytam100 commented 2 years ago

Thank you @benedekrozemberczki ! Btw, if I do this way, do you know how to get the embedding of the 300,000 graphs at the end? Do we have model.predict() in karate club?

ps I have already followed and will certainly cite your work.