I have the issue that after projecting the weight matrix is changes device from cuda to cpu. Below is the main components of the model if you can spot anything I am doing wrong that is causing this.
Imports and EMLP block in torch
from emlp.groups import S
from emlp.reps import V
import emlp
from torch.nn import Module
import emlp.nn.pytorch as emlp_torch
class EMLPBlock(Module):
""" Basic building block of EMLP consisting of G-Linear, biLinear,
and gated nonlinearity. """
def __init__(self,rep_in,rep_out):
super().__init__()
rep_out_wgates = emlp_torch.gated(rep_out)
self.linear = emlp_torch.Linear(rep_in,rep_out_wgates)
self.bilinear = emlp_torch.BiLinear(rep_out_wgates,rep_out_wgates)
self.nonlinearity = emlp_torch.GatedNonlinearity(rep_out)
def __call__(self,x):
print(f'linear weight device : {self.linear.weight.device}')
print(f'linear weight proj device : {self.linear.proj_w(self.linear.weight).device}')
lin = self.linear(x)
preact =self.bilinear(lin)+lin
return self.nonlinearity(preact)
The input adj2 is on device cuda and I have model.to('cuda') before calling the forward part of the model.
Also when I check the print for the linear weight in the EMLP block is on device cuda before being projected, but after the projection is on device CPU.
Do you know why the projection moves the wights from cuda to cpu?
Just after posting this I thought maybe its actually an issue with my jax installation since the pytorch version is still using jax. Turns out that was the case so just needed to fix my jax installation.
I have the issue that after projecting the weight matrix is changes device from cuda to cpu. Below is the main components of the model if you can spot anything I am doing wrong that is causing this.
Imports and EMLP block in torch
In the model init
In the forward of the model
The input adj2 is on device cuda and I have model.to('cuda') before calling the forward part of the model.
Also when I check the print for the linear weight in the EMLP block is on device cuda before being projected, but after the projection is on device CPU.
Do you know why the projection moves the wights from cuda to cpu?
Many thanks, Josh