Closed qxcv closed 3 years ago
Hi Sam! The random translate aug is:
# test time aug
def center_translate(imgs, size):
n, c, h, w = imgs.shape
assert size >= h and size >= w
outs = np.zeros((n, c, size, size), dtype=imgs.dtype)
h1 = (size - h) // 2
w1 = (size - w) // 2
outs[:, :, h1:h1 + h, w1:w1 + w] = imgs
return outs
# train time aug
def random_translate(imgs, size, return_random_idxs=False, h1s=None, w1s=None):
n, c, h, w = imgs.shape
assert size >= h and size >= w
outs = np.zeros((n, c, size, size), dtype=imgs.dtype)
h1s = np.random.randint(0, size - h + 1, n) if h1s is None else h1s
w1s = np.random.randint(0, size - w + 1, n) if w1s is None else w1s
for out, img, h1, w1 in zip(outs, imgs, h1s, w1s):
out[:, h1:h1 + h, w1:w1 + w] = img
if return_random_idxs: # So can do the same to another set of imgs.
return outs, dict(h1s=h1s, w1s=w1s)
return outs
I'll let you know when it's added to the main codebase.
Great, thanks for that. We're using RAD to train "expert demonstrator" policies as part of an image-based imitation learning project, so having the translate
augmentation will be helpful (although even random cropping on its own has worked pretty well for most tasks).
btw if it's helpful, here's how the replay buffer needs to be refactored to make this aug work (thanks to @WendyShang for providing this snippet)
# in utils.py
def sample_rad(self,aug_funcs):
# augs specified as flags
# curl_sac organizes flags into aug funcs
# passes aug funcs into sampler
idxs = np.random.randint(
0, self.capacity if self.full else self.idx, size=self.batch_size
)
obses = self.obses[idxs]
next_obses = self.next_obses[idxs]
#og_obses = torch.as_tensor(center_crop_images(obses,100), device=self.device).float()
#og_next_obses = torch.as_tensor(center_crop_images(next_obses,100), device=self.device).float()
#NOTE 100 is hard coded for the size of DMC suite!!!
og_obses = center_crop_images(obses,100)
og_next_obses = center_crop_images(next_obses,100)
if aug_funcs:
for aug,func in aug_funcs.items():
# apply crop and cutout first
if 'crop' in aug or 'cutout' in aug or 'window' in aug:
obses = func(obses,self.image_size)
next_obses = func(next_obses,self.image_size)
if 'translate' in aug:
obses, rndm_idxs = func(og_obses, self.image_size, return_random_idxs=True)
if self.augment_target_same_rnd:
next_obses = func(og_next_obses, self.image_size, **rndm_idxs)
else:
next_obses = func(og_next_obses, self.image_size)
obses = torch.as_tensor(obses, device=self.device).float()
next_obses = torch.as_tensor(next_obses, device=self.device).float()
actions = torch.as_tensor(self.actions[idxs], device=self.device)
rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device)
obses = obses / 255.
next_obses = next_obses / 255.
# augmentations go here
if aug_funcs:
for aug,func in aug_funcs.items():
# skip crop and cutout augs
if 'crop' in aug or 'cutout' in aug or 'translate' in aug or 'window' in aug:
continue
obses = func(obses)
next_obses = func(next_obses)
# return obses, actions, rewards, next_obses, not_dones, og_obses,og_next_obses
#else:
# return obses, actions, rewards, next_obses, not_dones, None
return obses, actions, rewards, next_obses, not_dones
Hi Misah,
I can refactor the replay buffer and submit a pull request after submitting my thesis (after Sept 18th) 😅 There are a few more changes to clean up the augmentation code.
Wendy
On Sun, Sep 6, 2020 at 8:49 AM Michael Laskin notifications@github.com wrote:
btw if it's helpful, here's how the replay buffer needs to be refactored to make this aug work (thanks to @WendyShang https://github.com/WendyShang for providing this snippet)
in utils.py
def sample_rad(self,aug_funcs): # augs specified as flags # curl_sac organizes flags into aug funcs # passes aug funcs into sampler idxs = np.random.randint( 0, self.capacity if self.full else self.idx, size=self.batch_size ) obses = self.obses[idxs] next_obses = self.next_obses[idxs] #og_obses = torch.as_tensor(center_crop_images(obses,100), device=self.device).float() #og_next_obses = torch.as_tensor(center_crop_images(next_obses,100), device=self.device).float() #NOTE 100 is hard coded for the size of DMC suite!!! og_obses = center_crop_images(obses,100) og_next_obses = center_crop_images(next_obses,100) if aug_funcs: for aug,func in aug_funcs.items(): # apply crop and cutout first if 'crop' in aug or 'cutout' in aug or 'window' in aug: obses = func(obses,self.image_size) next_obses = func(next_obses,self.image_size) if 'translate' in aug: obses, rndm_idxs = func(og_obses, self.image_size, return_random_idxs=True) if self.augment_target_same_rnd: next_obses = func(og_next_obses, self.image_size, **rndm_idxs) else: next_obses = func(og_next_obses, self.image_size) obses = torch.as_tensor(obses, device=self.device).float() next_obses = torch.as_tensor(next_obses, device=self.device).float() actions = torch.as_tensor(self.actions[idxs], device=self.device) rewards = torch.as_tensor(self.rewards[idxs], device=self.device) not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) obses = obses / 255. next_obses = next_obses / 255. # augmentations go here if aug_funcs: for aug,func in aug_funcs.items(): # skip crop and cutout augs if 'crop' in aug or 'cutout' in aug or 'translate' in aug or 'window' in aug: continue obses = func(obses) next_obses = func(next_obses) # return obses, actions, rewards, next_obses, not_dones, og_obses,og_next_obses #else: # return obses, actions, rewards, next_obses, not_dones, None return obses, actions, rewards, next_obses, not_dones
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/MishaLaskin/rad/issues/5#issuecomment-687824637, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABEXHXPUHO2GLBRWPUYB65TSEOVPVANCNFSM4OLUWV6Q .
Hi @qxcv , just merged this into the codebase!
Excellent, thank you!
Thanks for sharing this code! I'm trying to repeat some of the experiments in the RAD paper. It seems like most of them used "translate" as the only augmentation, but I can't find a function for random translation in
data_augs.py
. What do I have to do to get the same augmentation in the paper?