quanghuy0497 / Mask_R-CNN

A modified version of Mask R-CNN based on Matterport's version. Featuring: polygon annotating mask generation and k-fold cross-validation training.
15 stars 7 forks source link

Training on K-fold #1

Closed stevenlee168 closed 2 years ago

stevenlee168 commented 2 years ago

Chào Quang Huy, Mình có sử dụng cách training K-fold của bạn để áp dụng vào project của mình. Nhưng có một vài vấn đề mình hơi thắc mắc.

Tổng thể chúng ta sẽ có 5 fold trong quá trình training. Mỗi lần training chúng ta sẽ sử dụng 4 fold để train, 1 fold còn lại để val, nên tổng cộng sẽ training 5 model.

Theo lí thuyết thì hình như sau mỗi lần traning xong, chúng ta phải hủy model và training lại từ đầu, nhưng hình như model của bạn lúc training lần 2 sẽ sử dụng weight từ epoch của lần training trước. Điều này có thể thấy rõ bằng việc vẽ hình loss ra, loss của epoch đầu tiên của lần tranining 1 sẽ cao hơn rất nhiều so với loss đầu tiên của lần training 2. mình tự hỏi có cách nào để khắc phục chuyện này không?

Cảm ơn bạn!

quanghuy0497 commented 2 years ago

Hi @stevenlee168, Ý kiến của bạn hay lắm, thực sự lúc code mình ko để ý chuyện này. Với lại mình cũng drop cái project này khá lâu nên cũng hơi lười fix lại 😅😅😅

Mình nghĩ cách giải quyết cũng đơn giản thôi, ở file training_with_K_fold.py, bạn chỉ cần cho một vòng lặp for fold in range(K_fold) và gọi hàm model.load_weight lẫn train(model, fold). Mặc khác, trong hàm def train bạn truyền thêm biến fold để tạo train-val set tương ứng và bỏ vòng for bên trong hàm train đi nhé.

Mình nghĩ pseudo-code để train và test có thể là như sau:

Result = []
For fold in range(num_folds):
   model.create()
   model.load_weight()
   train(model, fold)
   Result.append(predict(model, test_data))
   model.reset()  # or delete

Final_result = mean(result)

Nếu bạn code được ổn áp rồi thì bạn gửi 1 pull request giúp mình để update code được ko?

Cảm ơn bạn nhiều lắm nhé! Nếu câu trả lời này giúp được bạn thì báo cho mình biết lẫn close issue giúp mình nhé 😉

stevenlee168 commented 2 years ago

Hi @quanghuy0497,

Không ngờ là bạn trả lời nhanh như vậy. Mình cũng nghĩ tới việc này qua rồi nhưng có vẻ nó sẽ hơi rối tí. Để mình tìm hiểu kĩ thêm về phần code của bạn và của vả matterport, khi nào mình giải quyết xong mình sẽ close issue và pull request!

Cảm ơn bạn nhiều!