XingyanLiu / CAME

Cell-type Assignment and Module Extraction based on a heterogeneous graph neural network.
https://XingyanLiu.github.io/CAME
MIT License
12 stars 6 forks source link

AttributeError: module 'dgl.function' has no attribute 'copy_src' #24

Open DanLiu527 opened 1 year ago

DanLiu527 commented 1 year ago

Hi, when I run the example dataset, step [6] show : AttributeError: module 'dgl.function' has no attribute 'copy_src'. Hope you can help me! Thanks a lot! [leiden] Time used: 0.2771 s already exists: /home/ld/CAME/came/sample_data/_temp/('Baron_human', 'Baron_mouse')-(02-11 16.32.15)/figs already exists: /home/ld/CAME/came/sample_data/_temp/('Baron_human', 'Baron_mouse')-(02-11 16.32.15) [] Setting dataset names: 0-->Baron_human 1-->Baron_mouse [] Setting aligned features for observation nodes (self._features) [*] Setting observation-by-variable adjacent matrices (self._ov_adjs) for making merged graph adjacent matrix of observation and variable nodes -------------------- Summary of the DGL-Heterograph -------------------- Graph(num_nodes={'cell': 4028, 'gene': 3446}, num_edges={('cell', 'express', 'gene'): 1677575, ('cell', 'self_loop_cell', 'cell'): 4028, ('cell', 'similar_to', 'cell'): 25760, ('gene', 'expressed_by', 'cell'): 1677575, ('gene', 'self_loop_gene', 'gene'): 3446}, metagraph=[('cell', 'gene', 'express'), ('cell', 'cell', 'self_loop_cell'), ('cell', 'cell', 'similar_to'), ('gene', 'cell', 'expressed_by'), ('gene', 'gene', 'self_loop_gene')]) self-loops for observation-nodes: True self-loops for variable-nodes: True

AlignedDataPair with 4028 obs- and 3446 var-nodes n_obs1 (Baron_human): 2142 n_obs2 (Baron_mouse): 1886 Dimensions of the obs-node-features: 701 main directory: /home/ld/CAME/came/sample_data/_temp/('Baron_human', 'Baron_mouse')-(02-11 16.32.15) model directory: /home/ld/CAME/came/sample_data/_temp/('Baron_human', 'Baron_mouse')-(02-11 16.32.15)/_models ============== start training (device='cpu') ===============

AttributeError Traceback (most recent call last) Cell In[165], line 10 1 came_inputs, (adata1, adata2) = pipeline.preprocess_aligned( 2 adatas, 3 key_class=key_class1, (...) 7 df_varmap_1v1=df_varmap_1v1, # set as None if NOT cross species 8 ) ---> 10 outputs = pipeline.main_for_aligned( 11 **came_inputs, 12 dataset_names=dsnames, 13 key_class1=key_class1, 14 key_class2=key_class2, 15 do_normalize=True, 16 n_epochs=n_epochs, 17 resdir=resdir, 18 n_pass=n_pass, 19 batch_size=batch_size, 20 plot_results=True, 21 ) 22 dpair = outputs['dpair'] 23 trainer = outputs['trainer']

File ~/miniconda3/envs/env_came/lib/python3.8/site-packages/came/pipeline.py:213, in main_for_aligned(adatas, vars_feat, vars_as_nodes, scnets, dataset_names, key_class1, key_class2, do_normalize, batch_keys, n_epochs, resdir, tag_data, params_model, params_lossfunc, n_pass, batch_size, pred_batch_size, plot_results, norm_target_sum, save_hidden_list, save_dpair) 207 trainer.train_minibatch( 208 n_epochs=n_epochs, 209 params_lossfunc=params_lossfunc, 210 batch_size=batch_size, 211 n_pass=n_pass, device=device) 212 else: --> 213 trainer.train(n_epochs=n_epochs, 214 params_lossfunc=params_lossfunc, 215 n_pass=n_pass, device=device) 216 trainer.save_model_weights() 217 # ========================== record results ========================

File ~/miniconda3/envs/env_came/lib/python3.8/site-packages/came/utils/train.py:295, in Trainer.train(self, n_epochs, use_class_weights, params_lossfunc, n_pass, eps, cat_class, device, info_stride, backup_stride, other_inputs) 293 self.optimizer.zero_grad() 294 t0 = time.time() --> 295 outputs = model(feat_dict, g, other_inputs) 296 logits = outputs[cat_class] 297 # logits2 = model(feat_dict, g, other_inputs)[cat_class] 298 # loss = ce_loss_with_rdrop( 299 # logits, logits2, labels=train_labels, (...) 303 # params_lossfunc 304 # )

File ~/miniconda3/envs/env_came/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, *kwargs) 1190 # If we don't have any hooks, we want to skip the rest of the logic in 1191 # this function, and just call forward. 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/env_came/lib/python3.8/site-packages/came/model/cggc.py:205, in CGGCNet.forward(self, feat_dict, g, other_inputs) 203 h_dict['gene'] = relu(h_dict0['gene'] + h_dict['gene']) 204 else: --> 205 h_dict = self.embed_layer(g, feat_dict, ) 206 h_dict = self.rgcn.forward(g, h_dict, other_inputs).copy() 208 h_dict['cell'] = self.cell_classifier.forward(g, h_dict)['cell']

File ~/miniconda3/envs/env_came/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, *kwargs) 1190 # If we don't have any hooks, we want to skip the rest of the logic in 1191 # this function, and just call forward. 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/env_came/lib/python3.8/site-packages/came/model/base_layers.py:571, in GeneralRGCLayer.forward(self, g, inputs, etypes, norm, bias, activate, static_wdict) 568 # inputs_src = inputs_dst = inputs 569 inputs = {ntype: self.dropout_feat(feat) for ntype, feat in inputs.items()} --> 571 hs = self.conv(g, inputs, etypes, mod_kwargs=wdict) 573 def _apply(ntype, h): 575 if self.use_batchnorm and norm:

File ~/miniconda3/envs/env_came/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, *kwargs) 1190 # If we don't have any hooks, we want to skip the rest of the logic in 1191 # this function, and just call forward. 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/env_came/lib/python3.8/site-packages/came/model/heteroframe.py:173, in HeteroGraphConv.forward(self, g, inputs, etypes, mod_args, mod_kwargs) 171 if stype not in inputs: 172 continue --> 173 dstdata = self.mods[etype]( 174 rel_graph, 175 inputs[stype], 176 *mod_args.get(etype, ()), 177 **mod_kwargs.get(etype, {})) 178 outputs[dtype].append(dstdata) 179 rsts = {}

File ~/miniconda3/envs/env_came/lib/python3.8/site-packages/torch/nn/modules/module.py:1194, in Module._call_impl(self, *input, *kwargs) 1190 # If we don't have any hooks, we want to skip the rest of the logic in 1191 # this function, and just call forward. 1192 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1193 or _global_forward_hooks or _global_forward_pre_hooks): -> 1194 return forward_call(input, **kwargs) 1195 # Do not call functions when jit is used 1196 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/env_came/lib/python3.8/site-packages/came/model/base_layers.py:875, in GraphConvLayer.forward(self, g, feat, weight, static_weight) 873 weight = self.weight 874 if static_weight is None: --> 875 message_func = fn.copy_src(src='h', out='m') 876 else: 877 g.edata['w_static'] = static_weight

AttributeError: module 'dgl.function' has no attribute 'copy_src'

XingyanLiu commented 1 year ago

Re-install your DGL below version 1.0.* may solve this problem. Here is an example environment:

MacOS, python 3.8
numpy                     1.23.2                   
numpy-base                1.23.3             
pandas                    1.4.3 
scanpy                    1.9.1 
scikit-learn              1.1.2 
scipy                     1.9.1  
dgl                     0.9.0 (<1.0.0)
torch                 1.12.1