MishaLaskin / rad

RAD: Reinforcement Learning with Augmented Data
400 stars 71 forks source link

"Translate" augmentation? #5

Closed qxcv closed 3 years ago

qxcv commented 4 years ago

Thanks for sharing this code! I'm trying to repeat some of the experiments in the RAD paper. It seems like most of them used "translate" as the only augmentation, but I can't find a function for random translation in data_augs.py. What do I have to do to get the same augmentation in the paper?

MishaLaskin commented 4 years ago

Hi Sam! The random translate aug is:

# test time aug
def center_translate(imgs, size):
    n, c, h, w = imgs.shape
    assert size >= h and size >= w
    outs = np.zeros((n, c, size, size), dtype=imgs.dtype)
    h1 = (size - h) // 2
    w1 = (size - w) // 2
    outs[:, :, h1:h1 + h, w1:w1 + w] = imgs
    return outs

# train time aug
def random_translate(imgs, size, return_random_idxs=False, h1s=None, w1s=None):
    n, c, h, w = imgs.shape
    assert size >= h and size >= w
    outs = np.zeros((n, c, size, size), dtype=imgs.dtype)
    h1s = np.random.randint(0, size - h + 1, n) if h1s is None else h1s
    w1s = np.random.randint(0, size - w + 1, n) if w1s is None else w1s
    for out, img, h1, w1 in zip(outs, imgs, h1s, w1s):
        out[:, h1:h1 + h, w1:w1 + w] = img
    if return_random_idxs:  # So can do the same to another set of imgs.
        return outs, dict(h1s=h1s, w1s=w1s)
    return outs

I'll let you know when it's added to the main codebase.

qxcv commented 4 years ago

Great, thanks for that. We're using RAD to train "expert demonstrator" policies as part of an image-based imitation learning project, so having the translate augmentation will be helpful (although even random cropping on its own has worked pretty well for most tasks).

MishaLaskin commented 4 years ago

btw if it's helpful, here's how the replay buffer needs to be refactored to make this aug work (thanks to @WendyShang for providing this snippet)

 # in utils.py  
    def sample_rad(self,aug_funcs):
        # augs specified as flags
        # curl_sac organizes flags into aug funcs
        # passes aug funcs into sampler
        idxs = np.random.randint(
            0, self.capacity if self.full else self.idx, size=self.batch_size
        )
        obses = self.obses[idxs]
        next_obses = self.next_obses[idxs]
        #og_obses = torch.as_tensor(center_crop_images(obses,100), device=self.device).float()
        #og_next_obses = torch.as_tensor(center_crop_images(next_obses,100), device=self.device).float()
        #NOTE 100 is hard coded for the size of DMC suite!!!
        og_obses = center_crop_images(obses,100)
        og_next_obses = center_crop_images(next_obses,100)
        if aug_funcs:
            for aug,func in aug_funcs.items():
                # apply crop and cutout first
                if 'crop' in aug or 'cutout' in aug or 'window' in aug:
                    obses = func(obses,self.image_size)
                    next_obses = func(next_obses,self.image_size)
                if 'translate' in aug:
                    obses, rndm_idxs = func(og_obses, self.image_size, return_random_idxs=True)
                    if self.augment_target_same_rnd:
                        next_obses = func(og_next_obses, self.image_size, **rndm_idxs)
                    else:
                        next_obses = func(og_next_obses, self.image_size)
        obses = torch.as_tensor(obses, device=self.device).float()
        next_obses = torch.as_tensor(next_obses, device=self.device).float()
        actions = torch.as_tensor(self.actions[idxs], device=self.device)
        rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
        not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
        obses = obses / 255.
        next_obses = next_obses / 255.
        # augmentations go here
        if aug_funcs:
            for aug,func in aug_funcs.items():
                # skip crop and cutout augs
                if 'crop' in aug or 'cutout' in aug or 'translate' in aug or 'window' in aug:
                    continue
                obses = func(obses)
                next_obses = func(next_obses)
        #    return obses, actions, rewards, next_obses, not_dones, og_obses,og_next_obses
        #else:
        #    return obses, actions, rewards, next_obses, not_dones, None
        return obses, actions, rewards, next_obses, not_dones
WendyShang commented 4 years ago

Hi Misah,

I can refactor the replay buffer and submit a pull request after submitting my thesis (after Sept 18th) 😅 There are a few more changes to clean up the augmentation code.

Wendy

On Sun, Sep 6, 2020 at 8:49 AM Michael Laskin notifications@github.com wrote:

btw if it's helpful, here's how the replay buffer needs to be refactored to make this aug work (thanks to @WendyShang https://github.com/WendyShang for providing this snippet)

in utils.py

def sample_rad(self,aug_funcs):
    # augs specified as flags
    # curl_sac organizes flags into aug funcs
    # passes aug funcs into sampler
    idxs = np.random.randint(
        0, self.capacity if self.full else self.idx, size=self.batch_size
    )
    obses = self.obses[idxs]
    next_obses = self.next_obses[idxs]
    #og_obses = torch.as_tensor(center_crop_images(obses,100), device=self.device).float()
    #og_next_obses = torch.as_tensor(center_crop_images(next_obses,100), device=self.device).float()
    #NOTE 100 is hard coded for the size of DMC suite!!!
    og_obses = center_crop_images(obses,100)
    og_next_obses = center_crop_images(next_obses,100)
    if aug_funcs:
        for aug,func in aug_funcs.items():
            # apply crop and cutout first
            if 'crop' in aug or 'cutout' in aug or 'window' in aug:
                obses = func(obses,self.image_size)
                next_obses = func(next_obses,self.image_size)
            if 'translate' in aug:
                obses, rndm_idxs = func(og_obses, self.image_size, return_random_idxs=True)
                if self.augment_target_same_rnd:
                    next_obses = func(og_next_obses, self.image_size, **rndm_idxs)
                else:
                    next_obses = func(og_next_obses, self.image_size)
    obses = torch.as_tensor(obses, device=self.device).float()
    next_obses = torch.as_tensor(next_obses, device=self.device).float()
    actions = torch.as_tensor(self.actions[idxs], device=self.device)
    rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
    not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
    obses = obses / 255.
    next_obses = next_obses / 255.
    # augmentations go here
    if aug_funcs:
        for aug,func in aug_funcs.items():
            # skip crop and cutout augs
            if 'crop' in aug or 'cutout' in aug or 'translate' in aug or 'window' in aug:
                continue
            obses = func(obses)
            next_obses = func(next_obses)
    #    return obses, actions, rewards, next_obses, not_dones, og_obses,og_next_obses
    #else:
    #    return obses, actions, rewards, next_obses, not_dones, None
    return obses, actions, rewards, next_obses, not_dones

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/MishaLaskin/rad/issues/5#issuecomment-687824637, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABEXHXPUHO2GLBRWPUYB65TSEOVPVANCNFSM4OLUWV6Q .

MishaLaskin commented 3 years ago

Hi @qxcv , just merged this into the codebase!

qxcv commented 3 years ago

Excellent, thank you!