facebookresearch / EGG

EGG: Emergence of lanGuage in Games
MIT License
281 stars 99 forks source link

Custom LoggingStrategy #169

Closed nicofirst1 closed 3 years ago

nicofirst1 commented 3 years ago

Is your proposal related to a problem?

I am working with the coco dataset and the SenderReceiverRnnReinforce architecture in a similar fashion to the basic game. Both my sender and receiver get as input an image of size [batch, 3, 299,299].

In order to save the interaction between the two agents I am using the default LoggingStrategy, which saves the required infos for every batch of an epoch. This causes the memory to quickly fill with images and explode.

In my case with:

It takes around 20 batches to slow my machine down and 25 to be killed by the system.

Describe the solution you'd like to have implemented

Since the interaction saving is done in the forward pass, inside the batch loop, it can be tackled in two ways which I will now consider.

Inside the batch loop

The easiest fix would be discard the interaction inside the batch loop according to some metrics which can be either stochastic, e.g. discard interaction i with probability p, or batch dependent, i.e. keep interaction i every n batches.

With the LoggingStrategy

Although the previous fix sounds easier, it is not modular since it does not use the LoggingStrategy class. So a much cleaner way is to allow the user to define a custom LoggingStrategy and use it to filter the interaction in the forward pass. In this case the only available solution would be the stochastic one, since no information about the current batch id is passed to the forward pass.

Necessary modification

To achieve the above mentioned result the LoggingStrategy class should change the store_* attributes from bool to functions. These function can be either:

eugene-kharitonov commented 3 years ago

Hello,

Wonder if minimal logging strategy would solve your issues? https://github.com/facebookresearch/EGG/blob/master/egg/core/interaction.py#L44

nicofirst1 commented 3 years ago

If I understood correctly the minimal logging strategy doesn't save anything but the message length. I need to have some interaction saved, but cannot afford (in terms of memory) for all to be saved.

eugene-kharitonov commented 3 years ago

Right, I see. Typically, I'd use minimal for the training data, and would be saving more for the test data. However, I guess this might not be working if the test dataset is large, too.

Regarding the options: I would prefer containing entire filtering logic in the LoggingStrategy; but having functions is perhaps too involved. Would it work if we store a boolean in LoggingStrategy and use it to subsample interactions there when filtering?

PS. Interactions are not stored on GPU, so I find it is surprising that 600 images would overflow 32Gb? Any ideas why this could happen?

nicofirst1 commented 3 years ago

Interactions are not stored on GPU, so I find it is surprising that 600 images would overflow 32Gb?

Sorry I mixed the numbers up, the slowdown occurs around the 400th batch iteration (12800 images).

Would it work if we store a boolean in LoggingStrategy and use it to subsample interactions there when filtering?

What kind of strategy would the subsample use?

but having functions is perhaps too involved.

If you're saying that it is too much work, I could do it and then pull request.

eugene-kharitonov commented 3 years ago

I wonder if uniform subsampling would be enough?

Here https://github.com/facebookresearch/EGG/blob/master/egg/core/interaction.py#L34 , something like that in a non-working pseudo-code:

to_take = bernouilli(self.subsample_prob, size=interactions.size(0))
...
return Interaction(
            sender_input=sender_input[to_take] if self.store_sender_input else None,
           ....
nicofirst1 commented 3 years ago

It would work, but it's not scalable to other logic in my opinion.

What about something like this?


@dataclass
class LoggingStrategy:

    def __init__(self, store_sender_input=True, store_receiver_input=True, store_labels=True, store_message=True,
                 store_receiver_output=True, store_message_length=True):

        self.store_sender_input = self.bool_filtering(store_sender_input)
        self.store_receiver_input = self.bool_filtering(store_receiver_input)
        self.store_labels = self.bool_filtering(store_labels)
        self.store_message = self.bool_filtering(store_message)
        self.store_receiver_output = self.bool_filtering(store_receiver_output)
        self.store_message_length = self.bool_filtering(store_message_length)

    @staticmethod
    def bool_filtering(to_store: bool):

        def inner_filtering(inp):
            if to_store:
                return inp
            else:
                return None

        return inner_filtering

    def filtered_interaction(
            self,
            sender_input: Optional[torch.Tensor],
            receiver_input: Optional[torch.Tensor],
            labels: Optional[torch.Tensor],
            message: Optional[torch.Tensor],
            receiver_output: Optional[torch.Tensor],
            message_length: Optional[torch.Tensor],
            aux: Dict[str, torch.Tensor],
    ):

        return Interaction(
            sender_input=self.store_sender_input(sender_input),
            receiver_input=self.store_receiver_input(receiver_input),
            labels=self.store_labels(labels),
            message=self.store_message(message),
            receiver_output=self.store_receiver_output(receiver_output),
            message_length=self.store_message_length(message_length),
            aux=aux,
        )
nicofirst1 commented 3 years ago

You could also use bool_filtering as an argument of the init to allow easier subclassing, like this:

class LoggingStrategy:

    def __init__(self, store_sender_input=True, store_receiver_input=True, store_labels=True, store_message=True,
                 store_receiver_output=True, store_message_length=True, default_filtering=None):

        if default_filtering is None:
            default_filtering=self.bool_filtering
        self.store_sender_input = default_filtering(store_sender_input)
        self.store_receiver_input = default_filtering(store_receiver_input)
        self.store_labels = default_filtering(store_labels)
        self.store_message = default_filtering(store_message)
        self.store_receiver_output = default_filtering(store_receiver_output)
        self.store_message_length = default_filtering(store_message_length)
eugene-kharitonov commented 3 years ago

That kind of would work, sure, but is that other logic needed for your use-case? Otherwise, I'd prefer to wait until we actually have a need for this: as you see it is not hard to implement, so we can do that when it is needed.

nicofirst1 commented 3 years ago

Well, since I now need a stochastic approach it is necessary to me to have a custom logging strategy. With this implementation I could have my approach and keep compatibility with the previous versions (default_filtering excluded).

eugene-kharitonov commented 3 years ago

Here by stochastic you mean some non-uniform subsampling of interactions?

nicofirst1 commented 3 years ago

For now uniform is enough for me, but I will later investigate non uniform too.

nicofirst1 commented 3 years ago

Implementing the random logging I am blocked by this check.

I don't quite understand why wouldn't it be ok to append empty and non-empty interactions logs.

robertodessi commented 3 years ago

torch.cat would fail. How would you append a None tensor with a filled one? Can you remove the None/empty ones?

nicofirst1 commented 3 years ago

By filtering out the None values everything works


        def _check_cat(lst):
            if all(x is None for x in lst):
                return None
            # if some but not all are None: filter out None
            if any(x is None for x in lst):
                lst = [elem for elem in lst if elem is not None]
            return torch.cat(lst, dim=0)
robertodessi commented 3 years ago

If we had this by default it would be missing potential cases where the game for some weird-buggy reasons is returning some empty tensors/values when it's not supposed. I don't know if it makes sense to have such filtering by default @eugene-kharitonov ?

eugene-kharitonov commented 3 years ago

Mmm, let's get rid of this if branch? I don't think I put it for some strong reason, just being overly cautious.

(adapting @nicofirst1 's snippet)

        def _check_cat(lst):
            if all(x is None for x in lst):
                return None
            # if some but not all are None: filter out None
            lst = [x for x in lst if x is not None]
            return torch.cat(lst, dim=0)
nicofirst1 commented 3 years ago

Would you accept a PR with this and the LoggingStrategy?

eugene-kharitonov commented 3 years ago

Thanks. Send over a PR for a review!

robertodessi commented 3 years ago

177 addressed this.