FabianFuchsML / se3-transformer-public

code for the SE3 Transformers paper: https://arxiv.org/abs/2006.10503
475 stars 69 forks source link

fit the training data #4

Open Chen-Cai-OSU opened 3 years ago

Chen-Cai-OSU commented 3 years ago

Hi Fabian,

When I use se3-transformer for my dataset. I find it seems quite difficult for the model to fit the training data. To understand why, I create a simple task in the following way.

I generate a few hundred point clouds sampled on the surface of an ellipsoid in 3d (centered at (0,0,0)). I construct a KNN (k=5) graph for each point cloud. The goal is to predict the first eigenvector of the covariance matrix for each point cloud. This is a type-1 feature. I am using the following model

ModuleList(
  (0): GSE3Res(
    (GMAB): ModuleDict(
      (v): GConvSE3Partial(structure=[(8, 0), (8, 1)])
      (k): GConvSE3Partial(structure=[(8, 0)])
      (q): G1x1SE3(structure=[(8, 0)])
      (attn): GMABSE3(n_heads=8, structure=[(8, 0), (8, 1)])
    )
    (project): G1x1SE3(structure=[(32, 0), (32, 1)])
    (add): GSum(structure=[(32, 0), (32, 1)])
  )
  (1): GNormSE3(num_layers=0, nonlin=ReLU(inplace=True))
  (2): GConvSE3(structure=[(32, 0), (32, 1)], self_interaction=True)
  (3): GNormSE3(num_layers=0, nonlin=ReLU(inplace=True))
  (4): GConvSE3(structure=[(32, 0), (32, 1)], self_interaction=True)
  (5): GNormSE3(num_layers=0, nonlin=ReLU(inplace=True))
  (6): GConvSE3(structure=[(32, 0), (32, 1)], self_interaction=True)
  (7): GNormSE3(num_layers=0, nonlin=ReLU(inplace=True))
  (8): GConvSE3(structure=[(1, 1)], self_interaction=True)
  (9): GAvgVecPooling(
    (pool): AvgPooling()
  )
)

however, when I trained with Adam optimizer with learning rate 1e-3 to minimize the mse loss between predicted eigenvector and the true eigenvector, I end up the training loss roughly around 0.05, and the cosine similarly between predicted eigenvector and true eigenvector is roughly 0.65. (which means the the angle is more than 45 degree)

The number of the parameters of the model is 643296, which is probably not large but my data set is also tiny (200 neighborhood graphs constructed from the point clouds). So I am a bit surprised why the model cannot even fit the data exactly. (I am trying to use more layers but Cuda memory quickly runs out)

Is there some places I should pay special attention to when using a se3-transformer? Maybe because the when the equivariance constraint is placed on the kernel, the model will be less flexible therefore harder to fit the training data? Should I try to increase model size, or try different optimizers and learning rates? I can share the data if needed.

Thank you!

FabianFuchsML commented 3 years ago

Hi Chen,

That's a great toy experiment! It's also quite different from the ones we ran, but here are my thoughts:

Chen-Cai-OSU commented 3 years ago

Hi Fabian,

Thanks for all the great suggestions! I would like to report some empirical observations.

I reduce the number of KNN graphs from point clouds to 1. In that case, I am able to fit the data exactly. However, even I slightly increase to the number of graphs to 5, then I am not the ability to fit the data. It also seems that more graphs I have, it's harder for SE3-transformer to fit the data. (The cosine similarity decreases when I increase the number of graphs)

I spend two hours trying different K, the size of the graph, varying architecture (switch from GConvSE3 to GSE3Res), and learning rate decay. however, I didn't see significant improvement.

According to this recent paper https://arxiv.org/abs/2010.02449 by Haggai Maron, SE3-transformer are universal. So I would like to try my best to solve this toy problem before moving to real data.

Let me know if you want to try it yourself. I can share the script.

FabianFuchsML commented 3 years ago

Hi Chen, yes, it would be very interesting to have a look at your script. Feel free to send it over!

Chen-Cai-OSU commented 3 years ago
# Created at 2020-11-11
# Summary: syn data shared with Fabian

import os
import pickle

import dgl
import networkx as nx
import scipy
import torch
from scipy.stats import special_ortho_group
from tqdm import tqdm

try:  # make sure load and save works for different versions of dgl
    from dgl import save_graphs, load_graphs
except ImportError:
    from dgl.data import load_graphs, save_graphs

import numpy as np
from scipy.spatial import distance
from sklearn.preprocessing import normalize

export_dir = os.path.join('pca-data', '')

def save_info(path, info):
    """ Save dataset related information into disk.

    Parameters
    ----------
    path : str
        File to save information.
    info : dict
        A python dict storing information to save on disk.
    """
    with open(path, "wb") as pf:
        pickle.dump(info, pf)

def load_info(path):
    """ Load dataset related information from disk.

    Parameters
    ----------
    path : str
        File to load information from.

    Returns
    -------
    info : dict
        A python dict storing information loaded from disk.
    """
    with open(path, "rb") as pf:
        info = pickle.load(pf)
    return info

class Data(object):
    def __init__(self, n=200, k=180, verbose=False):
        self.n = n
        self.k = k
        self.verbose = verbose
        self.node_attributes = ['x', 'y', 'z', 'dummy1', 'dummy2', 'dummy3']

    def _get_nx(self):
        """ generate points on a ellopsoid """
        points = np.random.rand(self.n, 3) - 0.5
        points = normalize(points, axis=1)
        points -= np.mean(points, axis=0)

        scale_matrix = np.diag(np.random.rand(3)) * 5  # np.array([[1,0,0],[0,2,0],[0,0,3]])
        rot = special_ortho_group.rvs(3)
        points = points @ scale_matrix @ rot

        D = distance.squareform(distance.pdist(points))
        closest = np.argsort(D, axis=1)
        closest = closest[:, 1:self.k + 1]
        g = nx.Graph()

        nodes = []
        for i in range(self.n):
            node = (i, {'x': points[i][0], 'y': points[i][1], 'z': points[i][2], 'dummy1': 1, 'dummy2': 1, 'dummy3': 1})
            nodes.append(node)
        g.add_nodes_from(nodes)

        edges = []
        for i in range(self.n):
            for j in closest[i]:
                edges.append((i, j))
        g.add_edges_from(edges)

        cov = points.T @ points * (1 / (self.n-1))
        vals, vecs = scipy.linalg.eigh(cov)
        return g, vecs

    def read_gml_dgl(self, n_feat=10):
        #          DGLGraph(num_nodes=21, num_edges=420,
        #          ndata_schemes={'x': Scheme(shape=(0,), dtype=torch.float32), 'f': Scheme(shape=(16, 1), dtype=torch.float32)}
        #          edata_schemes={'d': Scheme(shape=(3,), dtype=torch.float32), 'w': Scheme(shape=(0,), dtype=torch.float32)})
        nxg, vecs = self._get_nx()
        label = {}
        label['vecs'] = torch.tensor(vecs)

        try:
            g = dgl.from_networkx(nxg, node_attrs=self.node_attributes)
        except AttributeError:
            g = dgl.DGLGraph()
            g.from_networkx(nxg, node_attrs=self.node_attributes)

        node_data = torch.stack([g.ndata.__getitem__(k) for k in self.node_attributes], 1).type(torch.FloatTensor)
        g.ndata['x'] = node_data
        for k in self.node_attributes:
            if k != 'x':
                g.ndata.pop(k)

        src, dst = g.edges()[0], g.edges()[1]
        pos = g.ndata['x'][:, :3]
        try:
            g.ndata['f'] = g.ndata['x'][:, 3:, None]
            g.edata['d'] = pos[dst, :] - pos[src, :]
            g.edata['w'] = torch.rand(g.num_edges(), 0)
        except:  # 0.4.3.post2
            g.ndata['f'] = g.ndata['x'][:, 3:, None]
            g.edata['d'] = pos[dst, :] - pos[src, :]
            g.edata['w'] = torch.rand(len(g.edges), 0)

        return g, label

    def get_data(self, lib='dgl'):
        g, label_dict = self.read_gml_dgl()
        return g, label_dict

class SynDGLDataset():
    def __init__(self, root=export_dir, n_graph=2, n_pts=200, n_nbrs=200,reload=False):
        # modified from https://bit.ly/3jx6Eue
        self.save_path = root
        self.n_graph = n_graph
        self.reload = reload
        self.name = f'pcd-cov-{n_graph}'
        self.n_pts = n_pts
        self.n_nbrs = n_nbrs

        if not os.path.exists(self.save_path):
            os.makedirs(self.save_path)
        self.graph_path = os.path.join(self.save_path, self.name + '_dgl_graph.bin')
        self.info_path = os.path.join(self.save_path, self.name + '_info.pkl')

    def get_graphs(self):
        self.graphs = []
        self.labels = {}
        for idx in tqdm(range(1, self.n_graph + 1)):
            g, label_dict = Data(n=self.n_pts, k=self.n_nbrs).get_data(lib='dgl')
            self.graphs.append(g)
            self.labels[idx] = label_dict

    def save(self):
        if self.has_cache():
            return

        self.get_graphs()
        # save graphs and labels
        save_graphs(self.graph_path, self.graphs)  # {'labels': self.labels}
        # save other information in python dict
        save_info(self.info_path, self.labels)

    def load(self):
        print(f'load pca-cov from {self.graph_path}')
        graphs, _ = load_graphs(self.graph_path)
        label_dict = load_info(self.info_path)
        return graphs, label_dict

    def has_cache(self):
        ret = os.path.exists(self.graph_path) and os.path.exists(self.info_path)
        if self.reload: ret = False
        if ret: print(f'{self.graph_path} exist')
        return ret

if __name__ == '__main__':
    for n in [5]:
        D = SynDGLDataset(n_graph=n, n_pts=200, n_nbrs=200, reload=True)
        D.save()
        graphs, label_dict = D.load()
        print(graphs)
Chen-Cai-OSU commented 3 years ago

I am using the following version dgl: 0.4.3post2 torch: 1.6.0

You can modify the n_pts and n_nbrs for your needs. Let me know if you need any clarification. Thank you!

FabianFuchsML commented 3 years ago

Awesome, thank you! Hopefully, I will find the time to play around with it next week, but no promises.

FabianFuchsML commented 3 years ago

Another suggestion for analysing what's happening: you could try to have a few fully connected layers on each of the per-point outputs. This will obviously break the equivariance, but it could help analyse where the overfitting breaks. You can also directly compare that to using the same amount of fully connected layers applied directly (and again in a per-point fashion) to the inputs.

Chen-Cai-OSU commented 3 years ago

Awesome, thank you! Hopefully, I will find the time to play around with it next week, but no promises.

No worries. I guess right now everyone is busy with ICLR rebuttal:-)

Chen-Cai-OSU commented 3 years ago

Hi Fabian, I have another question regarding the time and memory of the se3-transformer (in general equivariant NN). It seems that the equivariance comes with a price of slower runtime and more memory.

In the paper, you mention that it takes 2.5 days to train the nets for QM9. If it's only for one regression (or it's the total time for 6 tasks?), then it's roughly 72min for one epoch. In contrast, I run a simple example on qm9 here https://github.com/rusty1s/pytorch_geometric/blob/master/examples/qm9_nn_conv.py and it took 3min per epoch.

I was wondering is it 20x more time roughly the right scale here? Also, seems that equivariance nets in general is memory expansive. Would you like to point out the source of slow speed and high memory usage? I am interested in improving it if it is not too hard.

blondegeek commented 3 years ago

Hi @Chen-Cai-OSU and @FabianFuchsML,

Multiple things to check here: 1) Make sure the model L=1 output convention and the coordinate convention you use for your vector match. For example, in e3nn L=1 features are given in order (y, z, x) to match the conventions for real spherical harmonics whereas Cartesian coordinates are (x, y, z) 2) One of the reasons why the model is not fitting may be because the question is not symmetrically well-posed. There are actually degenerate answers for the first eigenvector -- it can be v OR -v (although the degeneracy can be higher depending on the shape). You can try instead predict an L=2 feature which is capable of predicting a double-headed ray (similar to a vector but symmetric under 180 rotations). For example, you would just plug in the vector from your eigensolver into the expressions for the L=2 spherical harmonics (xy, yz, 2z^2-x^2-y^2, zx, x^2-y^2) to get the appropriate prediction coefficients.

Regarding memory and runtime of equivariant networks, (at least for e3nn) these bottlenecks are primarily due to the combinatorial nature of the geometric tensor product (contracting two representation indices with Clebsch-Gordan coefficients) and the fact that there are no readily available CUDA kernels for doing these operations. There are many ways around this. For example, one can create specialized geometric tensor product modules that do not do all possible products, but rather a subset.

Hope that helps a bit! Tess

Chen-Cai-OSU commented 3 years ago

Hi @blondegeek!

Thanks a lot for all the suggestions and explanations. I think for 1) I have checked that the output rotates properly when I rotate the input 2) I am using (negative) absolute cosine similarity as loss and metric so the 'up-to-sign' problem should be already solved. However, I am not able to fit even three point clouds, (200 points each point cloud) which is a bit puzzing. I will wait for Fabian to take a look at the dataset.

Would you like to elaborate on ``combinatorial nature of the geometric tensor product''? You mean to calculate the tensor product of type-a and type-b (a,b=0,1,2,...) irreducible representations is very expansive? How do you choose the subset of products? Thank you!

blondegeek commented 3 years ago

Hi @Chen-Cai-OSU,

Yes, the tensor product is expansive. How you choose subsets will depend on the application and determined by experiment :) For example, you can take an approach similar to depthwise convolutions where you only interact certain subsets of features with each other. In general, max L=1 hidden features seems effective for several tasks. There are more geometric tasks where you need higher L.

Have you been able to successfully overfit to one example? That should help debug whether the task is set up correctly.

Another thing to be aware of, even if your loss function allows for both -v and v to be "correct", the network will ALWAYS output the linear combination of the two degenerate possibilities, which in this case is zero. (more detail for why this is here: https://arxiv.org/abs/2007.02005)

Best, Tess

On Sun, Nov 15, 2020 at 10:54 AM Chen-Cai-OSU notifications@github.com wrote:

Hi @blondegeek https://github.com/blondegeek!

Thanks a lot for all the suggestions and explanations. I think for 1) I have checked that the output rotates properly when I rotate the input 2) I am using (negative) absolute cosine similarity as loss and metric so the 'up-to-sign' problem should be already solved. However, I am not able to fit even three point clouds, (200 points each point cloud) which is a bit puzzing. I will wait for Fabian to take a look at the dataset.

Would you like to elaborate on ``combinatorial nature of the geometric tensor product''? You mean to calculate the tensor product of type-a and type-b (a,b=0,1,2,...) irreducible representations is very expansive? How do you choose the subset of products? Thank you!

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/FabianFuchsML/se3-transformer-public/issues/4#issuecomment-727618273, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA7EGVX5JJ64NJD6I3T7OQTSQAPWRANCNFSM4TN7UCMQ .

FabianFuchsML commented 3 years ago

Hi both!

Some additional remarks about speed (I will put this somewhere in the readme):

Here are some ideas about speeding up the SE3-Transformer:

Best Fabian

FabianFuchsML commented 3 years ago

@blondegeek

Regarding memory and runtime of equivariant networks, (at least for e3nn) these bottlenecks are primarily due to the combinatorial nature of the geometric tensor product (contracting two representation indices with Clebsch-Gordan coefficients) and the fact that there are no readily available CUDA kernels for doing these operations. There are many ways around this. For example, one can create specialized geometric tensor product modules that do not do all possible products, but rather a subset.

This is super interesting! It sounds sort of similar to what I found out when I spent some time digging into what the bottlenecks are - but then also not quite. I wish I could remember more precisely what my findings were. In the beginning, the bottleneck was definitely purely the spherical harmonics. But after speeding them up by shifting the computations to the GPU (all the credit here goes to Daniel), the bottleneck was equally split between multiple parts - one of them being constructing the basis vectors from the spherical harmonics and the Clebsch-Gordon coefficients. It sounds like there is some potential if one wanted to get into CUDA programming.

Chen-Cai-OSU commented 3 years ago

Thanks @blondegeek for the reference and explanations.

Have you been able to successfully overfit to one example? That should help debug whether the task is set up correctly.

Yes. For a single pointclouds I can overfit. For two point clouds, I can also overfit. But starting from 3 point clouds, I cannot overfit anymore :-(

Another thing to be aware of, even if your loss function allows for both -v and v to be "correct", the network will ALWAYS output the linear combination of the two degenerate possibilities, which in this case is zero.

I didn't understand why this is the case. In your paper, I understand it's easy to convert a rectangle into a square as the latter has more symmetry but not the other way around. But how is this related to the prediction of eigenvectors? If my loss is set to encourage the output to be either v or -v, why will the network want to output the 0? I guess this is a subtle point that I haven't yet grasped.

FabianFuchsML commented 3 years ago

Hi Chen,

I had a little time to look through your code today. It's a great toy example and I would love to try to debug / crack it, but I am not too optimistic that I will have the time to get to it. I looked at how you sample points on an ellipsoide in _get_nx(). At first, you seem to sample on a sphere. Then there is a line points -= np.mean(points, axis=0), which moves the sphere away from the origin. Is this on purpose? You did state at some point that you were trying to sample from an ellipsoid centered around 0. Maybe I am overlooking something and you are re-centering it later, but I would recommend working with ellipsoids which are centered around 0 (as you said), as this seems less prone to errors down the line. Also, did you visualise the ellipsoids together with their principal axes?

Chen-Cai-OSU commented 3 years ago

Then there is a line points -= np.mean(points, axis=0), which moves the sphere away from the origin. Is this on purpose?

Hi Fabian, I remember I tried both 1) points -= np.mean(points, axis=0) and 2) not recentering it. np.mean(points, axis=0) is very close to the center. It's just the sample mean. I tried both versions and didn't see significant differences in term of fitting training data.

Chen-Cai-OSU commented 3 years ago
Screen Shot 2020-11-19 at 5 35 58 PM

I believe the code is correct.

blondegeek commented 3 years ago

Ok, so here's how I would do it with e3nn and torch_geometric (because I'm not familiar with the se3-transformer codebase or gdl).

Here's a simple network fitting on a single example. Note, I'm not fitting to eigenvectors but rather the rotated scale matrix, which is basically the matrix you are getting the eigenvectors for.

Here's the same network fitting on multiple training examples. It needs more training, but it's starting to get the idea.

Chen-Cai-OSU commented 3 years ago

Ok, so here's how I would do it with e3nn and torch_geometric (because I'm not familiar with the se3-transformer codebase or gdl).

Here's a simple network fitting on a single example. Note, I'm not fitting to eigenvectors but rather the rotated scale matrix, which is basically the matrix you are getting the eigenvectors for.

Here's the same network fitting on multiple training examples. It needs more training, but it's starting to get the idea.

Thanks @blondegeek for the nice notebooks! I am starting to trying out TFN now.

My quick question is that is about predicting the eigenvectors. Is it a bad idea to try to use equivariance NN to predict the eigenvectors in this case? Is it because of both v and -v are the right answer so the NN will tend to output 0? I still don't understand why this is the case. Even if I set the loss (to maximize the absolute cosine similarity between predicted eigenvector and true eigenvector) to account for this "up-to-sign" problem, is this still doomed to fail?

blondegeek commented 3 years ago

The key issue is that eigenvector solvers are not symmetry preserving, they "pick" an eigenvector typically based on a random initialization or similarly arbitrary procedure. This becomes especially problematic for symmetric structures.

Let's consider two higher symmetry cases. Let's say the scaling matrix is the identity torch.eye(3). What are the principle eigenvectors? They are degenerate -- a sphere is radially symmetric, any three orthogonal directions are equally valid and can be in any order.

How about if the scaling matrix is something like torch.eye(3) * [1, 1, 2] so that the ellispoid is radially symmetric along one axis. You have a similar problem, there is no unique way to choose the eigenvectors in a rotation equivariant manner.

So the issue is more than "just a sign" -- the issue is that the question is symmetrically ill posed. Principle axes are not vectors, they are double headed rays and you need L=2 features to describe them. A 3x3 matrix can handle the spherically symmetric case, it will just predict a matrix with a scalar trace (an identity matrix) and any less symmetric case.

Hope that helps!

Chen-Cai-OSU commented 3 years ago

Thanks @blondegeek for the further clarification! I understand the case where eigenvectors with the multiplicity, any vector in the eigenspace can be taken as eigenvectors.

Principle axes are not vectors, they are double headed rays and you need L=2 features to describe them

What is double headed rays exactly? I saw this slide (at around 20 min) in your talk https://sites.google.com/view/equiv-data-aug/home

Screen Shot 2020-11-21 at 12 54 47 PM

I can find the Pseudovector in Wikipedia but I didn't find any good references on the double headed rays and spiral. I am familiar with covariant/contravariant tensors but never head of double headed rays and spiral. Do you mind pointing out some references?

blondegeek commented 3 years ago

https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.113.165502

On Sat, Nov 21, 2020 at 10:00 Chen-Cai-OSU notifications@github.com wrote:

Thanks @blondegeek https://github.com/blondegeek for the further clarification! I understand the case where eigenvectors with the multiplicity, any vector in the eigenspace can be taken as eigenvectors.

Principle axes are not vectors, they are double headed rays and you need L=2 features to describe them

What is double headed rays exactly? I saw this slide (at around 20 min) in your talk https://sites.google.com/view/equiv-data-aug/home [image: Screen Shot 2020-11-21 at 12 54 47 PM] https://user-images.githubusercontent.com/47577816/99883983-be541f80-2bf8-11eb-9fb7-4bcec237fa94.png

I can find the Pseudovector in Wikipedia but I didn't find any good references on the double headed rays and spiral. I am familiar with covariant/contravariant tensors but never head of double headed rays and spiral. Do you mind pointing out some references?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/FabianFuchsML/se3-transformer-public/issues/4#issuecomment-731613838, or unsubscribe https://github.com/notifications/unsubscribe-auth/AA7EGVXZFXHGSDYSGNZEAJLSQ755HANCNFSM4TN7UCMQ .

Chen-Cai-OSU commented 3 years ago

@blondegeek Hi Tess, I am using e3nn (it's nice that change of basis matrix can be handled by to_irrep_transformation.) but I had some issues verifying the equivariance for the 3*3 matrix (Rs out=[(1, 0, 1), (1, 2, 1)])

Would you mind taking a look at https://github.com/e3nn/e3nn/issues/149? Many thanks!