Closed innat closed 7 months ago
Tried to wrap the nn.DataParallel(model)
with keras.layers.TorchModuleWrapper
. However, it doesn't auto handle tensor device placements around gpus.
distribute_model = keras.Sequential(
[
keras.layers.TorchModuleWrapper(
nn.DataParallel(model)
)
]
)
Full code
Run on 2xTesla T4 GPU, where it failes to sample device placements. However, with single device, it works.
import os
os.environ["KERAS_BACKEND"] = "torch"
import torch
import torch.nn as nn
import keras
import numpy as np
def get_model():
inputs = keras.Input(shape=(784,), name="digits")
x = keras.layers.Dense(512)(inputs)
outputs = keras.layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
return model
model = get_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
distribute_model = keras.Sequential(
[
keras.layers.TorchModuleWrapper(
nn.DataParallel(model)
)
]
)
distribute_model=distribute_model.to(device)
distribute_model.summary() # OK
x_train = np.random.random((1000, 784))
y_train = np.random.randint(2, size=(1000, 10))
distribute_model.compile(
optimizer="adam",
loss="binary_crossentropy",
metrics=["accuracy"]
)
distribute_model.fit(
x_train,
y_train,
epochs=2,
batch_size=32,
validation_split=0.2
)
logs
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[11], line 6
1 distribute_model.compile(
2 optimizer="adam",
3 loss="binary_crossentropy",
4 metrics=["accuracy"]
5 )
----> 6 distribute_model.fit(
7 x_train,
8 y_train,
9 epochs=2,
10 batch_size=32,
11 validation_split=0.2
12 )
File /opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py:123, in filter_traceback.<locals>.error_handler(*args, **kwargs)
120 filtered_tb = _process_traceback_frames(e.__traceback__)
121 # To get the full stack trace, call:
122 # `keras.config.disable_traceback_filtering()`
--> 123 raise e.with_traceback(filtered_tb) from None
124 finally:
125 del filtered_tb
File /opt/conda/lib/python3.10/site-packages/keras/src/trainers/trainer.py:923, in Trainer._symbolic_build(self, iterator, data_batch)
921 y_pred = backend.compute_output_spec(self, x)
922 except Exception as e:
--> 923 raise RuntimeError(
924 "Unable to automatically build the model. "
925 "Please build it yourself before calling "
926 "fit/evaluate/predict. "
927 "A model is 'built' when its variables have "
928 "been created and its `self.built` attribute "
929 "is True. Usually, calling the model on a batch "
930 "of data is the right way to build it.\n"
931 "Exception encountered:\n"
932 f"'{e}'"
933 )
934 if compile_metrics_unbuilt:
935 # Build all metric state with `backend.compute_output_spec`.
936 backend.compute_output_spec(
937 self.compute_metrics,
938 x,
(...)
941 sample_weight=sample_weight,
942 )
RuntimeError: Unable to automatically build the model. Please build it yourself before calling fit/evaluate/predict. A model is 'built' when its variables have been created and its `self.built` attribute is True. Usually, calling the model on a batch of data is the right way to build it.
Exception encountered:
'Exception encountered when calling TorchModuleWrapper.call().
Could not automatically infer the output shape / dtype of 'torch_module_wrapper_2' (of type TorchModuleWrapper). Either the `TorchModuleWrapper.call()` method is incorrect, or you need to implement the `TorchModuleWrapper.compute_output_spec() / compute_output_shape()` method. Error encountered:
Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/parallel_apply.py", line 85, in _worker
output = module(*input, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 123, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
RuntimeError: Exception encountered when calling Dense.call().
Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)
Arguments received by Dense.call():
• inputs=torch.Tensor(shape=torch.Size([16, 784]), dtype=float32)
Arguments received by TorchModuleWrapper.call():
• args=('<KerasTensor shape=(32, 784), dtype=float32, sparse=None, name=keras_tensor_14>',)
• kwargs=<class 'inspect._empty'>'
@fchollet The instruction here regarding distributed training with torch is fine but is it possible with the model.fit - in other word, with more keras style? And using keras.distribution in this case?
Hi @innat, once nn.DataParallel is called on a Keras model, you cannot use the Keras APIs for training. Please use the PyTorch APIs for further training. Here is a useful example: https://keras.io/guides/distributed_training_with_torch/
In the future, we plan to have a keras.distribution
API that will have support for the PyTorch backend.
@nkovela1 Please let the original poster close the ticket. What's so rush!
I understand that if we wrap keras layer with torch.nn.DataParallel
, we should't able to make it work in keras. However unless you've read my first comment, I wrap that module to keras.TorchModuleWrapper
, which most probably should make it work in keras.
Apologies on closing prematurely, @innat. The module does work on a single device with keras.TorchModuleWrapper
, but if you are trying to use distribution with nn.DataParallel
, you should use the corresponding Torch distributed training APIs instead, since PyTorch controls the distribution and device placement from that point forward.
We don't have Keras-native support for this yet through our own training API.
@nkovela1 Thanks for the details. The DDP article showed in the official guide should work. I was looking for something bit transparent (like torch.nn.DataParallel) and that would work with model.fit if it's wrapped with torch module.
Is there any possibility to make it work in the API? If so, wdyt, what would be the main challenges?
I wrap that module to keras.TorchModuleWrapper, which most probably should make it work in keras.
That turns it into a Layer, but Keras Layers don't have a summary()
method -- that's only on the Model class. Hence the error. This is WAI.
I was looking for something bit transparent (like torch.nn.DataParallel) and that would work with model.fit if it's wrapped with torch module.
We don't have that for now, but in the future you should be able to use keras.distribution
with torch directly.
Thanks!
With
torch
backend, mult-gpu training with DataParallel causes the following issue.