choderalab / modelforge

Infrastructure to implement and train NNPs
https://modelforge.readthedocs.io/en/latest/
MIT License
11 stars 4 forks source link

Tensornet representation #167

Closed MarshallYan closed 2 months ago

MarshallYan commented 3 months ago

Description

Implement modelforge TensorNetRepresentation module, which matches the output of TensorEmbedding forward. In original implementation, TensorEmbedding is used by TensorNet class, where forward is called whenever updating TensorNet. The output of TensorEmbedding corresponds to "X" tensor in the TensorNet paper.

Todos

Questions

Status

MarshallYan commented 3 months ago

@wiederm How edge_index, edge_vec, edge_weight are calculated are uncertain. The problem being:

  1. edge_weight may not have units as distances, which causes confusion in the unit transformations in RSF.
  2. The calculation is defined in torchmdnet.extensions.__init__, where torch.ops.torchmdnet_extensions.get_neighbor_pairs is called. I am unable to further trace back how this function is implemented.
MarshallYan commented 3 months ago

@wiederm It seems to me that input wrappers (NNPInput and *NeuralNetworkData) are trying to generalize the information needed to form a model in separated layers. However, I can't find where box and batch in TensorNet should be put. Which data structure should be modified? Or how do we manage input that couldn't be generalized in designed ways?

Just as a reminder, edge_index, edge_weight, edge_vec = self.distance(pos, batch, box) is in the forward function of class TensorNet in torchmd-net.

MarshallYan commented 3 months ago

@wiederm TensorNet uses pairlist with duplicated permute indices (0-1 and 1-0 are both included) and has a different order of these indices. But other than that, it seems that edge_weight==d_ij; edge_vec==r_ij. But is there an easier way to verify that? torch.allclose doesn't work directly since they are not the same in terms of dimensions and order.

MarshallYan commented 3 months ago

@wiederm I added a representation unit system to calculate the model with both angstrom and nanometer. Angstrom is used to compare the results with TensorNet, while I am also curious whether using nanometer inside makes a difference during training.

The only change I made that may affects 'main' branch is that 'CosineCutoff' now has a 'representation_unit' parameter that is set default with 'unit.nanometer'. (So I would assume that this won't even change anything outside tensornet.)

wiederm commented 3 months ago

It seems to me that input wrappers (NNPInput and *NeuralNetworkData) are trying to generalize the information needed to form a model in separated layers. However, I can't find where box and batch in TensorNet should be put. Which data structure should be modified? Or how do we manage input that couldn't be generalized in designed ways?

We discussed this offline: edge_weight is the pairwise distance, and edge_vec is the pairwise distance vector. Box vector and batch are passed to the input and don't need to be generated. The data structure is the NNPInput dataclass.

TensorNet uses pairlist with duplicated permute indices

So does modelforge! All pairwise interactions are listed, that means every atom $i$ encounters every atom $j$ withing.a neighborhood.

wiederm commented 3 months ago

ANI use nanometer by their original implementation

No, ANI uses angstrom. Testing is pretty simple, if you have a class that implements some transformation you can simply do:

input_data_in_angstrom = torch.abs(torch.randn(5,1)) * 5 # generate random data on interval [0,5]
reference_implementation = RefClass()
modelforge_implementation = MfClass()

# the rbf is defined on an interval [min_distance, max_distance]. 
# you can convert between rbfs in different length units by scaling the rbfs, 
# therefore you can convert from nm to A by multiplying with 10   

assert torch.allclose(modelforge_implementation(input_data_in_angstrom/10)*10, reference_implementation(input_data_in_angstrom)) 
MarshallYan commented 3 months ago

@wiederm I may found a bug in TensorNet. They did apply cutoff function twice in their code, which is inconsistent with their paper. I currently sticked to they way they coded it, but we may need a further discussion on this issue.

In class 'ExpNormalSmearing':

def forward(self, dist):
    dist = dist.unsqueeze(-1)
    return self.cutoff_fn(dist) * torch.exp(
        -self.betas
        * (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2
    )

The output of this forward function is called 'edge_attr' in 'TensorEmbedding'. This should be $e_k^{RBF}$ in their paper, without cutoff function applied.

However, in 'TensorEmbedding':

def _get_tensor_messages(
    self, Zij: Tensor, edge_weight: Tensor, edge_vec_norm: Tensor, edge_attr: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
    C = self.cutoff(edge_weight).reshape(-1, 1, 1, 1) * Zij
    eye = torch.eye(3, 3, device=edge_vec_norm.device, dtype=edge_vec_norm.dtype)[
        None, None, ...
    ]
    Iij = self.distance_proj1(edge_attr)[..., None, None] * C * eye
    Aij = (
        self.distance_proj2(edge_attr)[..., None, None]
        * C
        * vector_to_skewtensor(edge_vec_norm)[..., None, :, :]
    )
    Sij = (
        self.distance_proj3(edge_attr)[..., None, None]
        * C
        * vector_to_symtensor(edge_vec_norm)[..., None, :, :]
    )
    return Iij, Aij, Sij

Cutoff function C is timed again to I, A, and S.

MarshallYan commented 3 months ago

@wiederm The other thing that bothers me is that mathematically, Iij shouldn't be initialized with identity matrix, rather, it should be $\frac{1}{3}\rVert v\lVert^2\cdot I$ based on my calculation.

@wiederm I may found a bug in TensorNet. They did apply cutoff function twice in their code, which is inconsistent with their paper. I currently sticked to they way they coded it, but we may need a further discussion on this issue.

In class 'ExpNormalSmearing':

def forward(self, dist):
    dist = dist.unsqueeze(-1)
    return self.cutoff_fn(dist) * torch.exp(
        -self.betas
        * (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2
    )

The output of this forward function is called 'edge_attr' in 'TensorEmbedding'. This should be ekRBF in their paper, without cutoff function applied.

However, in 'TensorEmbedding':

def _get_tensor_messages(
    self, Zij: Tensor, edge_weight: Tensor, edge_vec_norm: Tensor, edge_attr: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
    C = self.cutoff(edge_weight).reshape(-1, 1, 1, 1) * Zij
    eye = torch.eye(3, 3, device=edge_vec_norm.device, dtype=edge_vec_norm.dtype)[
        None, None, ...
    ]
    Iij = self.distance_proj1(edge_attr)[..., None, None] * C * eye
    Aij = (
        self.distance_proj2(edge_attr)[..., None, None]
        * C
        * vector_to_skewtensor(edge_vec_norm)[..., None, :, :]
    )
    Sij = (
        self.distance_proj3(edge_attr)[..., None, None]
        * C
        * vector_to_symtensor(edge_vec_norm)[..., None, :, :]
    )
    return Iij, Aij, Sij

Cutoff function C is timed again to I, A, and S.

MarshallYan commented 2 months ago

@wiederm The pairwise indices used in torchmd-net is different from that in tensornet, resulting in different dimensions inputed into nn.Linear, causing different output.

In modelforge, if we have 3 atoms [0, 1, 2, 3], pair_indices is: [[0, 0, 0, 1, 1, 2], [1, 2, 3, 2, 3, 3]]; meanwhile, in torchmd-net, edge_index is: [[1, 2, 2, 3, 3, 3, 0, 0, 1, 0, 1, 2], [0, 0, 1, 0, 1, 2, 1, 2, 2, 3, 3, 3]]. It is not a big problem since I just need to attach the other way around to the end. However, the radial symmetry vector corresponding to the indices needs to be input into a nn.Linear layer, and in the previous example, the output contains 6 v.s. 12 atom pairs.

The real problem here is that in torchmd-net, since the first 6 pairs are the same as the last 6 pairs, I assume the output of that linear layer should also be the case (the first 6 tensors are the same as the last 6 tensors), while it's not the case. Thus in later summation, results from two ways are different.

One way that can definitely fix this problem is to convert the modelforge index system to the torchmd-net one inside representation module. But I think that may not be what we want in modelforge. We should discuss how we can solve the problem.

wiederm commented 2 months ago

Note the boolean set in the Pairlist: https://github.com/choderalab/modelforge/blob/5a312b0fe1c9383d7d4fef4651a1e36088a47e53/modelforge/potential/models.py#L56

If this is set to False (the default property) it will return the same number of pairs as the pairlist in tensornet. ANI is an outlier and requires this boolean to be set to True.

wiederm commented 2 months ago

I am not compley sure that I follow this:

The real problem here is that in torchmd-net, since the first 6 pairs are the same as the last 6 pairs, I assume the output of that linear layer should also be the case (the first 6 tensors are the same as the last 6 tensors), while it's not the case. Thus in later summation, results from two ways are different.

Can you provide a minimum example that examplifies your concern?

wiederm commented 2 months ago

If possible, it might be very useful to discuss this in a working test

MarshallYan commented 2 months ago

Note the boolean set in the Pairlist:

https://github.com/choderalab/modelforge/blob/5a312b0fe1c9383d7d4fef4651a1e36088a47e53/modelforge/potential/models.py#L56

If this is set to False (the default property) it will return the same number of pairs as the pairlist in tensornet. ANI is an outlier and requires this boolean to be set to True.

Thanks! It has been fixed by now.

ArnNag commented 2 months ago

@wiederm I may found a bug in TensorNet. They did apply cutoff function twice in their code, which is inconsistent with their paper. I currently sticked to they way they coded it, but we may need a further discussion on this issue.

In class 'ExpNormalSmearing':

def forward(self, dist):
    dist = dist.unsqueeze(-1)
    return self.cutoff_fn(dist) * torch.exp(
        -self.betas
        * (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2
    )

The output of this forward function is called 'edge_attr' in 'TensorEmbedding'. This should be ekRBF in their paper, without cutoff function applied.

However, in 'TensorEmbedding':

def _get_tensor_messages(
    self, Zij: Tensor, edge_weight: Tensor, edge_vec_norm: Tensor, edge_attr: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
    C = self.cutoff(edge_weight).reshape(-1, 1, 1, 1) * Zij
    eye = torch.eye(3, 3, device=edge_vec_norm.device, dtype=edge_vec_norm.dtype)[
        None, None, ...
    ]
    Iij = self.distance_proj1(edge_attr)[..., None, None] * C * eye
    Aij = (
        self.distance_proj2(edge_attr)[..., None, None]
        * C
        * vector_to_skewtensor(edge_vec_norm)[..., None, :, :]
    )
    Sij = (
        self.distance_proj3(edge_attr)[..., None, None]
        * C
        * vector_to_symtensor(edge_vec_norm)[..., None, :, :]
    )
    return Iij, Aij, Sij

Cutoff function C is timed again to I, A, and S.

Was this resolved? It seems that they wanted to make sure that the outputs of intermediate linear layers also decay to 0 at the cutoff distance, even before they multiply by the cutoff-filtered distances.

wiederm commented 2 months ago

@wiederm I may found a bug in TensorNet. They did apply cutoff function twice in their code, which is inconsistent with their paper. I currently sticked to they way they coded it, but we may need a further discussion on this issue. In class 'ExpNormalSmearing':

def forward(self, dist):
    dist = dist.unsqueeze(-1)
    return self.cutoff_fn(dist) * torch.exp(
        -self.betas
        * (torch.exp(self.alpha * (-dist + self.cutoff_lower)) - self.means) ** 2
    )

The output of this forward function is called 'edge_attr' in 'TensorEmbedding'. This should be ekRBF in their paper, without cutoff function applied. However, in 'TensorEmbedding':

def _get_tensor_messages(
    self, Zij: Tensor, edge_weight: Tensor, edge_vec_norm: Tensor, edge_attr: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
    C = self.cutoff(edge_weight).reshape(-1, 1, 1, 1) * Zij
    eye = torch.eye(3, 3, device=edge_vec_norm.device, dtype=edge_vec_norm.dtype)[
        None, None, ...
    ]
    Iij = self.distance_proj1(edge_attr)[..., None, None] * C * eye
    Aij = (
        self.distance_proj2(edge_attr)[..., None, None]
        * C
        * vector_to_skewtensor(edge_vec_norm)[..., None, :, :]
    )
    Sij = (
        self.distance_proj3(edge_attr)[..., None, None]
        * C
        * vector_to_symtensor(edge_vec_norm)[..., None, :, :]
    )
    return Iij, Aij, Sij

Cutoff function C is timed again to I, A, and S.

Was this resolved? It seems that they wanted to make sure that the outputs of intermediate linear layers also decay to 0 at the cutoff distance, even before they multiply by the cutoff-filtered distances.

Yes, this has been resolved! We came to the same conclusion and left it in the code.

wiederm commented 2 months ago

Closing PR since it is superseeded by PR #181