Stable-Baselines-Team / stable-baselines3-contrib

Contrib package for Stable-Baselines3 - Experimental reinforcement learning (RL) code
https://sb3-contrib.readthedocs.io
MIT License
458 stars 167 forks source link

Custom network with image augmentation layer #8

Open ac-93 opened 3 years ago

ac-93 commented 3 years ago

Hi all,

First off thanks for the hard work you have put in creating stable baselines 3, it's helped me a bunch.

I have a fairly simple suggestion that I think fits in the contrib repo. In my work I've been using a custom network with image augmentation (almost exclusively random translations) applied, this seems to help boost performance and stabilize training.

There's been a few fairly recent papers that have applied this effectively: CURL, RAD, DrQ.

I've been using Kornia as a simple drop in to apply augmentations before feeding into the feature extractor layers, so the network looks something like this.

''' class ImageAugNatureCNN(BaseFeaturesExtractor):

def __init__(self, observation_space: gym.spaces.Box,
             features_dim: int = 512,
             apply_augmentation: bool = True,
             shift_range: List[float] = [0.0, 0.0],
             zoom_range: List[float]  = [1.0, 1.0] ):

    super(ImageAugNatureCNN, self).__init__(observation_space, features_dim)
    # We assume CxHxW images (channels first)
    # Re-ordering will be done by pre-preprocessing or wrapper
    assert is_image_space(observation_space), (
        "You should use NatureCNN "
        f"only with images not with {observation_space} "
        "(you are probably using `CnnPolicy` instead of `MlpPolicy`)"
    )

    self.apply_augmentation = apply_augmentation
    self.augmentation = nn.Sequential(K.RandomAffine(degrees=0,
                                                     translate=shift_range,
                                                     scale=zoom_range))

    n_input_channels = observation_space.shape[0]
    self.cnn = nn.Sequential(
        nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0),
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
        nn.ReLU(),
        nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
        nn.ReLU(),
        nn.Flatten(),
    )

    # Compute shape by doing one forward pass
    with th.no_grad():
        n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]

    self.linear = nn.Sequential(nn.Linear(n_flatten, features_dim), nn.ReLU())

def forward(self, observations: th.Tensor) -> th.Tensor:

    if self.apply_augmentation:
        observations = self.augmentation(observations)

    # visualise_augmentation(observations)

    cnn_out = self.cnn(observations)
    return self.linear(cnn_out)

'''

Let me know if you think this would be a good fit or if there's any improvements you could think of.

Edit: My bad, hit submit a bit early...

araffin commented 3 years ago

Hello,

Can you elaborate a bit more, what do you want to implement exactly? Where should it be included? Which algorithms do you plan to support?

Miffyli commented 3 years ago

@araffin

I believe this is not only specific algorithms, rather a set of wrappers (or tools) to do image augmentation.

I think this would be a nice fit, although I am not sure if the library you linked would work here (observations are numpy arrays up until the algorithm, and I'd rather have the augmentations in the wrappers). For completeness-sake implementations of CURL/RAD could be added so one can try the setup as described by the authors in the original paper, but I think general wrappers/tools would be more useful to people.

ac-93 commented 3 years ago

I think this would be a nice fit, although I am not sure if the library you linked would work here (observations are numpy arrays up until the algorithm, and I'd rather have the augmentations in the wrappers). For completeness-sake implementations of CURL/RAD could be added so one can try the setup as described by the authors in the original paper, but I think general wrappers/tools would be more useful to people.

Yeah this is what I meant. The reason for applying the augmentation prior to the network instead of as a wrapper is to make best use of multiple pass-throughs of the data. e.g. instead of storing an augmented image and reusing that same augmentation, apply each random augmentation as the images are sampled from the buffer. I'm not sure on how much difference this would actually make in practice, but I think this way round is closer to whats done in the RAD and DRQ papers. I agree that there's probably a better way instead of relying on another library, this was just the simplest method I could think of. Perhaps there's a better place post sampling from buffer whilst still np arrays?

araffin commented 3 years ago

The reason for applying the augmentation prior to the network instead of as a wrapper is to make best use of multiple pass-throughs of the data. e.g. instead of storing an augmented image and reusing that same augmentation, apply each random augmentation as the images are sampled from the buffer.

What you are describing applies mostly to off-policy then and could be implemented at sampling time? What you propose in the code example however applied to on-policy algorithms too (which is fine).

The reason I'm asking for more details is that I don't want to have something half-working half-useful, that will then not be used. If we don't have a concrete reference, it will be hard to validate the implementation. And for image augmentation, I would rather go for imgaug if we use numpy array, otherwise kornia is better when handling pytorch tensors.

Miffyli commented 3 years ago

The reason I'm asking for more details is that I don't want to have something half-working half-useful, that will then not be used. If we don't have a concrete reference, it will be hard to validate the implementation.

The linked RAD and QrD are extensions of existing algorithms with augmentation so I think we have reference results (and likely also implementations available). I think these would be a nice addition to repo, especially if the augmentation tools (e.g. buffer that augments samples when reading from it) are something that can be used with other algorithms too. I would also rely on existing image augmentation library to make sure augmentations work as intended, unless something in the library really prevents their use.

ac-93 commented 3 years ago

What you are describing applies mostly to off-policy then and could be implemented at sampling time?

Yeah mostly applies to off policy, the RAD paper shows some generalization gains when using PPO though so I think it should be applicable to both.

The reason I'm asking for more details is that I don't want to have something half-working half-useful, that will then not be used. If we don't have a concrete reference, it will be hard to validate the implementation.

Makes sense, I think we could eliminate CURL as a reference as I believe RAD is a simplified follow up which has most of the performance gains. RAD and DRQ are concurrent work that practically do the same thing, I think DrQ has some extra features, there's some discussion between authors on the differences here

The linked RAD and QrD are extensions of existing algorithms with augmentation so I think we have reference results (and likely also implementations available).

Yep there are author implementations of both here and here.

I guess in more detail I'd propose an alternative replay buffer that can take augmentation params as arguments, that applies these when data is sampled from the buffer (using imgaug). Would start with off policy as maybe more care needs to be taken when doing this for on policy algos? And aim to recreate the results from the RAD paper (as the DrQ paper has additional algo changes, although both should probably be cited). I'm not sure I'll have the time (or the mujoco license) to run all of the experiments however.

Miffyli commented 3 years ago

And aim to recreate the results from the RAD paper (as the DrQ paper has additional algo changes, although both should probably be cited). I'm not sure I'll have the time (or the mujoco license) to run all of the experiments however.

This sounds good! Should be not too difficult to implement and not too messy. Implement them as wrappers and/or as a separate algorithm if you need to modify buffers.

I think you can focus on some of the clearest results in the paper to see if your implementation is correct (ProcGen/Coinrun is open and free). You can also use PyBullet via this trick, altho I can not say what the results should be like.