HumanSignal / label-studio

Label Studio is a multi-type data labeling and annotation tool with standardized output format
https://labelstud.io
Apache License 2.0
18.9k stars 2.35k forks source link

Adding a new feature for reward sketching #814

Open ankurhanda opened 3 years ago

ankurhanda commented 3 years ago

Is your feature request related to a problem? Please describe. Being able to annotate rewards is missing in this label studio. This reward annotation is useful for training RL agents.

Describe the solution you'd like I'd like to have a way to create reward sketches like done here https://sites.google.com/view/data-driven-robotics/ on video data.

Describe alternatives you've considered I have not seen anything close to this in the repo.

Additional context These reward sketches are useful to have to train RL agents. This is the closest you can get to have a supervised signal for robotics.

makseq commented 3 years ago

@ankurhanda What kind of reward sketches do you want to have? LS has ratings, checkboxes and text areas, so you can have something like this for video.

ankurhanda commented 3 years ago

@makseq I'd like to draw a continuous function with X-axis being the frame number and Y-axis being the reward. To clarify what that means check the reward sketching section in https://sites.google.com/view/data-driven-robotics/ where you slide (and click) your mouse along and it draws the curves for you.

makseq commented 3 years ago

Thanks! Now I've got it. It's a very cool feature!

ankurhanda commented 3 years ago

Do you think it is worth adding that in label-studio?

I implemented it in python already (using matplotlib) but it's fairly simple and often requires more clicking.

ankurhanda commented 3 years ago

The code looks like this in case you're interested.

` import math import numpy as np

import matplotlib.pyplot as plt from matplotlib.backend_bases import MouseEvent

class DraggablePlotExample(object): u""" An example of plot with draggable markers """

def __init__(self, hdf5File=None):
    self._figure, self._axes, self._line = None, None, None
    self._dragging_point = None
    self._points = {}

    self.prev_x, self.prev_y = -1, -1
    self.cur_x = 0 

    self.demo_num = 2

    if hdf5File is not None:

        import h5py
        demos = h5py.File(hdf5File)

        demo_i = demos['data']['demo_' + str(self.demo_num)]
        demo_obs = demo_i['obs']

        self.dataset_len = len(demo_obs)

        self.img_dict = {}        

        for i in range(self.dataset_len):
            self.img_dict[i] = demo_obs[i]

    else:

        import glob, cv2 
        dataset = sorted(glob.glob('./fetch_push/*.jpg'))

        self.dataset_len = len(dataset)

        self.img_dict = {}
        for i, img_name in enumerate(dataset):

            img = cv2.imread(img_name)
            img = cv2.resize(img, (256, 256), interpolation = cv2.INTER_NEAREST)
            self.img_dict[i] = img

    self._init_plot()

def _init_plot(self):
    self._figure = plt.figure("Example plot")

    img_plot = plt.subplot(2, 1, 1)
    img_plot.axis('off')
    img_plot.imshow(self.img_dict[0])
    self._img_plot = img_plot

    self.y_min = -100
    self.y_max = 0 

    axes = plt.subplot(2, 1, 2)
    axes.set_xlim(0, self.dataset_len)
    axes.set_ylim(self.y_min, self.y_max)
    axes.set_xlabel('frame $\mathcal{I}_t$')
    axes.set_ylabel('reward $r_t$')

    self._axes = axes

    self._figure.canvas.mpl_connect('button_press_event', self._on_click)
    self._figure.canvas.mpl_connect('button_release_event', self._on_release)
    self._figure.canvas.mpl_connect('motion_notify_event', self._on_motion)
    self._figure.canvas.mpl_connect('key_press_event', self._on_key_press)
    plt.show()

def _update_plot(self):
    if not self._points:
        self._line.set_data([], [])
    else:
        x, y = zip(*sorted(self._points.items()))
        # Add new plot
        if not self._line:
            self._line, = self._axes.plot(x, y, "b", marker="o", markersize=10)
            self._line2 = self._axes.fill_between(x, 0, y, facecolor='blue', alpha=0.5)
        # Update current plot
        else:
            self._line.set_data(x, y)
            self._line2.remove()
            self._line2 = self._axes.fill_between(x, self.y_min, y, facecolor='blue', alpha=0.5)
    if True:
        self._update_image()

    self._figure.canvas.draw()

def _update_image(self):
    self._img_plot.imshow(self.img_dict[int(self.cur_x)])

def _add_point(self, x, y=None):
    if isinstance(x, MouseEvent):
        x, y = int(x.xdata), int(x.ydata)

    self._points[x] = y

    return x, y

def _remove_point(self, x, _):
    if x in self._points:
        self._points.pop(x)

def _find_neighbor_point(self, event):
    u""" Find point around mouse position
    :rtype: ((int, int)|None)
    :return: (x, y) if there are any point around mouse else None
    """
    distance_threshold = 3.0
    nearest_point = None
    min_distance = math.sqrt(2 * (100 ** 2))
    for x, y in self._points.items():
        distance = math.hypot(event.xdata - x, event.ydata - y)
        if distance < min_distance:
            min_distance = distance
            nearest_point = (x, y)
    if min_distance < distance_threshold:
        return nearest_point
    return None

def _on_click(self, event):
    u""" callback method for mouse click event
    :type event: MouseEvent
    """
    # left click
    if event.button == 1 and event.inaxes in [self._axes]:
        point = self._find_neighbor_point(event)
        if point:
            self._dragging_point = point
        else:
            self._add_point(event)
        self._update_plot()
    # right click
    elif event.button == 3 and event.inaxes in [self._axes]:
        point = self._find_neighbor_point(event)
        if point:
            self._remove_point(*point)
            self._update_plot()

def _on_release(self, event):
    u""" callback method for mouse release event
    :type event: MouseEvent
    """
    if event.button == 1 and event.inaxes in [self._axes] and self._dragging_point:
        self._dragging_point = None
        self._update_plot()

def _on_key_press(self, event):
    if event.key == 'w':

        fileName = "reward_dict_" + str(self.demo_num) + '.csv'
        f = open(fileName,"w")

        for key, value in sorted( self._points.items() ):
            f.write( str(key) + ' ' + str(value) + '\n' )

        f.close()

        print('logged sketched reward')

def _on_motion(self, event):
    u""" callback method for mouse motion event
    :type event: MouseEvent
    """
    if event.xdata is not None: 
        self.cur_x = int(event.xdata)
        self._update_image()
        self._figure.canvas.draw()

    if not self._dragging_point:
        return
    if event.xdata is None or event.ydata is None:
        return

    # self._remove_point(*self._dragging_point)
    self._dragging_point = self._add_point(event)
    self._update_plot()

if name == "main": plot = DraggablePlotExample(hdf5File='images.hdf5') `

ankurhanda commented 3 years ago

some of this is inspired from https://github.com/yuma-m/matplotlib-draggable-plot/blob/master/draggable_plot.py

makseq commented 3 years ago

Thank you very much! But we need to re-implement it in JS React for LS.

ankurhanda commented 3 years ago

Yes, sure. Let me know if you ever get to it. I'm happy to do testing for you. Thanks a lot for listening.

makseq commented 2 years ago

https://labelstud.io/tags/number.html#main - we introduced Number tag, it can help here.