ClashLuke / HeavyBall

Efficient optimizers
BSD 2-Clause "Simplified" License
98 stars 5 forks source link

CacheLimitExceeded error when training #16

Open marthinwurer opened 13 hours ago

marthinwurer commented 13 hours ago
W1126 16:03:16.529497 1228885 torch/_dynamo/convert_frame.py:844] [1/8] torch._dynamo hit config.cache_size_limit (8)
W1126 16:03:16.529497 1228885 torch/_dynamo/convert_frame.py:844] [1/8]    function: '_compilable_update_' (/home/marthinwurer/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/utils.py:694)
W1126 16:03:16.529497 1228885 torch/_dynamo/convert_frame.py:844] [1/8]    last reason: 1/0: tensor 'L['p'][0]' size mismatch at index 0. expected 16, actual 64
W1126 16:03:16.529497 1228885 torch/_dynamo/convert_frame.py:844] [1/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1126 16:03:16.529497 1228885 torch/_dynamo/convert_frame.py:844] [1/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.
1batch [00:14, 14.43s/batch]

---------------------------------------------------------------------------
CacheLimitExceeded                        Traceback (most recent call last)
Cell In[21], line 5
      3 results = []
      4 for i in range(runs):
----> 5     result, stats = train_run(1e-3)
      6     results.append(result)
      7     print("result:", result)

Cell In[20], line 20, in train_run(lr)
     17 stats = defaultdict(list)
     19 start_time = time.time()
---> 20 train_loop(autoencoder, train_dataloader, optimizer, 10, stats)
     21 result = test_loop(autoencoder, test_dataloader)
     22 return result, stats

Cell In[14], line 24, in train_loop(model, dataloader, optimizer, epochs, stats)
     19             image = data.cuda()
     20 #             print(image.shape)
     21 #             break
     22 
     23     #         loss = train_batch(image, model, optimizer, autoencoder.spectral_loss)
---> 24             full_loss, losses = train_batch(model, [image], image, optimizer, F.mse_loss)
     26             if isinstance(full_loss, tuple):
     27                 loss = full_loss

Cell In[13], line 19, in train_batch(model, inputs, targets, optimizer, loss_func)
     15 loss = torch.mean(losses)
     17 loss.backward()
---> 19 optimizer.step()
     21 return loss.detach().item(), losses.detach()

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/torch/optim/optimizer.py:487, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    482         else:
    483             raise RuntimeError(
    484                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    485             )
--> 487 out = func(*args, **kwargs)
    488 self._optimizer_step_code()
    490 # call optimizer step post hooks

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/utils.py:601, in StatefulOptimizer.step(self, closure)
    599 for top_group in self.param_groups:
    600     for group in self.get_groups(top_group):
--> 601         self._step(group)
    602         if self.use_ema:
    603             self.ema_update(group)

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/precond_schedule_palm_foreach_soap.py:105, in PrecondSchedulePaLMForeachSOAP._step(self, group)
    102 precond = project(exp_avg_projected, state['Q'], True)
    104 update_preconditioner(g, state, max_precond_dim, precondition_1d, old_debiased2, update_precond)
--> 105 update_param_([p], [precond], step_size, group["weight_decay"], caution=group['caution'], grad=[g])

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/heavyball/utils.py:719, in update_param_(param, update, lr, decay, add_fn, caution, grad)
    717 if add_fn is None:
    718     add_fn = stochastic_add_
--> 719 _compilable_update_(param, update, decay, add_fn, lr, caution, grad)

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py:465, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    460 saved_dynamic_layer_stack_depth = (
    461     torch._C._functorch.get_dynamic_layer_stack_depth()
    462 )
    464 try:
--> 465     return fn(*args, **kwargs)
    466 finally:
    467     # Restore the dynamic layer stack depth if necessary.
    468     torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
    469         saved_dynamic_layer_stack_depth
    470     )

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:1269, in CatchErrorsWrapper.__call__(self, frame, cache_entry, frame_state)
   1263             return hijacked_callback(
   1264                 frame, cache_entry, self.hooks, frame_state
   1265             )
   1267 with compile_lock, _disable_current_modes():
   1268     # skip=1: skip this frame
-> 1269     return self._torchdynamo_orig_callable(
   1270         frame, cache_entry, self.hooks, frame_state, skip=1
   1271     )

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:526, in ConvertFrameAssert.__call__(self, frame, cache_entry, hooks, frame_state, skip)
    510 compile_id = CompileId(frame_id, frame_compile_id)
    512 signpost_event(
    513     "dynamo",
    514     "_convert_frame_assert._compile",
   (...)
    523     },
    524 )
--> 526 return _compile(
    527     frame.f_code,
    528     frame.f_globals,
    529     frame.f_locals,
    530     frame.f_builtins,
    531     self._torchdynamo_orig_callable,
    532     self._one_graph,
    533     self._export,
    534     self._export_constraints,
    535     hooks,
    536     cache_entry,
    537     cache_size,
    538     frame,
    539     frame_state=frame_state,
    540     compile_id=compile_id,
    541     skip=skip + 1,
    542 )

File ~/.pyenv/versions/3.9.6/envs/minerl/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py:859, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
    844 log.warning(
    845     "torch._dynamo hit config.%s (%s)\n"
    846     "   function: %s\n"
   (...)
    854     troubleshooting_url,
    855 )
    856 if config.skip_code_recursive_on_cache_limit_hit and justknobs_check(
    857     "pytorch/compiler:skip_code_recursive_on_cache_limit_hit"
    858 ):
--> 859     raise CacheLimitExceeded(f"{limit_type} reached")
    860 else:
    861     # do not recursively skip frames
    862     unimplemented(f"{limit_type} reached")

CacheLimitExceeded: cache_size_limit reached

Started my training with a relatively complex resnet autoencoder, got this error.

ClashLuke commented 13 hours ago

Oh, good catch. Following @gau-nernst's recommendation, I've disabled dynamic compilations, meaning you will hit cache size limits for small caches in the optimizer.\ This can be safely ignored by increasing the cache size limit (128 works well in my runs).

I'll add documentation for that tomorrow.

marthinwurer commented 13 hours ago

Increased it to 128 with

import torch._dynamo.config
torch._dynamo.config.cache_size_limit = 128

And now it runs. However, it takes 15-30s for it to compile and run the first batch.

gau-nernst commented 12 hours ago

@ClashLuke Btw you can also use torch._dynamo.utils.disable_cache_limit() within the optimizer. Should be safe I think

https://github.com/pytorch/ao/blob/5eb6339e0b6f413c74a3dfd5e7f53449474723fc/torchao/prototype/low_bit_optim/adam.py#L90-L92

Also, did you observe faster speed compared to dynamic-shape compile used previously?

ClashLuke commented 1 hour ago

@gau-nernst super cool, wasn't aware of that! I'll add that now, thank you for the pointer

ClashLuke commented 1 hour ago

@marthinwurer

However, it takes 15-30s for it to compile and run the first batch.

Hm, yeah, that doesn't seem ideal for prototyping, though the speedups are definitely useful for longer runs.\ I'll expose compile_mode a bit more aggressively