krumo / Detectron-DA-Faster-RCNN

Domain Adaptive Faster R-CNN in Detectron
Apache License 2.0
49 stars 15 forks source link

Cannot reproduce the same results for Sim10k->Cityscapes task #2

Closed halqasir closed 5 years ago

halqasir commented 5 years ago

Hi @krumo ,

Thanks for this great work, I have some questions. I'm trying to reproduce the results of your example (adapting from Sim10k dataset to Cityscapes dataset) I get higher value for second line (image-level alignment only) AP50 = 38.07 and a lower value for third line (ins-level alignment only) AP50=34.88 However AP50 values for the rest of the lines are acceptable. I suppose there is a difference in hyperparameters for each experiment.

  1. could you please confirm if you use: DA_IMG_GRL_WEIGHT: 0.2 DA_INS_GRL_WEIGHT: 0.2 in all the experiments for this dataset. Or you use different DA_IMG_GRL_WEIGHT when you use the image-level alignment only and a different DA_INS_GRL_WEIGHT when you use the instance-level alignment only?

  2. to make sure that I do it correctly, when I want to do the second line experiment (image-level alignment only ), I just comment the lines: https://github.com/krumo/Detectron-DA-Faster-RCNN/blob/44e5b9acda4957c6016c36384e17e0c80e47894a/detectron/modeling/model_builder.py#L227-L230 Is there anything else to do?

  3. you said in #1 you don't provide a configuration about weight loss, and the weights for two-level adaptation are the weights of GRL in two levels is that means DA_IMG_GRL_WEIGHT=DA_INS_GRL_WEIGHT= λ ?

λ is a trade-off parameter to balance the Faster R-CNN loss and our newly added domain adaptation components.

  1. if that true, where this DA_IMG_GRL_WEIGHT: 0.2 and DA_INS_GRL_WEIGHT: 0.2 came from? in the paper they didn't mention such a value.

we set λ= 0.1 for all experiments

  1. finally, in this case the L_cst is not multiplied by λ, but again in the paper

L=L_det+λ(L_img+L_ins+L_cst)

please correct me if i am wrong.

Thanks again

krumo commented 5 years ago

Hi @h13hiba, Thanks for your attention! According to my experience, adversarial training is quite unstable and it's not suprising to get these results. However, a consistent improvement should be observed after combining image-level and instance-level adaptation.

  1. Yes, I use the same parameter setting for Sim10k->Cityscapes task with different alignment methods. Both DA_IMG_GRL_WEIGHT and DA_INS_GRL_WEIGHT equals 0.2.
  2. Yes, that's also how I did to test image-level only adaptation.
  3. Yes. My implementation here is consistent with the original one in Caffe.
  4. I get them by testing different parameter setting because I cannot reproduce the same results with the parameter setting stated in paper on Caffe2 and Detectron.
  5. In my implementation, L_cst is also multiplied λ. Please look at Line 342 and Line 348.

Please let me know if you still have any confusion.

halqasir commented 5 years ago

Thanks @krumo for your reply, just last question: why did you use L1 distance instead of L2 distance?

https://github.com/krumo/Detectron-DA-Faster-RCNN/blob/44e5b9acda4957c6016c36384e17e0c80e47894a/detectron/modeling/model_builder.py#L364-L365

krumo commented 5 years ago

@h13hiba In my understanding, in the original paper L_cst is the sum of L2 distances of many scalars. A scalar's L1 distance is the same as its L2 distance. That's why I use L1 distance here.

JeromeMutgeert commented 5 years ago

Hi @krumo ,

Which paper do you mean here:

  1. I get them by testing different parameter setting because I cannot reproduce the same results with the parameter setting stated in paper on Caffe2 and Detectron.
Baby47 commented 5 years ago

Hi, @h13hiba Since you have trained this model successfully, i want to discuss some errors occurred in my training process. I'm trying to reproduce the results of the example addressed in the article (adapting from Sim10k dataset to Cityscapes dataset) , First, it went on training, when it iterated to 60 iters, some errors appeared as following:

terminate called after throwing an instance of 'caffe2::EnforceNotMet' what(): [enforce fail at distance_op.cu:64] X.dim32(i) == Y.dim32(i). 498 vs 499. Mismatch in dimensions / Error from operator: input: "gpu_0/repeated_img_probs" input: "gpu_0/ins_probs" output: "gpu_0/consistency_dist" name: "" type: "SquaredL2Distance" device_option { device_type: 1 cuda_gpu_id: 0 } Aborted at 1553173945 (unix time) try "date -d @1553173945" if you are using GNU date PC: @ 0x7fcff4083428 gsignal SIGABRT (@0x3e8000019ce) received by PID 6606 (TID 0x7fceb27fc700) from PID 6606; stack trace: @ 0x7fcff4429390 (unknown) @ 0x7fcff4083428 gsignal @ 0x7fcff408502a abort @ 0x7fcfed9a784d __gnu_cxx::__verbose_terminate_handler() @ 0x7fcfed9a56b6 (unknown) @ 0x7fcfed9a5701 std::terminate() @ 0x7fcfed9d0d38 (unknown) @ 0x7fcff441f6ba start_thread @ 0x7fcff415541d clone @ 0x0 (unknown) I have checked the training dataset carefully and make sure they are correctly transformed. Can you figure out the reason why it terminated?

JeromeMutgeert commented 5 years ago

Hi @Baby47,

I've encountered this problem too, and figured it was due to the code in the _add_consistency_loss method in model_builder.py

In this method the the mean image domain predictions are manually broadcasted along the instance domain predictions, but here it is assumed that there are always 256 instances per image. When there are less instances (by an odd number) your error shows up. Note that it is unnatural when there are that less detections. Note that the selection of these 256 detections does rely on the target set labels.

I've rewritten the core methods of the python layer that does the reshaping so that the right size and distributing takes place. I've replaced the second input, 'ins_probs', by 'rois':

def expand_as(inputs, outputs):
    img_prob = inputs[0].data
    rois = inputs[1].data
    import numpy as np
    mean_da_conv = np.mean(img_prob, (1,2,3))
    repeated_da_conv = np.expand_dims(mean_da_conv[rois[:,0].astype(np.int32)], axis=1) 
    outputs[0].feed(repeated_da_conv)

def grad_expand_as(inputs, outputs):
    import numpy as np
    img_prob = inputs[0].data
    rois = inputs[1].data
    grad_output = inputs[3]
    grad_input = outputs[0]

    grad_o = grad_output.data[...]
    # sums = np.zeros(img_prob.shape[0],dtype=np.float32)
    # for b_idx,g_o in zip(rois[:,0],grad_o[:,0]):
        # sums[b_idx] += g_o
    # probably faster:
    sums = np.bincount(rois[:,0].astype(np.int32),grad_o[:,0]).astype(np.float32)
    g_is = sums/(img_prob.shape[1]*img_prob.shape[2]*img_prob.shape[3])

    grad_i = np.empty(img_prob.shape,dtype=np.float32)
    for b in range(img_prob.shape[0]):
        grad_i[b,...] = g_is[b]

    grad_input.reshape(img_prob.shape)
    grad_input.data[...] = grad_i

and in the later parts of the method replace 'ins_probs' for 'rois' : model.net.Python(f=expand_as, grad_f=grad_expand_as, grad_input_indices=[0], grad_output_indices=[0])(['img_probs', 'rois'], ['repeated_img_probs'])

I know this should have been in a pull request, but for now here it is.

Baby47 commented 5 years ago

Hi @Baby47, I've encountered this problem too, and figured it was due to the code in the _add_consistency_loss method in model_builder.py In this method the the mean image domain predictions are manually broadcasted along the instance domain predictions, but here it is assumed that there are always 256 instances per image. When there are less instances (by an odd number) your error shows up. Note that it is unnatural when there are that less detections. Note that the selection of these 256 detections does rely on the target set labels. I've rewritten the core methods of the python layer that does the reshaping so that the right size and distributing takes place. I've replaced the second input, 'ins_probs', by 'rois': def expand_as(inputs, outputs): img_prob = inputs[0].data rois = inputs[1].data import numpy as np mean_da_conv = np.mean(img_prob, (1,2,3)) repeated_da_conv = np.expand_dims(mean_da_conv[rois[:,0].astype(np.int32)], axis=1) outputs[0].feed(repeated_da_conv)

def grad_expand_as(inputs, outputs): import numpy as np img_prob = inputs[0].data rois = inputs[1].data grad_output = inputs[3] grad_input = outputs[0]

grad_o = grad_output.data[...]
# sums = np.zeros(img_prob.shape[0],dtype=np.float32)
# for b_idx,g_o in zip(rois[:,0],grad_o[:,0]):
    # sums[b_idx] += g_o
# probably faster:
sums = np.bincount(rois[:,0].astype(np.int32),grad_o[:,0]).astype(np.float32)
g_is = sums/(img_prob.shape[1]*img_prob.shape[2]*img_prob.shape[3])

grad_i = np.empty(img_prob.shape,dtype=np.float32)
for b in range(img_prob.shape[0]):
    grad_i[b,...] = g_is[b]

grad_input.reshape(img_prob.shape)
grad_input.data[...] = grad_i

and in the later parts of the method replace 'ins_probs' for 'rois' : model.net.Python(f=expand_as, grad_f=grad_expand_as, grad_input_indices=[0], grad_output_indices=[0])(['img_probs', 'rois'], ['repeated_img_probs']) I know this should have been in a pull request, but for now here it is.

Thanks you a lot. Have you tried to minimize 256 instances per image? I will have a try whether your code works.