intel / intel-extension-for-tensorflow

Intel® Extension for TensorFlow*
Other
314 stars 39 forks source link

BFloat16 Mixed Precision Training Support For Arc Alchemist? If Not Could Float16 Training PLEASE Be Supported? #14

Closed tedliosu closed 12 months ago

tedliosu commented 1 year ago

Basically this question but I'd like to know the answer to the question for ITEX as well :smile:

Please let me know if there's any additional info I need to provide for my question to be answered. :+1:

ip2016 commented 1 year ago

Just to add to the question above, I have tried to enable automatic mixed precision support for training task (fine tuning NLP BERT) on Arc A380. Here what I did:

export ITEX_AUTO_MIXED_PRECISION=1
export ITEX_AUTO_MIXED_PRECISION_DATA_TYPE="BFLOAT16" #"FLOAT16"
export ITEX_AUTO_MIXED_PRECISION_ALLOWLIST_ADD="AvgPool3D,AvgPool"
export ITEX_AUTO_MIXED_PRECISION_INFERLIST_REMOVE="AvgPool3D,AvgPool"
export ITEX_AUTO_MIXED_PRECISION_LOG_PATH="./log"

I'm always getting the same error:

Epoch 1/2
2022-11-15 08:39:17.818540: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type XPU is enabled.
2022-11-15 08:39:21.456059: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:109] Run advanced auto mixed precision datatype BFLOAT16 on XPU
2022-11-15 08:39:21.520306: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:800] Saved pre-optimization graph as binary to /home/ubot/Projects/intel_tensorflow/log/graphdef_preop_1668519561512.pb
2022-11-15 08:39:21.620719: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:814] Saved pre-optimization graph as text to /home/ubot/Projects/intel_tensorflow/log/graphdef_preop_1668519561512.pb.txt
2022-11-15 08:39:23.113351: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:1723] Converted 1691/10080 nodes to bfloat16 precision using 138 cast(s) to bfloat16 (excluding Const and Variable casts)
2022-11-15 08:39:23.121577: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:800] Saved post-optimization graph as binary to /home/ubot/Projects/intel_tensorflow/log/graphdef_AutoMixedPrecision_1668519561512.pb
2022-11-15 08:39:23.228788: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:814] Saved post-optimization graph as text to /home/ubot/Projects/intel_tensorflow/log/graphdef_AutoMixedPrecision_1668519561512.pb.txt
2022-11-15 08:39:23.228966: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:846] Saved paint bucket info to /home/ubot/Projects/intel_tensorflow/log/paintbuckets_AutoMixedPrecision_1668519561512.txt
2022-11-15 08:39:27.780308: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:109] Run advanced auto mixed precision datatype BFLOAT16 on XPU
2022-11-15 08:39:27.861052: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:800] Saved pre-optimization graph as binary to /home/ubot/Projects/intel_tensorflow/log/graphdef_preop_1668519567850.pb
2022-11-15 08:39:27.990075: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:814] Saved pre-optimization graph as text to /home/ubot/Projects/intel_tensorflow/log/graphdef_preop_1668519567850.pb.txt
2022-11-15 08:39:28.393374: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:1723] Converted 0/12215 nodes to bfloat16 precision using 0 cast(s) to bfloat16 (excluding Const and Variable casts)
2022-11-15 08:39:28.403793: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:800] Saved post-optimization graph as binary to /home/ubot/Projects/intel_tensorflow/log/graphdef_AutoMixedPrecision_1668519567850.pb
2022-11-15 08:39:28.534511: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:814] Saved post-optimization graph as text to /home/ubot/Projects/intel_tensorflow/log/graphdef_AutoMixedPrecision_1668519567850.pb.txt
2022-11-15 08:39:28.534650: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:846] Saved paint bucket info to /home/ubot/Projects/intel_tensorflow/log/paintbuckets_AutoMixedPrecision_1668519567850.txt
2022-11-15 08:39:38.066532: W itex/core/utils/op_kernel.cc:342] ./itex/core/kernels/common/matmul_op.h895Aborted: Operation received an exception:Status: 3, message: could not create a primitive descriptor, in file ./itex/core/kernels/common/matmul_op.h:892
terminate called after throwing an instance of 'dnnl::error'
  what():  object is not initialized
Aborted (core dumped)

Extension version (manually built):

intel-extension-for-tensorflow @ file:///home/ubot/Downloads/intel_extension_for_tensorflow-1.1.0-cp310-cp310-linux_x86_64.whl
intel-extension-for-tensorflow-lib @ file:///home/ubot/Downloads/intel_extension_for_tensorflow_lib-1.1.0.1-cp310-cp310-linux_x86_64.whl
yiqianglee commented 1 year ago

@guizili0 please have a look, this seems a bug to me.

@tedliosu @ip2016 BF16 is not compiler native data type, Xe-Core has native support BF16 computation (like conv/gemm), but other ops will convert to FP32 as computation data type, so technically, we should support BF16 training on Arc GPU, but we haven't fully validated yet, there are possible bugs, you can have a try, and report issues if have.

guizili0 commented 1 year ago

Just to add to the question above, I have tried to enable automatic mixed precision support for training task (fine tuning NLP BERT) on Arc A380. Here what I did:

export ITEX_AUTO_MIXED_PRECISION=1
export ITEX_AUTO_MIXED_PRECISION_DATA_TYPE="BFLOAT16" #"FLOAT16"
export ITEX_AUTO_MIXED_PRECISION_ALLOWLIST_ADD="AvgPool3D,AvgPool"
export ITEX_AUTO_MIXED_PRECISION_INFERLIST_REMOVE="AvgPool3D,AvgPool"
export ITEX_AUTO_MIXED_PRECISION_LOG_PATH="./log"

I'm always getting the same error:

Epoch 1/2
2022-11-15 08:39:17.818540: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type XPU is enabled.
2022-11-15 08:39:21.456059: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:109] Run advanced auto mixed precision datatype BFLOAT16 on XPU
2022-11-15 08:39:21.520306: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:800] Saved pre-optimization graph as binary to /home/ubot/Projects/intel_tensorflow/log/graphdef_preop_1668519561512.pb
2022-11-15 08:39:21.620719: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:814] Saved pre-optimization graph as text to /home/ubot/Projects/intel_tensorflow/log/graphdef_preop_1668519561512.pb.txt
2022-11-15 08:39:23.113351: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:1723] Converted 1691/10080 nodes to bfloat16 precision using 138 cast(s) to bfloat16 (excluding Const and Variable casts)
2022-11-15 08:39:23.121577: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:800] Saved post-optimization graph as binary to /home/ubot/Projects/intel_tensorflow/log/graphdef_AutoMixedPrecision_1668519561512.pb
2022-11-15 08:39:23.228788: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:814] Saved post-optimization graph as text to /home/ubot/Projects/intel_tensorflow/log/graphdef_AutoMixedPrecision_1668519561512.pb.txt
2022-11-15 08:39:23.228966: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:846] Saved paint bucket info to /home/ubot/Projects/intel_tensorflow/log/paintbuckets_AutoMixedPrecision_1668519561512.txt
2022-11-15 08:39:27.780308: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:109] Run advanced auto mixed precision datatype BFLOAT16 on XPU
2022-11-15 08:39:27.861052: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:800] Saved pre-optimization graph as binary to /home/ubot/Projects/intel_tensorflow/log/graphdef_preop_1668519567850.pb
2022-11-15 08:39:27.990075: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:814] Saved pre-optimization graph as text to /home/ubot/Projects/intel_tensorflow/log/graphdef_preop_1668519567850.pb.txt
2022-11-15 08:39:28.393374: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:1723] Converted 0/12215 nodes to bfloat16 precision using 0 cast(s) to bfloat16 (excluding Const and Variable casts)
2022-11-15 08:39:28.403793: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:800] Saved post-optimization graph as binary to /home/ubot/Projects/intel_tensorflow/log/graphdef_AutoMixedPrecision_1668519567850.pb
2022-11-15 08:39:28.534511: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:814] Saved post-optimization graph as text to /home/ubot/Projects/intel_tensorflow/log/graphdef_AutoMixedPrecision_1668519567850.pb.txt
2022-11-15 08:39:28.534650: I itex/core/graph/auto_mixed_precision/auto_mixed_precision.cc:846] Saved paint bucket info to /home/ubot/Projects/intel_tensorflow/log/paintbuckets_AutoMixedPrecision_1668519567850.txt
2022-11-15 08:39:38.066532: W itex/core/utils/op_kernel.cc:342] ./itex/core/kernels/common/matmul_op.h895Aborted: Operation received an exception:Status: 3, message: could not create a primitive descriptor, in file ./itex/core/kernels/common/matmul_op.h:892
terminate called after throwing an instance of 'dnnl::error'
  what():  object is not initialized
Aborted (core dumped)

Extension version (manually built):

intel-extension-for-tensorflow @ file:///home/ubot/Downloads/intel_extension_for_tensorflow-1.1.0-cp310-cp310-linux_x86_64.whl
intel-extension-for-tensorflow-lib @ file:///home/ubot/Downloads/intel_extension_for_tensorflow_lib-1.1.0.1-cp310-cp310-linux_x86_64.whl

@ip2016 This issue should be fixed in https://github.com/intel/intel-extension-for-tensorflow/commit/1d4db87247f63d6c2cfdb2a454c8f0a8b3df8eba, you can local build a new itex whl to check this fix.

ip2016 commented 1 year ago

Thanks for the fix @guizili0 I'm not getting the error above anymore. However I'm now getting another error:

W itex/core/utils/op_kernel.cc:571] itex/core/kernels/gpu/cast_op.cc: 162Unimplemented: Cast int64 to bfloat16 is not supported
Traceback (most recent call last):
  File "/home/ubot/Projects/intel_tensorflow/v3.py", line 75, in <module>
    model.fit(
  File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/tensorflow/python/eager/execute.py", line 52, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.UnimplementedError: Graph execution error:

Detected at node 'tf_bert_for_sequence_classification/bert/Cast' defined at (most recent call last):
    File "/home/ubot/Projects/intel_tensorflow/v3.py", line 75, in <module>
      model.fit(
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/training.py", line 1673, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/training.py", line 1272, in train_function
      return step_function(self, iterator)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/training.py", line 1256, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/training.py", line 1245, in run_step
      outputs = model.train_step(data)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 1420, in train_step
      y_pred = self(x, training=True)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/training.py", line 558, in __call__
      return super().__call__(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 1652, in run_call_with_unpacked_inputs
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/transformers/models/bert/modeling_tf_bert.py", line 1665, in call
      outputs = self.bert(
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 1652, in run_call_with_unpacked_inputs
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/transformers/models/bert/modeling_tf_bert.py", line 837, in call
      extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
Node: 'tf_bert_for_sequence_classification/bert/Cast'
Cast int64 to bfloat16 is not supported

Looks like it happens when attention mask is converted to bfloat16 and somehow the code thinks that attention_mask has dtype int64. I converted it to tf.int32 before running train. Here is output of my train and validation datasets:

<MapDataset element_spec=({'input_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name=None), 'attention_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name=None)}, TensorSpec(shape=(None,), dtype=tf.int32, name=None))>
<MapDataset element_spec=({'input_ids': TensorSpec(shape=(None, 86), dtype=tf.int32, name=None), 'attention_mask': TensorSpec(shape=(None, 86), dtype=tf.int32, name=None)}, TensorSpec(shape=(None,), dtype=tf.int32, name=None))>
guizili0 commented 1 year ago

Thanks for the fix @guizili0 I'm not getting the error above anymore. However I'm now getting another error:

W itex/core/utils/op_kernel.cc:571] itex/core/kernels/gpu/cast_op.cc: 162Unimplemented: Cast int64 to bfloat16 is not supported
Traceback (most recent call last):
  File "/home/ubot/Projects/intel_tensorflow/v3.py", line 75, in <module>
    model.fit(
  File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/tensorflow/python/eager/execute.py", line 52, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.UnimplementedError: Graph execution error:

Detected at node 'tf_bert_for_sequence_classification/bert/Cast' defined at (most recent call last):
    File "/home/ubot/Projects/intel_tensorflow/v3.py", line 75, in <module>
      model.fit(
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/training.py", line 1673, in fit
      tmp_logs = self.train_function(iterator)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/training.py", line 1272, in train_function
      return step_function(self, iterator)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/training.py", line 1256, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/training.py", line 1245, in run_step
      outputs = model.train_step(data)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 1420, in train_step
      y_pred = self(x, training=True)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/training.py", line 558, in __call__
      return super().__call__(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 1652, in run_call_with_unpacked_inputs
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/transformers/models/bert/modeling_tf_bert.py", line 1665, in call
      outputs = self.bert(
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 65, in error_handler
      return fn(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/engine/base_layer.py", line 1132, in __call__
      outputs = call_fn(inputs, *args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 96, in error_handler
      return fn(*args, **kwargs)
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/transformers/modeling_tf_utils.py", line 1652, in run_call_with_unpacked_inputs
    File "/home/ubot/Projects/intel_tensorflow/venv/lib/python3.10/site-packages/transformers/models/bert/modeling_tf_bert.py", line 837, in call
      extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype)
Node: 'tf_bert_for_sequence_classification/bert/Cast'
Cast int64 to bfloat16 is not supported

Looks like it happens when attention mask is converted to bfloat16 and somehow the code thinks that attention_mask has dtype int64. I converted it to tf.int32 before running train. Here is output of my train and validation datasets:

<MapDataset element_spec=({'input_ids': TensorSpec(shape=(None, None), dtype=tf.int32, name=None), 'attention_mask': TensorSpec(shape=(None, None), dtype=tf.int32, name=None)}, TensorSpec(shape=(None,), dtype=tf.int32, name=None))>
<MapDataset element_spec=({'input_ids': TensorSpec(shape=(None, 86), dtype=tf.int32, name=None), 'attention_mask': TensorSpec(shape=(None, 86), dtype=tf.int32, name=None)}, TensorSpec(shape=(None,), dtype=tf.int32, name=None))>

Thanks for the verification, can you help to share your reproducer and TF version?

ip2016 commented 1 year ago

Here is the source code:

# %%
import tensorflow as tf

tf.config.list_physical_devices()

# %%
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding, DefaultDataCollator
import numpy as np

raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def tokenize_function(example):
    return tokenizer(example["sentence1"], example["sentence2"], truncation=True, padding=True)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)

# %%
print(tokenized_datasets)

# %%

data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
#data_collator = DefaultDataCollator(return_tensors="tf")

tf_train_dataset = tokenized_datasets["train"].to_tf_dataset(
    columns=["attention_mask", "input_ids", "token_type_ids"],
    label_cols=["labels"],
    shuffle=True,
    collate_fn=data_collator,
    batch_size=8,
)

tf_validation_dataset = tokenized_datasets["validation"].to_tf_dataset(
    columns=["attention_mask", "input_ids", "token_type_ids"],
    label_cols=["labels"],
    shuffle=False,
    collate_fn=data_collator,
    batch_size=8,
)

# %%
def cast_to_int32(x, y):
    t1 = tf.dtypes.cast(x['input_ids'], tf.int32)
    t2 = tf.dtypes.cast(x['attention_mask'], tf.int32)
    t3 = tf.dtypes.cast(y, tf.int32)
    return dict({'input_ids':t1, 'attention_mask':t2}), t3

tf_train_dataset = tf_train_dataset.map(cast_to_int32)
tf_validation_dataset = tf_validation_dataset.map(cast_to_int32)
print(tf_train_dataset)
print(tf_validation_dataset)

# %%
from transformers import TFAutoModelForSequenceClassification

model = TFAutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

# %%
model.summary()

# %%
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

model.compile(
    optimizer="adam",
    loss=loss,
    metrics=["accuracy"],
)
model.fit(
    tf_train_dataset,
    validation_data=tf_validation_dataset,
    epochs=2
)
guizili0 commented 1 year ago

@ip2016 Sorry for the delay response, we have fixed your issue with latest master, please help to check. thanks!

tedliosu commented 12 months ago

According to this bfloat16 should be supported on "Intel GPU", but unfortunately I don't have an Arc Alchemist card to test this nor do I plan on getting one anytime soon. Closing this issue for now.