mert-kurttutan / torchview

torchview: visualize pytorch models
https://torchview.dev
MIT License
822 stars 37 forks source link

AttributeError: 'Tensor' object has no attribute 'training' #65

Closed ZihanZhang0 closed 1 year ago

ZihanZhang0 commented 1 year ago

Hi, I'm trying to draw the graph of GMAN model. The source code link is https://github.com/VincLee8188/GMAN-PyTorch I added the model_graph = draw_graph(model(X,TE), input_size=(1,12*325)) model_graph.visual_graph in the train.py file since the input of the model is calculated in the function train. I got a problem when running the line: model_graph = draw_graph, input_size)

It throws that kind of error:

Traceback (most recent call last): File "main1.py", line 84, in loss_train, loss_val = train(model, args, log, loss_criterion, optimizer, scheduler) File "/home/zzh/Downloads/GMAN-PyTorch/model/train1.py", line 48, in train model_graph = draw_graph(model(X,TE), input_size=(1,12*325)) File "/home/zzh/.conda/envs/geocpu/lib/python3.8/site-packages/torchview/torchview.py", line 220, in draw_graph forward_prop( File "/home/zzh/.conda/envs/geocpu/lib/python3.8/site-packages/torchview/torchview.py", line 242, in forward_prop saved_model_mode = model.training AttributeError: 'Tensor' object has no attribute 'training'

As model.training is boolean represents whether this module is in training or evaluation mode. I ran the line of model_graph after model.train() . I print(model.training) before the line and got True.

Thanks in advance for your valuable helps.

mert-kurttutan commented 1 year ago

As written in the document (also docstring), the first input should be pytorch model. In your example, it seems that you are passing output of model, which is a tensor, rather than the model itself. Assuming that X and TE are the inputs for your model, the correct way is the following

model_graph = draw_graph(model, input_data=(X, TE))
model_graph.visual_graph
ZihanZhang0 commented 1 year ago

It works! Thank you so much!