argman / EAST

A tensorflow implementation of EAST text detector
GNU General Public License v3.0
3.02k stars 1.05k forks source link

fix training slowly problem #160

Open Yuanhang8605 opened 6 years ago

Yuanhang8605 commented 6 years ago

I have create a repo for you "https://github.com/Yuanhang8605/geo_map_gen-for-argman-east.git", you can download the cython code and check whether it's right or not. just the "for y, x in xy_in_poly" part. you can import the cython lib to replace this part. I'm trying use some code from you to re implement this article https://arxiv.org/abs/1801.01671, it achieve a good result on icdar15, is very similar with east framework. Thanks for share your code !

Yuanhang8605 commented 6 years ago

The cython code is right, I have test it. After modify some code of you, Only training less than 20000 step, about half day, I can get good pretrained result on Synthtext.

infinitas-loop commented 6 years ago

@Yuanhang8605 Thanks for your efforts, I tried the code you provided and my training speed is up to 2x as before. But I get some problem, the loss will be Nan suddenly after I use your code. I use default setting and rctw dataset, do you konw why? Thanks.

Yuanhang8605 commented 6 years ago

@infinitas-loop I think this is not the reason of my code. I suggest you to use a try except block to capture exception like this:

def _generate_gt_maps(gt_box, num_gt_box, gt_mask, image_size):
  h, w = image_size
  # define the masks. 
  poly_mask = np.zeros((h, w), dtype=np.uint8)
  score_map = np.zeros((h, w), dtype=np.uint8)
  geo_map = np.zeros((h, w, 5), dtype=np.float32) 

  shrinked_poly_list = []
  rectangle_list = []
  rotate_angle_list = []

  for poly_idx in range(num_gt_box):
    try:
      poly = gt_box[poly_idx]
      tag = gt_mask[poly_idx]
      ex_poly = poly.astype(np.int32)[np.newaxis, :, :]

      if tag > 0:
        # fit the RBox
        # first fit parallelograms
        parallelogram = fit_parallelograms(poly)
        # fit the rbox
        rectangle = rectangle_from_parallelogram(parallelogram)
        # sort the rbox and use the AABB rule to generate the angle. 
        rectangle, rotate_angle, is_right = sort_rectangle(rectangle)

        if not is_right:
          cv2.fillPoly(score_map, ex_poly, 2)
          continue

        rectangle_list.append(rectangle)
        rotate_angle_list.append(rotate_angle)

        # cal the r in the paper. 
        r = [None, None, None, None]
        for i in range(4):
          r[i] = min(np.linalg.norm(poly[i] - poly[(i + 1) % 4]),
                    np.linalg.norm(poly[i] - poly[(i - 1) % 4]))      
        # score map
        shrinked_poly = shrink_poly(poly.copy(), r).astype(np.int32)[np.newaxis, :, :]
        shrinked_poly_list.append(shrinked_poly)
        # first fill the whole poly with 2, means ignore. 
        # then fill the shrinked part with 1, means positive. 
        cv2.fillPoly(score_map, ex_poly, 2)
        cv2.fillPoly(score_map, shrinked_poly, 1)
        # poly_mask indicate which poly is now. 
        cv2.fillPoly(poly_mask, shrinked_poly, poly_idx + 1)
        xy_in_poly = np.argwhere(poly_mask == (poly_idx + 1))
        gen_geo_map.gen_geo_map(geo_map, xy_in_poly, rectangle, rotate_angle)

      else:
        # ignored poly. 
        cv2.fillPoly(score_map, ex_poly, 2)

    except Exception as e:  
      # if exception, ignore poly
      cv2.fillPoly(score_map, ex_poly, 2)
      continue

  return score_map, geo_map, shrinked_poly_list, rectangle_list, rotate_angle_list  

this is my rewrited code, you may change some. It means you use try to capture the exception, in the except module, you just fill the poly with some value to ignore it( I use 2 to label ignore, but you may have to change the loss part. ), so if you don't want to modify the code, you can just fill with 0.

Yuanhang8605 commented 6 years ago

@infinitas-loop I have rewrite the sort_rectangle function to avoid wrong result. like this:

def sort_rectangle(poly):
  # sort the four coordinates of the polygon, points in poly should be sorted clockwise
  # First find the lowest point
  p_lowest = np.argmax(poly[:, 1])
  if np.count_nonzero((poly[:, 1] - poly[p_lowest, 1]) < 1e-3) == 2:
    # 搴曡竟骞宠浜嶺杞? 閭d箞p0涓哄乏涓婅
    p0_index = np.argmin(np.sum(poly, axis=1))
    p1_index = (p0_index + 1) % 4
    p2_index = (p0_index + 2) % 4
    p3_index = (p0_index + 3) % 4
    return poly[[p0_index, p1_index, p2_index, p3_index]], 0.
  else:
    # 鎵惧埌鏈€浣庣偣鍙宠竟鐨勭偣
    p_lowest_right = (p_lowest - 1) % 4
    p_lowest_left = (p_lowest + 1) % 4
    angle = np.arctan(-(poly[p_lowest][1] - poly[p_lowest_right][1])/(poly[p_lowest][0] - poly[p_lowest_right][0] + 1e-3))
    # assert angle > 0
    if angle <= 0:
      return poly[[p0_index, p1_index, p2_index, p3_index]], angle, False
    #   print(angle, poly[p_lowest], poly[p_lowest_right])
    if angle/np.pi * 180 > 45:
      # 杩欎釜鐐逛负p2
      p2_index = p_lowest
      p1_index = (p2_index - 1) % 4
      p0_index = (p2_index - 2) % 4
      p3_index = (p2_index + 1) % 4
      return poly[[p0_index, p1_index, p2_index, p3_index]], -(np.pi/2 - angle), True
    else:
      # 杩欎釜鐐逛负p3
      p3_index = p_lowest
      p0_index = (p3_index + 1) % 4
      p1_index = (p3_index + 2) % 4
      p2_index = (p3_index + 3) % 4
      return poly[[p0_index, p1_index, p2_index, p3_index]], angle, True

when use it , like this:

        rectangle, rotate_angle, is_right = sort_rectangle(rectangle)

        if not is_right:
          cv2.fillPoly(score_map, ex_poly, 2)
          continue
Yuanhang8605 commented 6 years ago

@infinitas-loop rewrite the following function to avoid wrong result:

def point_dist_to_line(p1, p2, p3):
  # compute the distance from p3 to p1-p2
  # return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
  a = np.linalg.norm(p1 - p2)
  b = np.linalg.norm(p2 - p3)
  c = np.linalg.norm(p3 - p1)
  s = (a + b + c) / 2.0
  area = (s*(s-a)*(s-b)*(s-c)) ** 0.5
  if a < 1.0:
    return (b + c)/2.0
  return 2 * area / a

def fit_line(p1, p2):
  """fit a line ax+by+c = 0
  Args:
    p1: [x1, x2]
    p2: [y1, y2]
  """
  if abs(p1[0] - p1[1]) < 1e-3:
    return [1., 0., -p1[0]]
  else:
    # [k, b] = np.polyfit(p1, p2, deg=1)
    x1, x2 = p1
    y1, y2 = p2
    k = (y2 - y1) / (x2 - x1)
    b = y1 - k * x1
    return [k, -1., b]
infinitas-loop commented 6 years ago

@Yuanhang8605 Sorry for the Late Reply, this time I am busy with other things! I did some debug, I found the Nan appear in geo_map after gen_geo_map.gen_geo_map(geo_map, xy_in_poly, rectange, rotate_angle), I check your .pyx code, and add this test code:

if isnan(geo_map[y, x, 0]):
    printf("p0:%f,%f; p1:%f,%f; p:%d,%d\n",p0[0],p0[1],p1[0],p1[1],x,y)
    printf("a:%f,b:%f,c:%f\n",a,b,c)
    printf("tri:%f\n",tri_area(a,b,c))
    printf("geo_map: %f\n",geo_map[y,x,0])

I found sometimes the output is like below:

p0:379.192535,165.468445; p1:401.493347,163.520020; p:396,164
a:22.385767,b:16.871490,c:5.514277
tri:-nan
geo_map: -nan

the a,b,c may not form a triangle because of calculating precision(a=b+c), so the area of triangle will be nan. I think that is where the problem lies. so I add this code to aviod it,

if isnan(tri_area(a, b, c)):
    printf("area is nan!\n")
    geo_map[y, x, 0] =0
else:
    geo_map[y, x, 0] = 2 * tri_area(a, b, c) / a
Yuanhang8605 commented 6 years ago

@infinitas-loop I have update my code, you can download and try, I think the problem has been solved. Thank you!

Yuanhang8605 commented 6 years ago

@infinitas-loop Actually it's because the area = (s(s-a)(s-b)*(s-c)) 0.5 part has problem, because when cal sqrt, the s(s-a)(s-b)(s-c) part may be negative. so I modify this to: area = abs(s(s-a)(s-b)(s-c)) 0.5

Thank you! It's a very small bug, but it's very hard to find.

infinitas-loop commented 6 years ago

@Yuanhang8605 yeah, I have seen your update. I think when the value is negative, the result has no mathematical meaning (because a triangle will be postivate in any case ), so setting the value to zero may be more reasonable.

Thnak you for sharing your code!

bidai541 commented 6 years ago

The time spending on ground truth is very long. I change the function of "caclulate_geo_map" by vector and the speed of trainning is twice.

def caclulate_geo_map(geo_map, xy_in_poly, rectange, rotate_angle):
  p0_rect, p1_rect, p2_rect, p3_rect = rectange
  height = (point_dist_to_line(p0_rect, p1_rect, p2_rect) + point_dist_to_line(p0_rect, p1_rect, p3_rect)) / 2
  width = (point_dist_to_line(p3_rect, p0_rect, p1_rect) + point_dist_to_line(p3_rect, p0_rect, p2_rect)) / 2

  ys = xy_in_poly[:, 0]
  xs = xy_in_poly[:, 1]
  num_points = xy_in_poly.shape[0]
  top_distance_tmp = point_dist_to_line(np.tile(p0_rect, (num_points, 1)),
                                        np.tile(p1_rect, (num_points, 1)),
                                        xy_in_poly[:, ::-1])
  geo_map[ys, xs, 0] = top_distance_tmp
  right_distance_tmp = point_dist_to_line(np.tile(p1_rect, (num_points, 1)),
                                        np.tile(p2_rect, (num_points, 1)),
                                        xy_in_poly[:, ::-1])
  geo_map[ys, xs, 1] = right_distance_tmp
  geo_map[ys, xs, 2] = height - top_distance_tmp
  geo_map[ys, xs, 3] = width - right_distance_tmp
  geo_map[ys, xs, 4] = rotate_angle
  return geo_map

def point_dist_to_line(p1, p2, p3):
  if len(p3.shape) < 2:
    return np.linalg.norm(np.cross(p2 - p1, p1 - p3)) / np.linalg.norm(p2 - p1)
  else:
    points = p3.shape[0]
    cross_product = np.cross(p2 - p1, p1 - p3)
    cross_product = np.resize(cross_product, [points, 1])
    return np.linalg.norm(cross_product, axis=1) / np.linalg.norm(p2 - p1, axis=1)
Zhang-O commented 5 years ago

nice work