Open ankurhanda opened 3 years ago
@ankurhanda What kind of reward sketches do you want to have? LS has ratings, checkboxes and text areas, so you can have something like this for video.
@makseq I'd like to draw a continuous function with X-axis being the frame number and Y-axis being the reward. To clarify what that means check the reward sketching section in https://sites.google.com/view/data-driven-robotics/ where you slide (and click) your mouse along and it draws the curves for you.
Thanks! Now I've got it. It's a very cool feature!
Do you think it is worth adding that in label-studio?
I implemented it in python already (using matplotlib) but it's fairly simple and often requires more clicking.
The code looks like this in case you're interested.
` import math import numpy as np
import matplotlib.pyplot as plt from matplotlib.backend_bases import MouseEvent
class DraggablePlotExample(object): u""" An example of plot with draggable markers """
def __init__(self, hdf5File=None):
self._figure, self._axes, self._line = None, None, None
self._dragging_point = None
self._points = {}
self.prev_x, self.prev_y = -1, -1
self.cur_x = 0
self.demo_num = 2
if hdf5File is not None:
import h5py
demos = h5py.File(hdf5File)
demo_i = demos['data']['demo_' + str(self.demo_num)]
demo_obs = demo_i['obs']
self.dataset_len = len(demo_obs)
self.img_dict = {}
for i in range(self.dataset_len):
self.img_dict[i] = demo_obs[i]
else:
import glob, cv2
dataset = sorted(glob.glob('./fetch_push/*.jpg'))
self.dataset_len = len(dataset)
self.img_dict = {}
for i, img_name in enumerate(dataset):
img = cv2.imread(img_name)
img = cv2.resize(img, (256, 256), interpolation = cv2.INTER_NEAREST)
self.img_dict[i] = img
self._init_plot()
def _init_plot(self):
self._figure = plt.figure("Example plot")
img_plot = plt.subplot(2, 1, 1)
img_plot.axis('off')
img_plot.imshow(self.img_dict[0])
self._img_plot = img_plot
self.y_min = -100
self.y_max = 0
axes = plt.subplot(2, 1, 2)
axes.set_xlim(0, self.dataset_len)
axes.set_ylim(self.y_min, self.y_max)
axes.set_xlabel('frame $\mathcal{I}_t$')
axes.set_ylabel('reward $r_t$')
self._axes = axes
self._figure.canvas.mpl_connect('button_press_event', self._on_click)
self._figure.canvas.mpl_connect('button_release_event', self._on_release)
self._figure.canvas.mpl_connect('motion_notify_event', self._on_motion)
self._figure.canvas.mpl_connect('key_press_event', self._on_key_press)
plt.show()
def _update_plot(self):
if not self._points:
self._line.set_data([], [])
else:
x, y = zip(*sorted(self._points.items()))
# Add new plot
if not self._line:
self._line, = self._axes.plot(x, y, "b", marker="o", markersize=10)
self._line2 = self._axes.fill_between(x, 0, y, facecolor='blue', alpha=0.5)
# Update current plot
else:
self._line.set_data(x, y)
self._line2.remove()
self._line2 = self._axes.fill_between(x, self.y_min, y, facecolor='blue', alpha=0.5)
if True:
self._update_image()
self._figure.canvas.draw()
def _update_image(self):
self._img_plot.imshow(self.img_dict[int(self.cur_x)])
def _add_point(self, x, y=None):
if isinstance(x, MouseEvent):
x, y = int(x.xdata), int(x.ydata)
self._points[x] = y
return x, y
def _remove_point(self, x, _):
if x in self._points:
self._points.pop(x)
def _find_neighbor_point(self, event):
u""" Find point around mouse position
:rtype: ((int, int)|None)
:return: (x, y) if there are any point around mouse else None
"""
distance_threshold = 3.0
nearest_point = None
min_distance = math.sqrt(2 * (100 ** 2))
for x, y in self._points.items():
distance = math.hypot(event.xdata - x, event.ydata - y)
if distance < min_distance:
min_distance = distance
nearest_point = (x, y)
if min_distance < distance_threshold:
return nearest_point
return None
def _on_click(self, event):
u""" callback method for mouse click event
:type event: MouseEvent
"""
# left click
if event.button == 1 and event.inaxes in [self._axes]:
point = self._find_neighbor_point(event)
if point:
self._dragging_point = point
else:
self._add_point(event)
self._update_plot()
# right click
elif event.button == 3 and event.inaxes in [self._axes]:
point = self._find_neighbor_point(event)
if point:
self._remove_point(*point)
self._update_plot()
def _on_release(self, event):
u""" callback method for mouse release event
:type event: MouseEvent
"""
if event.button == 1 and event.inaxes in [self._axes] and self._dragging_point:
self._dragging_point = None
self._update_plot()
def _on_key_press(self, event):
if event.key == 'w':
fileName = "reward_dict_" + str(self.demo_num) + '.csv'
f = open(fileName,"w")
for key, value in sorted( self._points.items() ):
f.write( str(key) + ' ' + str(value) + '\n' )
f.close()
print('logged sketched reward')
def _on_motion(self, event):
u""" callback method for mouse motion event
:type event: MouseEvent
"""
if event.xdata is not None:
self.cur_x = int(event.xdata)
self._update_image()
self._figure.canvas.draw()
if not self._dragging_point:
return
if event.xdata is None or event.ydata is None:
return
# self._remove_point(*self._dragging_point)
self._dragging_point = self._add_point(event)
self._update_plot()
if name == "main": plot = DraggablePlotExample(hdf5File='images.hdf5') `
some of this is inspired from https://github.com/yuma-m/matplotlib-draggable-plot/blob/master/draggable_plot.py
Thank you very much! But we need to re-implement it in JS React for LS.
Yes, sure. Let me know if you ever get to it. I'm happy to do testing for you. Thanks a lot for listening.
https://labelstud.io/tags/number.html#main - we introduced Number tag, it can help here.
Is your feature request related to a problem? Please describe. Being able to annotate rewards is missing in this label studio. This reward annotation is useful for training RL agents.
Describe the solution you'd like I'd like to have a way to create reward sketches like done here https://sites.google.com/view/data-driven-robotics/ on video data.
Describe alternatives you've considered I have not seen anything close to this in the repo.
Additional context These reward sketches are useful to have to train RL agents. This is the closest you can get to have a supervised signal for robotics.