Open Yuanhang8605 opened 6 years ago
The cython code is right, I have test it. After modify some code of you, Only training less than 20000 step, about half day, I can get good pretrained result on Synthtext.
@Yuanhang8605 Thanks for your efforts, I tried the code you provided and my training speed is up to 2x as before. But I get some problem, the loss will be Nan suddenly after I use your code. I use default setting and rctw dataset, do you konw why? Thanks.
@infinitas-loop I think this is not the reason of my code. I suggest you to use a try except block to capture exception like this:
def _generate_gt_maps(gt_box, num_gt_box, gt_mask, image_size):
h, w = image_size
# define the masks.
poly_mask = np.zeros((h, w), dtype=np.uint8)
score_map = np.zeros((h, w), dtype=np.uint8)
geo_map = np.zeros((h, w, 5), dtype=np.float32)
shrinked_poly_list = []
rectangle_list = []
rotate_angle_list = []
for poly_idx in range(num_gt_box):
try:
poly = gt_box[poly_idx]
tag = gt_mask[poly_idx]
ex_poly = poly.astype(np.int32)[np.newaxis, :, :]
if tag > 0:
# fit the RBox
# first fit parallelograms
parallelogram = fit_parallelograms(poly)
# fit the rbox
rectangle = rectangle_from_parallelogram(parallelogram)
# sort the rbox and use the AABB rule to generate the angle.
rectangle, rotate_angle, is_right = sort_rectangle(rectangle)
if not is_right:
cv2.fillPoly(score_map, ex_poly, 2)
continue
rectangle_list.append(rectangle)
rotate_angle_list.append(rotate_angle)
# cal the r in the paper.
r = [None, None, None, None]
for i in range(4):
r[i] = min(np.linalg.norm(poly[i] - poly[(i + 1) % 4]),
np.linalg.norm(poly[i] - poly[(i - 1) % 4]))
# score map
shrinked_poly = shrink_poly(poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
shrinked_poly_list.append(shrinked_poly)
# first fill the whole poly with 2, means ignore.
# then fill the shrinked part with 1, means positive.
cv2.fillPoly(score_map, ex_poly, 2)
cv2.fillPoly(score_map, shrinked_poly, 1)
# poly_mask indicate which poly is now.
cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
gen_geo_map.gen_geo_map(geo_map, xy_in_poly, rectangle, rotate_angle)
else:
# ignored poly.
cv2.fillPoly(score_map, ex_poly, 2)
except Exception as e:
# if exception, ignore poly
cv2.fillPoly(score_map, ex_poly, 2)
continue
return score_map, geo_map, shrinked_poly_list, rectangle_list, rotate_angle_list
this is my rewrited code, you may change some. It means you use try to capture the exception, in the except module, you just fill the poly with some value to ignore it( I use 2 to label ignore, but you may have to change the loss part. ), so if you don't want to modify the code, you can just fill with 0.
@infinitas-loop I have rewrite the sort_rectangle function to avoid wrong result. like this:
def sort_rectangle(poly):
# sort the four coordinates of the polygon, points in poly should be sorted clockwise
# First find the lowest point
p_lowest = np.argmax(poly[:, 1])
if np.count_nonzero((poly[:, 1] - poly[p_lowest, 1]) < 1e-3) == 2:
# 搴曡竟骞宠浜嶺杞? 閭d箞p0涓哄乏涓婅
p0_index = np.argmin(np.sum(poly, axis=1))
p1_index = (p0_index + 1) % 4
p2_index = (p0_index + 2) % 4
p3_index = (p0_index + 3) % 4
return poly[[p0_index, p1_index, p2_index, p3_index]], 0.
else:
# 鎵惧埌鏈€浣庣偣鍙宠竟鐨勭偣
p_lowest_right = (p_lowest - 1) % 4
p_lowest_left = (p_lowest + 1) % 4
angle = np.arctan(-(poly[p_lowest][1] - poly[p_lowest_right][1])/(poly[p_lowest][0] - poly[p_lowest_right][0] + 1e-3))
# assert angle > 0
if angle <= 0:
return poly[[p0_index, p1_index, p2_index, p3_index]], angle, False
# print(angle, poly[p_lowest], poly[p_lowest_right])
if angle/np.pi * 180 > 45:
# 杩欎釜鐐逛负p2
p2_index = p_lowest
p1_index = (p2_index - 1) % 4
p0_index = (p2_index - 2) % 4
p3_index = (p2_index + 1) % 4
return poly[[p0_index, p1_index, p2_index, p3_index]], -(np.pi/2 - angle), True
else:
# 杩欎釜鐐逛负p3
p3_index = p_lowest
p0_index = (p3_index + 1) % 4
p1_index = (p3_index + 2) % 4
p2_index = (p3_index + 3) % 4
return poly[[p0_index, p1_index, p2_index, p3_index]], angle, True
when use it , like this:
rectangle, rotate_angle, is_right = sort_rectangle(rectangle)
if not is_right:
cv2.fillPoly(score_map, ex_poly, 2)
continue
@infinitas-loop rewrite the following function to avoid wrong result:
def point_dist_to_line(p1, p2, p3):
# compute the distance from p3 to p1-p2
# return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
a = np.linalg.norm(p1 - p2)
b = np.linalg.norm(p2 - p3)
c = np.linalg.norm(p3 - p1)
s = (a + b + c) / 2.0
area = (s*(s-a)*(s-b)*(s-c)) ** 0.5
if a < 1.0:
return (b + c)/2.0
return 2 * area / a
def fit_line(p1, p2):
"""fit a line ax+by+c = 0
Args:
p1: [x1, x2]
p2: [y1, y2]
"""
if abs(p1[0] - p1[1]) < 1e-3:
return [1., 0., -p1[0]]
else:
# [k, b] = np.polyfit(p1, p2, deg=1)
x1, x2 = p1
y1, y2 = p2
k = (y2 - y1) / (x2 - x1)
b = y1 - k * x1
return [k, -1., b]
@Yuanhang8605 Sorry for the Late Reply, this time I am busy with other things!
I did some debug, I found the Nan appear in geo_map
after gen_geo_map.gen_geo_map(geo_map, xy_in_poly, rectange, rotate_angle)
, I check your .pyx code, and add this test code:
if isnan(geo_map[y, x, 0]):
printf("p0:%f,%f; p1:%f,%f; p:%d,%d\n",p0[0],p0[1],p1[0],p1[1],x,y)
printf("a:%f,b:%f,c:%f\n",a,b,c)
printf("tri:%f\n",tri_area(a,b,c))
printf("geo_map: %f\n",geo_map[y,x,0])
I found sometimes the output is like below:
p0:379.192535,165.468445; p1:401.493347,163.520020; p:396,164
a:22.385767,b:16.871490,c:5.514277
tri:-nan
geo_map: -nan
the a,b,c may not form a triangle because of calculating precision(a=b+c
), so the area of triangle will be nan. I think that is where the problem lies. so I add this code to aviod it,
if isnan(tri_area(a, b, c)):
printf("area is nan!\n")
geo_map[y, x, 0] =0
else:
geo_map[y, x, 0] = 2 * tri_area(a, b, c) / a
@infinitas-loop I have update my code, you can download and try, I think the problem has been solved. Thank you!
@infinitas-loop Actually it's because the area = (s(s-a)(s-b)*(s-c)) 0.5 part has problem, because when cal sqrt, the s(s-a)(s-b)(s-c) part may be negative. so I modify this to: area = abs(s(s-a)(s-b)(s-c)) 0.5
Thank you! It's a very small bug, but it's very hard to find.
@Yuanhang8605 yeah, I have seen your update. I think when the value is negative, the result has no mathematical meaning (because a triangle will be postivate in any case ), so setting the value to zero may be more reasonable.
Thnak you for sharing your code!
The time spending on ground truth is very long. I change the function of "caclulate_geo_map" by vector and the speed of trainning is twice.
def caclulate_geo_map(geo_map, xy_in_poly, rectange, rotate_angle):
p0_rect, p1_rect, p2_rect, p3_rect = rectange
height = (point_dist_to_line(p0_rect, p1_rect, p2_rect) + point_dist_to_line(p0_rect, p1_rect, p3_rect)) / 2
width = (point_dist_to_line(p3_rect, p0_rect, p1_rect) + point_dist_to_line(p3_rect, p0_rect, p2_rect)) / 2
ys = xy_in_poly[:, 0]
xs = xy_in_poly[:, 1]
num_points = xy_in_poly.shape[0]
top_distance_tmp = point_dist_to_line(np.tile(p0_rect, (num_points, 1)),
np.tile(p1_rect, (num_points, 1)),
xy_in_poly[:, ::-1])
geo_map[ys, xs, 0] = top_distance_tmp
right_distance_tmp = point_dist_to_line(np.tile(p1_rect, (num_points, 1)),
np.tile(p2_rect, (num_points, 1)),
xy_in_poly[:, ::-1])
geo_map[ys, xs, 1] = right_distance_tmp
geo_map[ys, xs, 2] = height - top_distance_tmp
geo_map[ys, xs, 3] = width - right_distance_tmp
geo_map[ys, xs, 4] = rotate_angle
return geo_map
def point_dist_to_line(p1, p2, p3):
if len(p3.shape) < 2:
return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
else:
points = p3.shape[0]
cross_product = np.cross(p2 - p1, p1 - p3)
cross_product = np.resize(cross_product, [points, 1])
return np.linalg.norm(cross_product, axis=1) / np.linalg.norm(p2 - p1, axis=1)
nice work
I have create a repo for you "https://github.com/Yuanhang8605/geo_map_gen-for-argman-east.git", you can download the cython code and check whether it's right or not. just the "for y, x in xy_in_poly" part. you can import the cython lib to replace this part. I'm trying use some code from you to re implement this article https://arxiv.org/abs/1801.01671, it achieve a good result on icdar15, is very similar with east framework. Thanks for share your code !