WarBean / tps_stn_pytorch

PyTorch implementation of Spatial Transformer Network (STN) with Thin Plate Spline (TPS)
926 stars 155 forks source link

Please help me!!! #18

Open viet24dung opened 3 years ago

viet24dung commented 3 years ago

Awesome for your repo

I have a question: I warped my image with grid point and noisy point and it work perfect. When i have many points on original image and i want to get corresponding position of coordinates on transformed image how to do.

Currently my solution is to traverse each point with the transformed grid and find the corresponding coordinates. For example the picture below. Untitled1 .

My code

import os
import torch
from PIL import Image
import cv2
import numpy as np
from sklearn.metrics.pairwise import euclidean_distances
import morphops as mops
import torchvision

....

points = torch.tensor([0.33145134998962683,0.37168062334870805,0.291440680477757,0.36570985028286535,0.5959663317625441,0.35376830415117994,0.6433863845173527,0.3268998253548878,0.6737648558134021,0.3224217455555058,0.7078480187309208,0.32988521188780917,0.7330399217569128,0.34705118445210686,0.705625203758039,0.36645619691609566,0.667096410894757,0.3701879300822474,0.6367179395987077,0.36421715701640467,0.3603479446370884,0.7135073813682026,0.4099908123647787,0.6829071694057588,0.46852493998399575,0.6597704237756183,0.4996443496043389,0.6672338901079217,0.5307637592246821,0.6605167704088487,0.589297886843899,0.6829071694057588,0.6359770012744138,0.716492767901124,0.5841113185738419,0.7403758601644947,0.5441006490619721,0.7530637529294105,0.4981624729557511,0.7575418327287925,0.4529652351738241,0.7538100995626408,0.41369550398624816,0.7388831668980341,0.37442577279867223,0.7105219948352812,0.459633680092469,0.6978341020703654,0.4996443496043389,0.6963414088039048,0.5403959574405026,0.6970877554371352,0.6263448030585934,0.7142537280014329,0.5389140807919149,0.7135073813682026,0.4996443496043389,0.7157464212678936,0.46037461841676297,0.7127610347349723,0.32552384339527574,0.34555849118564624,0.6730239174891082,0.34406579791918557]).reshape((-1,2))

**warped_C = tps(Variable(torch.unsqueeze(source_control_points, 0)))**

dis = euclidean_distances(points.cpu().numpy()*2-1,(warped_C))
dis = np.argmin(dis,1)
Y,X = np.unravel_index(dis,(h,w))
new_point = np.array([[x,y]for x,y in zip(X,Y)]).reshape((-1,2)).astype(int)

It takes a lot of time and memory. However I wonder if there is another way that relies on matrixs (K,U,P)