yizt / keras-faster-rcnn

keras实现faster rcnn,end2end训练、预测; 持续更新中,见todo... ;欢迎试用、关注并反馈问题
87 stars 36 forks source link

请问这个可以训练自己的数据集吗,该修改哪些地方吗 #2

Open huangzhenjie opened 5 years ago

huangzhenjie commented 5 years ago

谢谢

yizt commented 5 years ago

@huangzhenjie 可以的,虽然目前还没有测试过其它数据集;需要修改的地方如下: 1:faster_rcnn.config中创建自己的Config类 2:在faster_rcnn.preprocess.input中创建自己的Dataset类,保证初始化image_info_list字段 3:在train和inference中,import自己的Config和Dataset即可 如遇到任何问题,请反馈,谢谢

wudi00 commented 5 years ago

训练自己的数据时出现下面的错误 `Traceback (most recent call last): File "/opt/anaconda3/envs/py36/lib/python3.6/multiprocessing/pool.py", line 119, in worker result = (True, func(*args, **kwds)) File "/opt/anaconda3/envs/py36/lib/python3.6/site-packages/keras/utils/data_utils.py", line 626, in next_sample return six.next(_SHARED_SEQUENCES[uid]) File "/media/disk/wudi/faster/faster_rcnn/utils/generator.py", line 100, in gen batch_gt_boxes[i] = np_utils.pad_to_fixed_size(gt_boxes, self.max_gt_num) ValueError: could not broadcast input array from shape (71,5) into shape (50,5) """

` 请问这是是哪里出问题了,要怎么解决呢

yizt commented 5 years ago

@wudi00 感谢您的反馈,修改config.py中的MAX_GT_INSTANCES属性值即可;调大

wudi00 commented 5 years ago

@yizt 谢谢,1、还有另外一个问题,下面这个错误Traceback (most recent call last): File "/opt/anaconda3/envs/py36/lib/python3.6/multiprocessing/pool.py", line 119, in worker result = (True, func(*args, **kwds)) File "/opt/anaconda3/envs/py36/lib/python3.6/site-packages/keras/utils/data_utils.py", line 626, in next_sample return six.next(_SHARED_SEQUENCES[uid]) File "/media/disk/wudi/faster/faster_rcnn/utils/generator.py", line 93, in gen image, gt_boxes = image_crop(image, gt_boxes) File "/media/disk/wudi/faster/faster_rcnn/utils/generator.py", line 42, in image_crop image, crop_window = image_utils.random_crop_image(image, [min_y, min_x, max_y, max_x]) File "/media/disk/wudi/faster/faster_rcnn/utils/image.py", line 199, in random_crop_image wy2 = h - np.random.randint(min(h - y2 + 1, h // 20)) File "mtrand.pyx", line 993, in mtrand.RandomState.randint ValueError: low >= high 2、如果我自己的数据集图片大小是8001280,可不可以更改网络输入层大小为8001280,,还是应该resize图片大小成720*720呢?

yizt commented 5 years ago

目前网络输入是方形,长宽相等,图像缩放到长边尺寸,保持长宽比,短边padding; 可以改为1280*1280;如果你的GPU显存足够

在 2019-05-09 15:43:08,"wudi00" notifications@github.com 写道:

@yizt 谢谢,1、还有另外一个问题,下面这个错误Traceback (most recent call last): File "/opt/anaconda3/envs/py36/lib/python3.6/multiprocessing/pool.py", line 119, in worker result = (True, func(*args, *kwds)) File "/opt/anaconda3/envs/py36/lib/python3.6/site-packages/keras/utils/data_utils.py", line 626, in next_sample return six.next(_SHARED_SEQUENCES[uid]) File "/media/disk/wudi/faster/faster_rcnn/utils/generator.py", line 93, in gen image, gt_boxes = image_crop(image, gt_boxes) File "/media/disk/wudi/faster/faster_rcnn/utils/generator.py", line 42, in image_crop image, crop_window = image_utils.random_crop_image(image, [min_y, min_x, max_y, max_x]) File "/media/disk/wudi/faster/faster_rcnn/utils/image.py", line 199, in random_crop_image wy2 = h - np.random.randint(min(h - y2 + 1, h // 20)) File "mtrand.pyx", line 993, in mtrand.RandomState.randint ValueError: low >= high 2、如果我自己的数据集图片大小是8001280,可不可以更改网络输入层大小为8001280,,还是应该resize图片大小成720720呢?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub, or mute the thread.

yizt commented 5 years ago

@wudi00 目前都是方形,长宽一样,大小任意,修改IMAGE_MAX_DIM属性即可;不过建议修改IMAGE_MAX_DIM后,执行python gt_cluster.py --clusters 9 重新设置anchors的尺寸

sibadakesi commented 5 years ago

你搞定了吗,我可以帮你,如果你需要的话