Closed nickvazz closed 2 years ago
Add all of the useful detections
for idx, sample in enumerate(dataset):
if idx > 5:
break
img = sample['filepath']
img = mpimg.imread(img)
useful_detections = [d for d in sample['ground_truth']['detections'] if d['label'] in ['teddy_bear','bowl']]
print(bbox)
mask = np.zeros(img.shape[:2])
# print(x*img.shape[1], (x+w)*img.shape[1], y*img.shape[0], (y+h)*img.shape[0], w*img.shape[1], h*img.shape[0])
# print(np.floor(x*img.shape[1]), np.ceil((x+w)*img.shape[1]), np.floor(y*img.shape[0]), np.ceil((y+h)*img.shape[0]), w*img.shape[1], h*img.shape[0])
def possible_slices(x,w,y,h):
xslices = [
slice(int(np.floor(x*img.shape[1])), int(np.floor((x+w)*img.shape[1]))),
slice(int(np.ceil(x*img.shape[1])), int(np.floor((x+w)*img.shape[1]))),
slice(int(np.floor(x*img.shape[1])), int(np.ceil((x+w)*img.shape[1]))),
slice(int(np.ceil(x*img.shape[1])), int(np.ceil((x+w)*img.shape[1]))),
]
yslices = [
slice(int(np.floor(y*img.shape[0])), int(np.floor((y+h)*img.shape[0]))),
slice(int(np.ceil(y*img.shape[0])), int(np.floor((y+h)*img.shape[0]))),
slice(int(np.floor(y*img.shape[0])), int(np.ceil((y+h)*img.shape[0]))),
slice(int(np.ceil(y*img.shape[0])), int(np.ceil((y+h)*img.shape[0]))),
]
return list(product(xslices,yslices))
for detection in range(len(useful_detections)):
bbox = useful_detections[detection]['bounding_box']
x,y,w,h = bbox
all_possible_slices = possible_slices(x,w,y,h)
completed = False
for xslice, yslice in all_possible_slices:
if completed: continue
try:
mask[yslice, xslice] += useful_detections[detection]['mask']
completed = True
except:
pass
# mask[int(y*img.shape[0]):int((y+h)*img.shape[0])+1, int(x*img.shape[1]):int((x+w)*img.shape[1])+1] += useful_detections[0]['mask']
fig, ax = plt.subplots(ncols=3, figsize=(20,6), sharex=True, sharey=True)
fig.suptitle(len(useful_detections))
ax[0].imshow(img)
ax[1].imshow(mask)
ax[2].scatter(img.shape[1] * np.array([x,x,x+w,x+w]), img.shape[0] * np.array([y,y+h,y,y+h]), color='r')
ax[2].imshow(useful_detections[0]['mask'])
ax[0].imshow(mask, alpha=0.5,)
ax[0].scatter(img.shape[1] * np.array([x,x,x+w,x+w]), img.shape[0] * np.array([y,y+h,y,y+h]), color='r')
plt.tight_layout()
plt.show()