nmwsharp / diffusion-net

Pytorch implementation of DiffusionNet for fast and robust learning on 3D surfaces like meshes or point clouds.
https://arxiv.org/abs/2012.00888
MIT License
398 stars 50 forks source link

Segmentation on the simplified human dataset (face label) #31

Open dasbai opened 1 year ago

dasbai commented 1 year ago

Thank you for your great work! I'm a little confused about the segmentation accuracy of the Human dataset, which is provided by PD-Mesh (Simplified inputs with hard ground truth at faces). I try to use DiffusionNet (xyz) on it, but the accuracy is only 81%. I try to use hks as input, and the accuracy is only 79%.
image I also try to adjust C_width, but it seems to have no effect. emmmmm..... I want to know what parameters I need to adjust....

The code for loading data and training: human_seg_dataset.py

import os
import sys
import numpy as np
import torch
from torch.utils.data import Dataset
import potpourri3d as pp3d
sys.path.append(os.path.join(os.path.dirname(__file__), "../../"))  # add the path to the DiffusionNet src
import src.diffusion_net as diffusion_net

def is_mesh_file(filename):
    return any(filename.endswith(extension) for extension in ['.obj', 'off'])

class HumanDataset(Dataset):
    def __init__(self, root_dir, phase, k_eig=128, op_cache_dir=None):
        self.k_eig = k_eig 
        self.root_dir = root_dir
        self.cache_dir = os.path.join(root_dir, "cache")
        self.op_cache_dir = op_cache_dir
        self.dir = os.path.join(self.root_dir, phase)
        self.paths = self.make_dataset(self.dir)
        self.seg_paths = self.get_seg_files(self.paths, os.path.join(self.root_dir, 'seg'))

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        path = self.paths[index]
        label = np.loadtxt(open(self.seg_paths[index], 'r'), dtype='float64')
        verts, faces = pp3d.read_mesh(path)
        verts = torch.tensor(verts).float()

        faces = torch.tensor(faces)
        label = torch.tensor(label).long()
        frames, mass, L, evals, evecs, gradX, gradY = diffusion_net.geometry.get_operators(verts, faces,
                                                                                           k_eig=self.k_eig,
                                                                                           op_cache_dir=self.op_cache_dir)
        return verts, faces, frames, mass, L, evals, evecs, gradX, gradY, label

    @staticmethod
    def get_seg_files(paths, seg_dir, seg_ext='.eseg'):
        segs = []
        for path in paths:
            segfile = os.path.join(seg_dir, os.path.splitext(os.path.basename(path))[0] + seg_ext)
            assert (os.path.isfile(segfile))
            segs.append(segfile)
        return segs

    @staticmethod
    def make_dataset(path):
        meshes = []
        assert os.path.isdir(path), '%s is not a valid directory' % path
        for root, _, fnames in sorted(os.walk(path)):
            for fname in fnames:
                if is_mesh_file(fname):
                    path = os.path.join(root, fname)
                    meshes.append(path)

        return meshes

seg_human.py

import os
import sys
import argparse
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

sys.path.append(os.path.join(os.path.dirname(__file__), "../../"))  # add the path to the DiffusionNet src
import src.diffusion_net as diffusion_net
from experiments.human_seg.human_seg_dataset import HumanDataset

# Parse a few args
parser = argparse.ArgumentParser()
parser.add_argument("--input_features", type=str, help="('xyz' or 'hks')", default='hks')
parser.add_argument("--features_width", type=int, choices={32, 64, 128, 256}, default=128)
parser.add_argument("--data_name", type=str, choices={'human'}, default='human')
parser.add_argument("--n_class", type=int, choices={8}, default=8)
args = parser.parse_args()

# system things
device = torch.device('cuda:0')
dtype = torch.float32
input_features = args.input_features  # one of ['xyz', 'hks']
C_width = args.features_width
k_eig = 128
n_epoch = 200
lr = 1e-3
decay_every = 50
decay_rate = 0.5
augment_random_rotate = (input_features == 'xyz')

# === Load datasets
data_name = args.data_name  # aliens vases chairs
base_path = os.path.dirname(__file__)
op_cache_dir = os.path.join(base_path, "data", "op_cache", data_name)
dataset_path = os.path.join(base_path, "data", data_name)
n_class = args.n_class
# Load the test dataset
test_dataset = HumanDataset(dataset_path, 'test', k_eig=k_eig, op_cache_dir=op_cache_dir)
test_loader = DataLoader(test_dataset, batch_size=None)
# Load the train dataset
train_dataset = HumanDataset(dataset_path, 'train', k_eig=k_eig, op_cache_dir=op_cache_dir)
train_loader = DataLoader(train_dataset, batch_size=None, shuffle=True)

# === Create the model
C_in={'xyz': 3, 'hks': 16}[input_features] # dimension of input features

model = diffusion_net.layers.DiffusionNet(C_in=C_in,
                                          C_out=n_class,
                                          C_width=C_width,
                                          N_block=4, 
                                          last_activation=lambda x : torch.nn.functional.log_softmax(x, dim=-1),
                                          outputs_at='faces', 
                                          dropout=True)

model = model.to(device)

# === Optimize
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

def train_epoch(epoch):

    # Implement lr decay
    if epoch > 0 and epoch % decay_every == 0:
        global lr 
        lr *= decay_rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr 

    # Set model to 'train' mode
    model.train()
    optimizer.zero_grad()

    correct = 0
    total_num = 0
    for data in tqdm(train_loader):

        # Get data
        verts, faces, frames, mass, L, evals, evecs, gradX, gradY, labels = data

        # Move to device
        verts = verts.to(device)
        faces = faces.to(device)
        frames = frames.to(device)
        mass = mass.to(device)
        L = L.to(device)
        evals = evals.to(device)
        evecs = evecs.to(device)
        gradX = gradX.to(device)
        gradY = gradY.to(device)
        labels = labels.to(device)

        # Randomly rotate positions
        if augment_random_rotate:
            verts = diffusion_net.utils.random_rotate_points(verts)

        # Construct features
        if input_features == 'xyz':
            features = verts
        elif input_features == 'hks':
            features = diffusion_net.geometry.compute_hks_autoscale(evals, evecs, 16)

        # Apply the model
        preds = model(features, mass, L=L, evals=evals, evecs=evecs, gradX=gradX, gradY=gradY, faces=faces)

        # Evaluate loss
        loss = torch.nn.functional.nll_loss(preds, labels)
        loss.backward()

        # track accuracy
        pred_labels = torch.max(preds, dim=1).indices
        this_correct = pred_labels.eq(labels).sum().item()
        this_num = labels.shape[0]
        correct += this_correct
        total_num += this_num

        # Step the optimizer
        optimizer.step()
        optimizer.zero_grad()

    train_acc = correct / total_num
    return train_acc

# Do an evaluation pass on the test dataset 
def test():

    model.eval()

    correct = 0
    total_num = 0
    with torch.no_grad():

        for data in tqdm(test_loader):

            # Get data
            verts, faces, frames, mass, L, evals, evecs, gradX, gradY, labels = data

            # Move to device
            verts = verts.to(device)
            faces = faces.to(device)
            frames = frames.to(device)
            mass = mass.to(device)
            L = L.to(device)
            evals = evals.to(device)
            evecs = evecs.to(device)
            gradX = gradX.to(device)
            gradY = gradY.to(device)
            labels = labels.to(device)

            # Construct features
            if input_features == 'xyz':
                features = verts
            elif input_features == 'hks':
                features = diffusion_net.geometry.compute_hks_autoscale(evals, evecs, 16)

            # Apply the model
            preds = model(features, mass, L=L, evals=evals, evecs=evecs, gradX=gradX, gradY=gradY, faces=faces)

            # track accuracy
            pred_labels = torch.max(preds, dim=1).indices
            this_correct = pred_labels.eq(labels).sum().item()
            this_num = labels.shape[0]
            correct += this_correct
            total_num += this_num

    test_acc = correct / total_num
    return test_acc 

print("Training...")

for epoch in range(n_epoch):
    train_acc = train_epoch(epoch)
    test_acc = test()
    print("Epoch {} - Train overall: {:06.3f}%  Test overall: {:06.3f}%".format(epoch, 100*train_acc, 100*test_acc))

# Test
test_acc = test()
print("Overall test accuracy: {:06.3f}%".format(100*test_acc))

I'm sorry that this content is too much. Any hint on this problem is appreciated! Thanks! Have a nice day!