microsoft / tensorflow-directml-plugin

DirectML PluggableDevice plugin for TensorFlow 2
Apache License 2.0
179 stars 23 forks source link

Support XLA #332

Closed Zhaopudark closed 1 year ago

Zhaopudark commented 1 year ago

It seems do not support XLA (jit compile), such as the following demo:

physical_devices = tf.config.experimental.list_physical_devices(device_type='GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'],
              jit_compile=True)
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test,  y_test, verbose=2)

And the bug is shown as:

W tensorflow/core/framework/op_kernel.cc:1780] OP_REQUIRES failed at xla_ops.cc:296 : UNIMPLEMENTED: Could not find compiler for platform DML: NOT_FOUND: could not find registered compiler for platform DML -- was support for that platform linked in?

Thanks in advance. Best wishes.

PatriceVignola commented 1 year ago

DML doesn't support XLA because it is fundamentally different from how the compiler for other backends like CUDA and AMD works. To make the code work, remove the jit_compile=True line above.

Zhaopudark commented 1 year ago

Thanks. 😊