mfinzi / equivariant-MLP

A library for programmatically generating equivariant layers through constraint solving
MIT License
251 stars 21 forks source link

Linear projected weight changes device - EMLP in Pytorch #11

Closed JoshuaMitton closed 3 years ago

JoshuaMitton commented 3 years ago

I have the issue that after projecting the weight matrix is changes device from cuda to cpu. Below is the main components of the model if you can spot anything I am doing wrong that is causing this.

Imports and EMLP block in torch

from emlp.groups import S
from emlp.reps import V
import emlp
from torch.nn import Module
import emlp.nn.pytorch as emlp_torch

class EMLPBlock(Module):
    """ Basic building block of EMLP consisting of G-Linear, biLinear,
        and gated nonlinearity. """
    def __init__(self,rep_in,rep_out):
        super().__init__()
        rep_out_wgates = emlp_torch.gated(rep_out)
        self.linear = emlp_torch.Linear(rep_in,rep_out_wgates)
        self.bilinear = emlp_torch.BiLinear(rep_out_wgates,rep_out_wgates)
        self.nonlinearity = emlp_torch.GatedNonlinearity(rep_out)
    def __call__(self,x):
        print(f'linear weight device : {self.linear.weight.device}')
        print(f'linear weight proj device : {self.linear.proj_w(self.linear.weight).device}')
        lin = self.linear(x)
        preact =self.bilinear(lin)+lin
        return self.nonlinearity(preact)

In the model init

rin_2 = 10*V(S(2))**2
rout_2 = 20*V(S(2))**2
rin_3 = 10*V(S(3))**2
rout_3 = 20*V(S(3))**2
rin_4 = 10*V(S(4))**2
rout_4 = 20*V(S(4))**2
rin_5 = 10*V(S(5))**2
rout_5 = 20*V(S(5))**2

print(f'rep in layer 1 S(2) : {rin_2}')
print(f'rep in layer 1 S(2) : {rout_2}')
print(f'rep in layer 1 S(3) : {rin_3}')
print(f'rep in layer 1 S(3) : {rout_3}')
print(f'rep in layer 1 S(4) : {rin_4}')
print(f'rep in layer 1 S(4) : {rout_4}')
print(f'rep in layer 1 S(5) : {rin_5}')
print(f'rep in layer 1 S(5) : {rout_5}')

self.eqblock1_2 = EMLPBlock(rin_2,rout_2)
self.eqblock1_3 = EMLPBlock(rin_3,rout_3)
self.eqblock1_4 = EMLPBlock(rin_4,rout_4)
self.eqblock1_5 = EMLPBlock(rin_5,rout_5)

In the forward of the model

adj_2 = self.eqblock1_2(adj_2)

The input adj2 is on device cuda and I have model.to('cuda') before calling the forward part of the model.

Also when I check the print for the linear weight in the EMLP block is on device cuda before being projected, but after the projection is on device CPU.

Do you know why the projection moves the wights from cuda to cpu?

Many thanks, Josh

JoshuaMitton commented 3 years ago

Just after posting this I thought maybe its actually an issue with my jax installation since the pytorch version is still using jax. Turns out that was the case so just needed to fix my jax installation.