BunnySoCrazy / SECAD-Net

This repository provides the official code of SECAD-Net.
MIT License
35 stars 6 forks source link

Plan for updating the functions 'create_CAD_mesh' and 'draw_2d_im_sketch' in utils/cad_meshing.py #2

Closed eunzihong closed 1 year ago

eunzihong commented 1 year ago

Hi,

Thank you for the great work that reconstructs meshes and CAD models from the given voxel grids. I and my colleagues recently came across your work and got inspired by it, and trying to build upon your brilliant codebase.

I tested fine-tuning the models for some shapes and checked out that the mesh outputs are coming out well. I want to check the CAD reconstruction outputs, however, I noticed that the functions 'create_CAD_mesh' and 'draw_2d_im_sketch' in utils/cad_meshing.py are not provided.

Would you mind sharing the codes for research purposes? You can reach me out with my email, eunji.hong@kaist.ac.kr

Thank you. Eunji Hong.

BunnySoCrazy commented 1 year ago

Hello Ms. Hong. I am honored by your interest in SECAD-net. As mentioned in the README, I am currently struggling with other tasks. However, I appreciate your encouragement, and I will do my best to expedite the release of this code, even though I cannot provide you with an exact publishing time.

eunzihong commented 1 year ago

Hi, Pu Li. I deeply understand your situation and thank you for noticing me. Please let me know if you're updating the codes by replying to this thread. For now, I'll close this issue.

Ma-Weijian commented 9 months ago

Hi, there.

I still wonder if the sketch conversion code exists up to now.

Thanks a lot.

BunnySoCrazy commented 9 months ago

Apologies for my late response.

Here, I'm providing a preliminary implementation of the code, which will assist you in obtaining the BREP format or the corresponding mesh of the model. Please note that the current code does not use splines when creating sketches, although it's entirely feasible to do so.

I plan to update this code to the repository once I've ensured its robustness. I hope this current version can provide immediate assistance to those who are in urgent need of testing.

Please feel free to reach out if you have any further questions or concerns.


import os
import time
import torch
import mcubes
from utils import utils
import cv2
import trimesh
import numpy as np
from shapely.geometry import Polygon, MultiPolygon
from pyquaternion import Quaternion
from .utils import add_latent
import torch.nn as nn
import OCC.Core.gp as gp
import OCC.Core.BRepPrimAPI as BRepPrimAPI
import OCC.Core.BRepBuilderAPI as BRepBuilderAPI
import OCC.Core.TopoDS as TopoDS
import OCC.Core.BRep as BRep
import OCC.Core.BRepAlgoAPI as BRepAlgoAPI
from OCC.Core.BRepMesh import BRepMesh_IncrementalMesh
from OCC.Core.gp import gp_Pnt, gp_Pnt2d, gp_Vec
from OCC.Core.GeomAPI import GeomAPI_PointsToBSpline
from OCC.Core.TColgp import TColgp_Array1OfPnt
from OCC.Core.BRepBuilderAPI import BRepBuilderAPI_MakeEdge, BRepBuilderAPI_MakeWire, BRepBuilderAPI_MakeFace
from OCC.Core.BRepTools import breptools

def get_sketch_list(generator, shape_code, wh):
    """
    Sampling the shape of a 2D sketch from an implicit network.

    Notes:
        - This function is currently a preliminary implementation.
    """
    B, N = 1, wh*wh
    x1, x2 = np.mgrid[-0.5:0.5:complex(0,wh), -0.5:0.5:complex(0,wh)]
    x1 = torch.from_numpy(x1)*1
    x2 = torch.from_numpy(x2)*1
    sample_points = torch.dstack((x1,x2)).view(-1,2).unsqueeze(0).cuda()

    shape_code_cuda = shape_code.cuda()
    latent_list = [add_latent(sample_points, shape_code_cuda).float() for _ in range(4)]

    sdfs_2d_list = []
    for i in range(4):
        head = getattr(generator, f'sketch_head_{i}')
        latent = latent_list[i]
        sdfs_2d = head(latent).reshape(B,N,-1).float().squeeze().detach().cpu().unsqueeze(-1).numpy()
        sdfs_2d_list.append(sdfs_2d)

    sample_points = sample_points.detach().cpu().numpy()[0][:,:2]/1+0.5

    fill_sk_list=[]
    for dis in sdfs_2d_list:
        a = np.hstack((sample_points,dis))
        canvas = np.zeros((wh+80,wh+80))
        for i in a:
            canvas[int((i[1])*wh)][int((i[0])*wh)] = i[2]
        sk = canvas
        result = sk[:wh,:wh]
        bin_img = (result)
        imgray = (bin_img<-0.01).astype('uint8')*255

        ret, thresh = cv2.threshold(imgray, 254, 255, 0)
        contours, b = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
        my_contour_list = []  
        fill_polygon_list = []
        if len(contours)==0:
            my_contour_list.append(None)
            fill_sk_list.append(None)
            continue

        hir=b[0][...,3]

        for c in contours:
            my_contour_list.append(c[:,0,:])

        polygon_list = trimesh.path.polygons.paths_to_polygons(my_contour_list)
        for polygon in polygon_list:
            if polygon == None:
                continue

            path_2d = trimesh.path.exchange.misc.polygon_to_path(polygon)
            path = trimesh.path.Path2D(path_2d['entities'],path_2d['vertices'])

            max_values = np.max(path_2d['vertices'], axis=0)
            min_values = np.min(path_2d['vertices'], axis=0)
            size = np.linalg.norm(max_values - min_values)
            smooth_value = size
            sm_path = trimesh.path.simplify.simplify_spline(path, smooth=smooth_value, verbose=True)

            a,_ = sm_path.triangulate()
            polygon = trimesh.path.polygons.paths_to_polygons([a])
            if polygon[0] == None:
                continue
            Matrix = np.eye(3)
            Matrix[0,2] =- wh/2
            Matrix[1,2] =- wh/2
            polygon = trimesh.path.polygons.transform_polygon(polygon[0],Matrix)
            fill_polygon_list.append(polygon)
        if len(fill_polygon_list)==0:
            fill_sk_list.append(None)
            continue
        fill_sk = fill_polygon_list[0]
        for i in range(1,len(fill_polygon_list)):
            if hir[i]%2==1:
                fill_sk = fill_sk | fill_polygon_list[i]
            else:
                fill_sk = fill_sk - fill_polygon_list[i]
        fill_sk_list.append(fill_sk)
    return fill_sk_list

def create_cylinder(polygon, builder, compound, height, wh):
    """
    Extruding 3D cylinders from 2D contours. 
    If it's an internal contour, subtract it from the whole; 
    If it's an external contour, add it to the whole.

    Notes:
        - This function is currently a preliminary implementation.
    """
    for ring in [polygon.exterior, *polygon.interiors]:
        points = [gp_Pnt(float(pt[0]), float(pt[1]), 0) for pt in np.array(ring.coords[:-1])]
        points.append(gp_Pnt(float(points[0].X()), float(points[0].Y()), 0))

        points_array = TColgp_Array1OfPnt(1, len(points))
        for i_, pt in enumerate(points):
            points_array.SetValue(i_ + 1, pt)

        bspline_curve = GeomAPI_PointsToBSpline(points_array).Curve()
        edge = BRepBuilderAPI_MakeEdge(bspline_curve).Edge()

        wire_builder = BRepBuilderAPI.BRepBuilderAPI_MakeWire()
        wire_builder.Add(edge)
        exterior_wire = wire_builder

        exterior_face = BRepBuilderAPI.BRepBuilderAPI_MakeFace(exterior_wire.Wire())
        shape = BRepPrimAPI.BRepPrimAPI_MakePrism(exterior_face.Face(), gp.gp_Vec(0, 0, abs(height / 1) * wh*2 + np.finfo(float).eps)).Shape()

        if ring == polygon.exterior:
            builder.Add(compound, shape)
        else:
            compound = BRepAlgoAPI.BRepAlgoAPI_Cut(compound, shape).Shape()
    return compound

def create_CAD_mesh(generator, shape_code, shape_3d, CAD_mesh_filepath):
    """
    Reconstruct shapes with sketch-extrude operations.

    Notes:
        - This function is currently a preliminary implementation.
    """
    wh = 500
    fill_sk_list = get_sketch_list(generator, shape_code, wh)
    ext_3d_list=[]
    for i in range(len(fill_sk_list)):
        if fill_sk_list[i]==None:
            continue

        rotation_qua = shape_3d[0,:4,i].detach().cpu().numpy()
        translation = shape_3d[0,4:7,i].detach().cpu().numpy()
        height = shape_3d[0,7,i].detach().cpu().numpy()
        quaternion = Quaternion(rotation_qua)  #[w,x,y,z]

        inverse = quaternion.inverse
        quaternion = np.asarray([inverse[3], inverse[0], inverse[1],inverse[2]]) # [x,y,z,w]
        if abs(height)*wh*2+np.finfo(float).eps<1:
            continue

        compound = TopoDS.TopoDS_Compound()
        builder = BRep.BRep_Builder()
        builder.MakeCompound(compound)

        # process each part separately if MultiPolygon
        if isinstance(fill_sk_list[i], MultiPolygon):
            for polygon in fill_sk_list[i]:
                compound = create_cylinder(polygon, builder, compound, height, wh)
        else:
            compound = create_cylinder(fill_sk_list[i], builder, compound, height, wh)

        # Apply translation and rotation
        transformation = gp.gp_Trsf()
        transformation.SetTranslationPart(gp.gp_Vec(0, 0, -abs(height) * wh))
        compound = BRepBuilderAPI.BRepBuilderAPI_Transform(compound, transformation).Shape()

        # Scale all points back
        transformation = gp.gp_Trsf()
        transformation.SetScaleFactor(1/wh)
        compound = BRepBuilderAPI.BRepBuilderAPI_Transform(compound, transformation).Shape()

        # Apply quaternion rotation
        transformation = gp.gp_Trsf()
        quaternion =  np.asarray(nn.functional.normalize(torch.from_numpy(quaternion), dim=-1))

        transformation.SetRotation(gp.gp_Quaternion(quaternion[0], quaternion[1], quaternion[2], quaternion[3]))
        compound = BRepBuilderAPI.BRepBuilderAPI_Transform(compound, transformation).Shape()

        # Apply translation along X, Y, and Z axes
        transformation = gp.gp_Trsf()
        transformation.SetTranslationPart(gp.gp_Vec(translation[0] * 1, translation[1] * 1, translation[2] * 1))
        compound = BRepBuilderAPI.BRepBuilderAPI_Transform(compound, transformation).Shape()

        ext_3d_list.append(compound)

    # Create a compound to hold all the cylinders
    compound = TopoDS.TopoDS_Compound()
    builder = BRep.BRep_Builder()
    builder.MakeCompound(compound)

    for shape in ext_3d_list:
        builder.Add(compound, shape)

    # Export the shapes to stl format (meshed)
    mesh = BRepMesh_IncrementalMesh(compound, 0.05)
    from OCC.Core.StlAPI import StlAPI_Writer
    writer = StlAPI_Writer()
    writer.Write(mesh.Shape(), CAD_mesh_filepath + '_CAD.stl')

    status = breptools.Write(compound, CAD_mesh_filepath + '_CAD.brep')

    if status:
        print('CAD saving ', CAD_mesh_filepath + '_CAD.brep')
    else:
        print('Failed to save the CAD file.')