Open rmchauhan03 opened 1 year ago
To give more background both G,D are MLP O2MODELS with nn.BCECriterion generator = nn.Sequential( Reshape(30), nn.Linear(30, 50), nn.ReLU(), nn.Linear(50, 256), ) criterion = nn.MSELoss() geno2 = O2Model(generator, criterion) discriminator = nn.Sequential( Reshape(256), nn.Linear(256, 30), nn.ReLU(), nn.Linear(30, 1), nn.Sigmoid() ) criterion = nn.BCELoss() disco2 = O2Model(discriminator, criterion) The errror I obtained was using standard SGD on the O2MODEL
Hello, I am working on using the o2grad package for a GAN, and I am running into the following error. G.zero_grad() device = "cpu" bs = 32 z = torch.randn(bs, 30).to(device) y = torch.ones(bs, 1).to(device)
When I run this nested backpropagation code, I obtain this error File "/Users/rohan/Desktop/rsh/chaotic_neurips22/o2grad/o2grad/backprop/o2model.py", line 455, in _capture_backprops layer._callbacks.on_capture() File "/Users/rohan/Desktop/rsh/chaotic_neurips22/o2grad/o2grad/utils.py", line 35, in call [cb() for cb in self.callbacks] File "/Users/rohan/Desktop/rsh/chaotic_neurips22/o2grad/o2grad/utils.py", line 35, in
[cb() for cb in self.callbacks]
File "/Users/rohan/Desktop/rsh/chaotic_neurips22/o2grad/o2grad/modules/o2module.py", line 58, in
self._callbacks[key].add(lambda: cb(self))
File "/Users/rohan/Desktop/rsh/chaotic_neurips22/o2grad/o2grad/backprop/o2model.py", line 137, in _backprop_step
backprop_step(layer, diagonal_blocks, cache_cpu=True)
File "/Users/rohan/Desktop/rsh/chaotic_neurips22/o2grad/o2grad/backprop/o2backprop.py", line 51, in backprop_step
dL2dw2 = layer.get_loss_param_hessian(
File "/Users/rohan/Desktop/rsh/chaotic_neurips22/o2grad/o2grad/modules/o2parametric/o2parametric.py", line 162, in get_loss_param_hessian
dL2dw2_term2 = twin_matmul_mixed(dydw, dL2dy2)
File "/Users/rohan/Desktop/rsh/chaotic_neurips22/o2grad/o2grad/linalg.py", line 39, in twin_matmul_mixed
assert B.shape[1] == A.shape[0]
I believe this issue has to due with the nested backpropagation and I was wondering if there was a workaround. When I take out the nested part of the code, everything compiles as needed. Thanks again