Open 91Abdullah opened 7 months ago
Thanks for posting--what's the model you're trying to look at? Will help me debug.
Are you trying to run ModelHistory.run_and_log_inputs_through_model
by itself? This is meant to be a function that's called internally, not called by the user. If you want to save activations for a new input, the function to use is ModelHistory.save_new_activations
. I should've marked the former function as internal with an underscore so that's my mistake. Let me know if this answers your question.
@johnmarktaylor91 the model I was trying to use was MobileNetV3. But IMO the problem is with the last layer. What I did was I removed the last layer of MobileNetV3Large and added a custom implemented QuantumLayer (using pennylane) and because of that it may not be working. I didn't test without the custom layer but I will try and let you know if it works.
I was able to get this minimum example to work with a QuantumLayer, though the visualization is odd (no direct path from input to output). Unfortunately I'm not at all familiar with this kind of architecture and I would have to look into how PennyLane interfaces with PyTorch under the hood. I think what's happening is that there are some operations performed in the QuantumLayers that aren't being tracked by TorchLens for whatever reason--I will ask the PennyLane folks for some insight.
import torch
import pennylane as qml
import torchlens as tl
from os.path import join as opj
n_qubits = 2
dev = qml.device("default.qubit", wires=n_qubits)
@qml.qnode(dev)
def qnode(inputs, weights):
qml.AngleEmbedding(inputs, wires=range(n_qubits))
qml.BasicEntanglerLayers(weights, wires=range(n_qubits))
return [qml.expval(qml.PauliZ(wires=i)) for i in range(n_qubits)]
n_layers = 6
weight_shapes = {"weights": (n_layers, n_qubits)}
qlayer = qml.qnn.TorchLayer(qnode, weight_shapes)
clayer_1 = torch.nn.Linear(2, 2)
clayer_2 = torch.nn.Linear(2, 2)
softmax = torch.nn.Softmax(dim=1)
layers = [clayer_1, qlayer, clayer_2, softmax]
model = torch.nn.Sequential(*layers)
model_history = tl.log_forward_pass(model, torch.rand(1, 2), vis_opt='unrolled',
vis_outpath=opj('/home/jtaylor/projects/torchlens/local_jmt/Debugging Scripts/qnn'))