Closed yeswzc closed 10 months ago
Hi, I have a solution, but maybe not the best one.
import torchsnapshot from tensordict import TensorDict
import copy tmp = copy.deepcopy(graph_dict)
tmp['adj_norm'] = tmp['adj_norm'].to_dense() tmp['adj_label'] = tmp['adj_label'].to_dense() d = TensorDict(tmp, []) state = {'state': d} snapshot = torchsnapshot.Snapshot.take(app_state=state, path="snapshot")
snapshot = torchsnapshot.Snapshot(path="snapshot") graph_dict_r = TensorDict({}, []) state_target = {"state": graph_dict_r} snapshot.restore(app_state=state_target) assert(graph_dict_r == d).all()
graph_dict_r = graph_dict_r.to_dict() graph_dict_r['adj_norm'] = graph_dict_r['adj_norm'].to_sparse() graph_dict_r['adj_label'] = graph_dict_r['adj_label'].to_sparse()
Thank you! Works well!
Hi, is there a way to save the graph_dict in python that can be loaded into python next time? I tried pickle pickle.HIGHEST_PROTOCOL but it cann't be read.
Thank you!