churchmanlab / genewalk

GeneWalk identifies relevant gene functions for a biological context using network representation learning
https://churchman.med.harvard.edu/genewalk
BSD 2-Clause "Simplified" License
127 stars 14 forks source link

Illustration of the GeneWalk network #29

Closed izu0421 closed 3 years ago

izu0421 commented 3 years ago

Hi, Thanks a lot for your work. I ran a GeneWalk analysis and would like to visualise the network generated. I think that's saved in multi_graph.pkl? I tried to draw it with networkx & pyplot, but it didn't turn out very pretty. Do you have a script?

Thanks Yizhou

ri23 commented 3 years ago

Hi @izu0421 , Thanks for your interest. You are right the GeneWalk Network is saved in multi_graph.pkl as a networkx.MultiGraph object. In general, the graph is too large to sensibly visualize as a whole given that it consists of over 40k nodes. Even if you only visualize all the genes and their interactions it will likely still just end up as a big hairball. I can give you a code snippet to visualize GWN subgraphs around certain genes of interest, similar to figure 3B of our publication. Let me know if that is of interest to you. If you do want to visualize the whole network, you could consider visualizing the vector representations of all the nodes using UMAP in a 2d embedding. Hope this helps. Robert

bgyori commented 3 years ago

Another option could be to export the graph from networkx into Cytoscape JSON and load the network in Cytoscape.

ri23 commented 3 years ago

Just for completeness below is the script to generate the GWN subgraph around 3 chosen genes of interest (see Figure 3B of our publication) also discussed here: https://github.com/churchmanlab/genewalk/issues/28#issuecomment-778684193

#!/usr/bin/env python
# coding: utf-8

# # GeneWalk network visualization

import os
import re
import copy
import pickle as pkl
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42

# ### Load GeneWalk multigraph and results

path = '/home/genewalk/qki/'   

filename = 'multi_graph.pkl'
with open(os.path.join(path,filename), 'rb') as f:
    MG = pkl.load(f)

filename = 'genewalk_results.csv'
GW = pd.read_csv(os.path.join(path,filename)) 

# ## QKI subnetwork visualization

# Data preprocessing

# Choose genes of interest
GENES=['MAL','PLLP','PLP1']
labels = {}   
for node in GENES:
    labels[node] = node

#Genes and Neighors
#NB = neighbors
GENES_NB=copy.deepcopy(GENES)
for source in GENES:
    GENES_NB.extend(list(MG.neighbors(source)))
    print(source, len(GENES_NB))
GENES_NB=sorted(list(set(GENES_NB)))

#Subset of Neighbors that are genes
GENES_only_NB=copy.deepcopy(GENES_NB)
for gene in GENES_NB:
    if re.search('GO:',gene):
        GENES_only_NB.remove(gene)

#Enumerate GO annotations of Mal according to GeneWalk ranking
gene = 'MAL'
MAL_GO_NB = list(GW[GW['hgnc_symbol']==gene].sort_values(by='global_padj')['go_id'])
labels_MAL_GO_NB = dict()
for i in range(len(MAL_GO_NB)):
    labels_MAL_GO_NB[MAL_GO_NB[i]] = str(i+1)
MAL_GO_NB_edges = [(gene,gonode) for gonode in MAL_GO_NB]

# ### Generate SubGraph (for plotting)
SG = MG.subgraph(GENES_NB)

SGplot = nx.OrderedGraph() 
SGplot.add_nodes_from(GENES_NB) 
SGplot.add_edges_from((u, v) for (u, v) in SG.edges() if u in SGplot if v in SGplot)

# ### Generate plot
plt.figure(figsize=(4,4))#units: inch

pos = nx.circular_layout(SGplot)

nx.draw(SGplot, pos=pos, node_color='white', with_labels=False, alpha=0.1, node_size=150)
nx.draw_networkx_nodes(SGplot, pos, nodelist=GENES_NB, node_size=150, node_color='#B82225', alpha=1)
nx.draw_networkx_nodes(SGplot, pos, nodelist=GENES_only_NB, node_size=150, node_color='#007EC3', alpha=1)
nx.draw_networkx_edges(SGplot, pos, edgelist=MAL_GO_NB_edges, width=2.0)#draw thick edges for Mal GO annotations

scale_factor = 1.05
posl = copy.deepcopy(pos)
for node in posl:
    posl[node] = scale_factor * pos[node]

#Add labels to the nodes you require
lab = nx.draw_networkx_labels(SGplot, pos=posl, labels=labels, font_size=8,font_weight='bold')
lab = nx.draw_networkx_labels(SGplot, pos=pos, labels=labels_MAL_GO_NB, font_size=10, 
                              font_color='white',font_weight='bold')

filename = 'subnetwork_circular'
plt.savefig(os.path.join(path, filename + '.pdf'),bbox_inches="tight",transparent=True)
plt.savefig(os.path.join(path, filename + '.png'),bbox_inches="tight",transparent=True)
izu0421 commented 3 years ago

Hi @ri23,

Thanks a lot! Much appreciated.

Yizhou