ZHKKKe / MODNet

A Trimap-Free Portrait Matting Solution in Real Time [AAAI 2022]
Apache License 2.0
3.85k stars 636 forks source link

Argument Types not matching in dataloader (in training script) #190

Open yogeshwari-20000609 opened 2 years ago

yogeshwari-20000609 commented 2 years ago

I wrote get_data() function given below which is called by Dataloader() in trainer.py. But, i'm getting argument type error for arguments that i have passed through supervised_training_iter(modnet, optimizer, image, trimap, gt_matte) while enumerating through dataloader(). Error description : TypeError: conv2d() received an invalid combination of arguments - got (numpy.ndarray, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:

Function : def get_data(bs, folder_path) : output = [] final = [] direc = os.listdir(folder_path)

Creating batch

batch = [direc[i:i + bs] for i in range(0, len(direc), bs)]

for group in batch:
    output = []   
    for i in range(len(group)):
        listl = []
        group[i] = folder_path+group[i]
        image_array = get_image(group[i])
        listl.append(image_array)
        trimap_array = generate_trimap(group[i])
        listl.append(trimap_array)
        mask_array = get_mask(group[i])
        listl.append(mask_array)
        output.append(listl)
    final.append(output)
return final

dataloader = get_data(bs, folder_path)

In data in get_data() function i'm creating numpy array of image(values ranging 0 to 1), trimap([0, 0.5, 1]) and mask([0, 1]) in sequence and appending it to the list. Then finally tuple of lists of image, trimap and mask respectively are appended and returned. Is this right way to pass the data to dataloader() function? Please guide me on type of arguments should pass to the dataloader().