Closed nicofirst1 closed 3 years ago
Hello,
Wonder if minimal logging strategy would solve your issues? https://github.com/facebookresearch/EGG/blob/master/egg/core/interaction.py#L44
If I understood correctly the minimal logging strategy doesn't save anything but the message length. I need to have some interaction saved, but cannot afford (in terms of memory) for all to be saved.
Right, I see. Typically, I'd use minimal for the training data, and would be saving more for the test data. However, I guess this might not be working if the test dataset is large, too.
Regarding the options: I would prefer containing entire filtering logic in the LoggingStrategy; but having functions is perhaps too involved. Would it work if we store a boolean in LoggingStrategy and use it to subsample interactions there when filtering?
PS. Interactions are not stored on GPU, so I find it is surprising that 600 images would overflow 32Gb? Any ideas why this could happen?
Interactions are not stored on GPU, so I find it is surprising that 600 images would overflow 32Gb?
Sorry I mixed the numbers up, the slowdown occurs around the 400th batch iteration (12800 images).
Would it work if we store a boolean in LoggingStrategy and use it to subsample interactions there when filtering?
What kind of strategy would the subsample use?
but having functions is perhaps too involved.
If you're saying that it is too much work, I could do it and then pull request.
I wonder if uniform subsampling would be enough?
Here https://github.com/facebookresearch/EGG/blob/master/egg/core/interaction.py#L34 , something like that in a non-working pseudo-code:
to_take = bernouilli(self.subsample_prob, size=interactions.size(0))
...
return Interaction(
sender_input=sender_input[to_take] if self.store_sender_input else None,
....
It would work, but it's not scalable to other logic in my opinion.
What about something like this?
@dataclass
class LoggingStrategy:
def __init__(self, store_sender_input=True, store_receiver_input=True, store_labels=True, store_message=True,
store_receiver_output=True, store_message_length=True):
self.store_sender_input = self.bool_filtering(store_sender_input)
self.store_receiver_input = self.bool_filtering(store_receiver_input)
self.store_labels = self.bool_filtering(store_labels)
self.store_message = self.bool_filtering(store_message)
self.store_receiver_output = self.bool_filtering(store_receiver_output)
self.store_message_length = self.bool_filtering(store_message_length)
@staticmethod
def bool_filtering(to_store: bool):
def inner_filtering(inp):
if to_store:
return inp
else:
return None
return inner_filtering
def filtered_interaction(
self,
sender_input: Optional[torch.Tensor],
receiver_input: Optional[torch.Tensor],
labels: Optional[torch.Tensor],
message: Optional[torch.Tensor],
receiver_output: Optional[torch.Tensor],
message_length: Optional[torch.Tensor],
aux: Dict[str, torch.Tensor],
):
return Interaction(
sender_input=self.store_sender_input(sender_input),
receiver_input=self.store_receiver_input(receiver_input),
labels=self.store_labels(labels),
message=self.store_message(message),
receiver_output=self.store_receiver_output(receiver_output),
message_length=self.store_message_length(message_length),
aux=aux,
)
You could also use bool_filtering
as an argument of the init to allow easier subclassing, like this:
class LoggingStrategy:
def __init__(self, store_sender_input=True, store_receiver_input=True, store_labels=True, store_message=True,
store_receiver_output=True, store_message_length=True, default_filtering=None):
if default_filtering is None:
default_filtering=self.bool_filtering
self.store_sender_input = default_filtering(store_sender_input)
self.store_receiver_input = default_filtering(store_receiver_input)
self.store_labels = default_filtering(store_labels)
self.store_message = default_filtering(store_message)
self.store_receiver_output = default_filtering(store_receiver_output)
self.store_message_length = default_filtering(store_message_length)
That kind of would work, sure, but is that other logic
needed for your use-case? Otherwise, I'd prefer to wait until we actually have a need for this: as you see it is not hard to implement, so we can do that when it is needed.
Well, since I now need a stochastic approach it is necessary to me to have a custom logging strategy. With this implementation I could have my approach and keep compatibility with the previous versions (default_filtering excluded).
Here by stochastic you mean some non-uniform subsampling of interactions?
For now uniform is enough for me, but I will later investigate non uniform too.
Implementing the random logging I am blocked by this check.
I don't quite understand why wouldn't it be ok to append empty and non-empty interactions logs.
torch.cat would fail. How would you append a None tensor with a filled one? Can you remove the None/empty ones?
By filtering out the None values everything works
def _check_cat(lst):
if all(x is None for x in lst):
return None
# if some but not all are None: filter out None
if any(x is None for x in lst):
lst = [elem for elem in lst if elem is not None]
return torch.cat(lst, dim=0)
If we had this by default it would be missing potential cases where the game for some weird-buggy reasons is returning some empty tensors/values when it's not supposed. I don't know if it makes sense to have such filtering by default @eugene-kharitonov ?
Mmm, let's get rid of this if
branch? I don't think I put it for some strong reason, just being overly cautious.
(adapting @nicofirst1 's snippet)
def _check_cat(lst):
if all(x is None for x in lst):
return None
# if some but not all are None: filter out None
lst = [x for x in lst if x is not None]
return torch.cat(lst, dim=0)
Would you accept a PR with this and the LoggingStrategy?
Thanks. Send over a PR for a review!
Is your proposal related to a problem?
I am working with the coco dataset and the SenderReceiverRnnReinforce architecture in a similar fashion to the basic game. Both my sender and receiver get as input an image of size [batch, 3, 299,299].
In order to save the interaction between the two agents I am using the default LoggingStrategy, which saves the required infos for every batch of an epoch. This causes the memory to quickly fill with images and explode.
In my case with:
It takes around 20 batches to slow my machine down and 25 to be killed by the system.
Describe the solution you'd like to have implemented
Since the interaction saving is done in the forward pass, inside the batch loop, it can be tackled in two ways which I will now consider.
Inside the batch loop
The easiest fix would be discard the interaction inside the batch loop according to some metrics which can be either stochastic, e.g. discard interaction i with probability p, or batch dependent, i.e. keep interaction i every n batches.
With the LoggingStrategy
Although the previous fix sounds easier, it is not modular since it does not use the LoggingStrategy class. So a much cleaner way is to allow the user to define a custom LoggingStrategy and use it to filter the interaction in the forward pass. In this case the only available solution would be the stochastic one, since no information about the current batch id is passed to the forward pass.
Necessary modification
To achieve the above mentioned result the LoggingStrategy class should change the
store_*
attributes from bool to functions. These function can be either:init
time if no other information are needed, e.g. the stochastic approach.