Closed lkhphuc closed 2 years ago
Amazing! Thanks a lot for doing this. LGTM.
Merging.
My code randomly failed at the end of an epoch of training. I did some modification of the code and found that the call to train_on_batch
caused it to fail. I rolled back to elegy==0.8.5
and everything works. Pretty sure this change is breaking (I am running the simple MNIST example code).
2022-04-30 18:39:44.022094: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2140] Execution of replica 0 failed: INVALID_ARGUMENT: Invalid buffer passed to Execute() as argument 0 to replica 0: INVALID_ARGUMENT: Donation requested for invalid buffer
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [52], in <cell line: 1>()
----> 1 history = model.fit(
2 inputs=X_train,
3 labels=y_train,
4 epochs=5,
5 steps_per_epoch=2,
6 batch_size=10,
7 validation_data=(X_test, y_test),
8 shuffle=True,
9 callbacks=[eg.callbacks.ModelCheckpoint("models/high-level", save_best_only=True)],
10 )
Input In [39], in fit(self, inputs, labels, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, drop_remaining)
102 if drop_remaining and not data_utils.has_batch_size(
103 batch, data_handler.batch_size
104 ):
105 continue
--> 107 tmp_logs = self.train_on_batch(
108 inputs=inputs,
109 labels=labels,
110 )
Input In [50], in train_on_batch(self, inputs, labels)
88 labels = dict(target=labels)
91 train_step_fn = self.train_step_fn[self._distributed_strategy]
---> 92 logs, model = train_step_fn(self, inputs, labels)
93 print("Ending train step")
94 return {}
ValueError: INVALID_ARGUMENT: Invalid buffer passed to Execute() as argument 0 to replica 0: INVALID_ARGUMENT: Donation requested for invalid buffer
Hi @bhoov, can you paste a full stack trace, and a minimal example? I will try to take a look.
As discussed in Discord, using
donate_argnums=1
in Jit/pmap will reduce GPU/TPU memory by 1/3.Before: After:
Technically, only donate argnum in train_step_fn is necessary, since all other _step_fn got called inside train_step_fn anyway. But for consistency I add donate argnum to every step_fn anyway.