Open Alxe1 opened 2 years ago
Hi @Alxe1
Thanks for filing the issue!
I tried your script - this is indeed a bug.
We have multiple workers on the same node in this case, and they are all accessing the same directory (the directory under which you run the script) by the time model.save()
is called, thus causing contention. It doesn't matter whether it is h5 or tf.
The way to fix this is to have separate directories for each worker.
This is exactly what is being done for AIR trainer. Take a look here.
I also want to mention that you need to use Session
API to have your saved checkpoint synced to driver or cloud in a multi-node set up. Otherwise, the saved checkpoint will only show up in whichever node that workers are running on. It may or may not be the head node.
The following is how I have modified your original script to use the new AIR API. PTAL. This needs to run with our ray 2.0.0 wheels (rc wheels are fine).
import argparse
import json
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import Callback
import ray.train as train
from ray.train import Trainer
from ray.train.tensorflow import TensorflowTrainer
from ray.air import session
from ray.air.config import ScalingConfig
from ray.air.checkpoint import Checkpoint
class TrainReportCallback(Callback):
def on_epoch_end(self, epoch, logs=None):
train.report(**logs)
def mnist_dataset(batch_size):
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train / np.float32(255)
y_train = y_train.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
return train_dataset
def build_and_compile_cnn_model(config):
learning_rate = config.get("lr", 0.001)
model = tf.keras.Sequential(
[
tf.keras.Input(shape=(28, 28)),
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
tf.keras.layers.Conv2D(32, 3, activation="relu"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10)
]
)
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.SGD(learning_rate=learning_rate),
metrics=["accuracy"]
)
return model
def train_func(config):
per_work_batch_size = config.get("batch_size", 64)
epochs = config.get("epochs", 1)
steps_per_epoch = config.get("steps_per_epoch", 70)
tf_config = json.loads(os.environ["TF_CONFIG"])
num_works = len(tf_config["cluster"]["worker"])
strategy = tf.distribute.MultiWorkerMirroredStrategy()
global_batch_size = per_work_batch_size * num_works
multi_worker_dataset = mnist_dataset(global_batch_size)
with strategy.scope():
multi_worker_model = build_and_compile_cnn_model(config)
history = multi_worker_model.fit(
multi_worker_dataset,
epochs=1,
steps_per_epoch=steps_per_epoch,
callbacks=[TrainReportCallback()]
)
multi_worker_model.save("./multi_worker_model", save_format="tf")
ckpt = Checkpoint.from_directory("./multi_worker_model")
result = history.history
session.report(result, checkpoint=ckpt)
# feel free to delete "./multi_worker_model" at this point.
return result
def train_tf_mnist(num_workers=2, use_gpu=False, epochs=1):
# trainer = Trainer(backend="tensorflow", num_workers=num_workers, use_gpu=use_gpu)
# trainer.start()
# results = trainer.run(
# train_func=train_func,
# config={"lr": 1e-3, "batch_size": 64, "epochs": epochs}
# )
# trainer.shutdown()
# print(f"Results: {results[0]}")
trainer = TensorflowTrainer(train_loop_per_worker=train_func, train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": epochs},
scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu))
trainer.fit()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--address", required=False, type=str, help="The address to use for ray")
parser.add_argument("--num_workers", "-n", type=int, default=4, help="Sets number of workers for training")
parser.add_argument("--use_gpu", action="store_true", default=False, help="Enable GPU training")
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for")
parser.add_argument("--smoke-test", action="store_true", default=False, help="Finish quickly for testing")
args, _ = parser.parse_known_args()
import ray
if args.smoke_test:
ray.init(num_cpus=2)
train_tf_mnist()
else:
ray.init(address=args.address)
train_tf_mnist(
num_workers=args.num_workers,
use_gpu=args.use_gpu,
epochs=args.epochs
)
cc @amogkam for visibility
Hi @Alxe1 Thanks for filing the issue! I tried your script - this is indeed a bug. We have multiple workers on the same node in this case, and they are all accessing the same directory (the directory under which you run the script) by the time
model.save()
is called, thus causing contention. It doesn't matter whether it is h5 or tf. The way to fix this is to have separate directories for each worker. This is exactly what is being done for AIR trainer. Take a look here. I also want to mention that you need to useSession
API to have your saved checkpoint synced to driver or cloud in a multi-node set up. Otherwise, the saved checkpoint will only show up in whichever node that workers are running on. It may or may not be the head node.The following is how I have modified your original script to use the new AIR API. PTAL. This needs to run with our ray 2.0.0 wheels (rc wheels are fine).
import argparse import json import os import numpy as np import tensorflow as tf from tensorflow.keras.callbacks import Callback import ray.train as train from ray.train import Trainer from ray.train.tensorflow import TensorflowTrainer from ray.air import session from ray.air.config import ScalingConfig from ray.air.checkpoint import Checkpoint class TrainReportCallback(Callback): def on_epoch_end(self, epoch, logs=None): train.report(**logs) def mnist_dataset(batch_size): (x_train, y_train), _ = tf.keras.datasets.mnist.load_data() x_train = x_train / np.float32(255) y_train = y_train.astype(np.int64) train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).repeat().batch(batch_size) return train_dataset def build_and_compile_cnn_model(config): learning_rate = config.get("lr", 0.001) model = tf.keras.Sequential( [ tf.keras.Input(shape=(28, 28)), tf.keras.layers.Reshape(target_shape=(28, 28, 1)), tf.keras.layers.Conv2D(32, 3, activation="relu"), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10) ] ) model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=tf.keras.optimizers.SGD(learning_rate=learning_rate), metrics=["accuracy"] ) return model def train_func(config): per_work_batch_size = config.get("batch_size", 64) epochs = config.get("epochs", 1) steps_per_epoch = config.get("steps_per_epoch", 70) tf_config = json.loads(os.environ["TF_CONFIG"]) num_works = len(tf_config["cluster"]["worker"]) strategy = tf.distribute.MultiWorkerMirroredStrategy() global_batch_size = per_work_batch_size * num_works multi_worker_dataset = mnist_dataset(global_batch_size) with strategy.scope(): multi_worker_model = build_and_compile_cnn_model(config) history = multi_worker_model.fit( multi_worker_dataset, epochs=1, steps_per_epoch=steps_per_epoch, callbacks=[TrainReportCallback()] ) multi_worker_model.save("./multi_worker_model", save_format="tf") ckpt = Checkpoint.from_directory("./multi_worker_model") result = history.history session.report(result, checkpoint=ckpt) # feel free to delete "./multi_worker_model" at this point. return result def train_tf_mnist(num_workers=2, use_gpu=False, epochs=1): # trainer = Trainer(backend="tensorflow", num_workers=num_workers, use_gpu=use_gpu) # trainer.start() # results = trainer.run( # train_func=train_func, # config={"lr": 1e-3, "batch_size": 64, "epochs": epochs} # ) # trainer.shutdown() # print(f"Results: {results[0]}") trainer = TensorflowTrainer(train_loop_per_worker=train_func, train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": epochs}, scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)) trainer.fit() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--address", required=False, type=str, help="The address to use for ray") parser.add_argument("--num_workers", "-n", type=int, default=4, help="Sets number of workers for training") parser.add_argument("--use_gpu", action="store_true", default=False, help="Enable GPU training") parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for") parser.add_argument("--smoke-test", action="store_true", default=False, help="Finish quickly for testing") args, _ = parser.parse_known_args() import ray if args.smoke_test: ray.init(num_cpus=2) train_tf_mnist() else: ray.init(address=args.address) train_tf_mnist( num_workers=args.num_workers, use_gpu=args.use_gpu, epochs=args.epochs )
cc @amogkam for visibility
Thank you, I will try it.
What happened + What you expected to happen
Saving tensorflow multi-worker model as 'tf' format get error:
Versions / Dependencies
Reproduction script
Issue Severity
No response