poets-ai / elegy

A High Level API for Deep Learning in JAX
https://poets-ai.github.io/elegy/
MIT License
469 stars 32 forks source link

Donate model's memory buffer to jit/pmap functions. #226

Closed lkhphuc closed 2 years ago

lkhphuc commented 2 years ago

As discussed in Discord, using donate_argnums=1 in Jit/pmap will reduce GPU/TPU memory by 1/3.

Before: image0 After:

Screen_Shot_2022-03-23_at_09 00 54

Technically, only donate argnum in train_step_fn is necessary, since all other _step_fn got called inside train_step_fn anyway. But for consistency I add donate argnum to every step_fn anyway.

cgarciae commented 2 years ago

Amazing! Thanks a lot for doing this. LGTM.

Merging.

bhoov commented 2 years ago

My code randomly failed at the end of an epoch of training. I did some modification of the code and found that the call to train_on_batch caused it to fail. I rolled back to elegy==0.8.5 and everything works. Pretty sure this change is breaking (I am running the simple MNIST example code).

2022-04-30 18:39:44.022094: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2140] Execution of replica 0 failed: INVALID_ARGUMENT: Invalid buffer passed to Execute() as argument 0 to replica 0: INVALID_ARGUMENT: Donation requested for invalid buffer
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [52], in <cell line: 1>()
----> 1 history = model.fit(
      2     inputs=X_train,
      3     labels=y_train,
      4     epochs=5,
      5     steps_per_epoch=2,
      6     batch_size=10,
      7     validation_data=(X_test, y_test),
      8     shuffle=True,
      9     callbacks=[eg.callbacks.ModelCheckpoint("models/high-level", save_best_only=True)],
     10 )

Input In [39], in fit(self, inputs, labels, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, drop_remaining)
    102 if drop_remaining and not data_utils.has_batch_size(
    103     batch, data_handler.batch_size
    104 ):
    105     continue
--> 107 tmp_logs = self.train_on_batch(
    108     inputs=inputs,
    109     labels=labels,
    110 )

Input In [50], in train_on_batch(self, inputs, labels)
     88     labels = dict(target=labels)
     91 train_step_fn = self.train_step_fn[self._distributed_strategy]
---> 92 logs, model = train_step_fn(self, inputs, labels)
     93 print("Ending train step")
     94 return {}

ValueError: INVALID_ARGUMENT: Invalid buffer passed to Execute() as argument 0 to replica 0: INVALID_ARGUMENT: Donation requested for invalid buffer
lkhphuc commented 2 years ago

Hi @bhoov, can you paste a full stack trace, and a minimal example? I will try to take a look.