google-research / human-scene-transformer

Human Scene Transformer: A framework for trajectory prediction and wrappers for reframing the JRDB dataset for the prediction task.
https://human-scene-transformer.github.io/
Apache License 2.0
47 stars 8 forks source link

tf Softmax input dimension error #11

Closed AlfredMoore closed 1 year ago

AlfredMoore commented 1 year ago

System version: Gcloud debian 11 Cpu: C3 8vCPU Memory: 64 GB

Software version1:

Python 3.9.2
numpy                         1.26.0
tensorflow                    2.14.0
open3d                        0.17.0
opencv-python-headless        4.8.1.78

Trace back:

Traceback (most recent call last):
  File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/train.py", line 141, in <module>
    app.run(main)
  File "/home/47800/.local/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/47800/.local/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/train.py", line 124, in main
    train_model.train_model(
  File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/train_model.py", line 252, in train_model
    train_step(train_iter)
  File "/home/47800/.local/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/47800/.local/lib/python3.9/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
tensorflow.python.framework.errors_impl.AbortedError: Graph execution error:

Detected at node while/body/_1/while/human_trajectory_scene_transformer/feature_attn_agent_encoder_learned_layer/multi_head_attention/softmax/Softmax defined at (most recent call last):
<stack traces unavailable>
Input dims must be <= 5 and >=1
         [[{{node while/body/_1/while/human_trajectory_scene_transformer/feature_attn_agent_encoder_learned_layer/multi_head_attention/softmax/Softmax}}]] [Op:__inference_train_step_67952]

When I try to run 'train.py', it happens. I think it might be caused by dataload or preprocessed data itself. I am still trying to fix it......

Software version2

Python == 3.8
tensorflow == 2.13

the behavior changes.

ARNING:tensorflow:From /home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/jrdb/input_fn.py:555: load (from tensorflow.python.data.experimental.ops.io) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.load(...)` instead.
W1016 02:08:06.098888 140408121303680 deprecation.py:364] From /home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/jrdb/input_fn.py:555: load (from tensorflow.python.data.experimental.ops.io) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.data.Dataset.load(...)` instead.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 4645555079012573616
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:1: Invalid control characters encountered in text.
......
......
[libprotobuf ERROR external/com_google_protobuf/src/google/protobuf/text_format.cc:337] Error parsing text-format tensorflow.data.experimental.DistributedSnapshotMetadata: 1:3: Expected identifier, got: 6459351093901513222
I1016 02:08:23.444825 140408121303680 train_model.py:151] Model created on device.
2023-10-16 02:08:23.611301: W tensorflow/core/framework/dataset.cc:956] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
I1016 02:08:24.023096 140408121303680 train_model.py:245] Beginning training.
Traceback (most recent call last):
  File "train.py", line 141, in <module>
    app.run(main)
  File "/home/47800/miniconda3/envs/hstpy38/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/47800/miniconda3/envs/hstpy38/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "train.py", line 124, in main
    train_model.train_model(
  File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/train_model.py", line 252, in train_model
    train_step(train_iter)
  File "/home/47800/miniconda3/envs/hstpy38/lib/python3.8/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/tmp/__autograph_generated_fileq49vjt2f.py", line 93, in tf__train_step
    ag__.for_stmt(ag__.converted_call(ag__.ld(tf).range, (ag__.converted_call(ag__.ld(tf).constant, (ag__.ld(train_params).batches_per_train_step,), None, fscope),), None, fscope), None, loop_body_1, get_state_4, set_state_4, (), {'iterate_names': '_'})
  File "/tmp/__autograph_generated_fileq49vjt2f.py", line 91, in loop_body_1
    ag__.converted_call(ag__.ld(strategy).run, (ag__.ld(step_fn),), dict(args=(ag__.converted_call(ag__.ld(next), (ag__.ld(iterator),), None, fscope),), options=ag__.converted_call(ag__.ld(tf).distribute.RunOptions, (), dict(experimental_enable_dynamic_batch_size=False), fscope)), fscope)
  File "/tmp/__autograph_generated_fileq49vjt2f.py", line 21, in step_fn
    loss_dict = ag__.converted_call(ag__.ld(loss_obj), (ag__.ld(output_batch), ag__.ld(predictions)), None, fscope_1)
  File "/tmp/__autograph_generated_file4bviil8d.py", line 12, in tf____call__
    retval_ = ag__.converted_call(ag__.ld(self).call, (ag__.ld(input_batch), ag__.ld(predictions)), None, fscope)
  File "/tmp/__autograph_generated_filevzzr7lrl.py", line 13, in tf__call
    loss_dict = (ag__.ld(position_loss) | ag__.ld(mixture_loss))
TypeError: in user code:

    File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/train_model.py", line 184, in step_fn  *
        loss_dict = loss_obj(output_batch, predictions)
    File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/losses.py", line 37, in __call__  *
        return self.call(input_batch, predictions)
    File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/losses.py", line 456, in call  *
        loss_dict = position_loss | mixture_loss

    TypeError: unsupported operand type(s) for |: 'dict' and 'dict'
Tim-Salzmann commented 1 year ago

Hi Mo,

Thanks for reaching out! Could you please clarify which gin configuration files you are using? E.g. please post the exact command you are running and if you made any changes to the gin configurations.

python train.py --gin_files ...

Best Tim

AlfredMoore commented 1 year ago

Hi Mo,

Thanks for reaching out! Could you please clarify which gin configuration files you are using? E.g. please post the exact command you are running and if you made any changes to the gin configurations.

python train.py --gin_files ...

Best Tim

Hi Dr. Tim,

Thank you so much for your help. Same errors show up in Py3.10. I am running command lines below.

cd /home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer
export PYTHONPATH=/path/to/directory:/home/47800/SocialNavigation_v2/human-scene-transformer
python3 train.py --model_base_dir=./model/jrdb  --gin_files=./config/jrdb/training_params.gin --gin_files=./config/jrdb/model_params.gin --gin_files=./config/jrdb/dataset_params.gin --gin_files=./config/jrdb/metrics.gin --dataset=JRDB

With the differences

--- /home/47800/originalHST/human-scene-transformer/human_scene_transformer/config/jrdb/dataset_params.gin      2023-10-18 02:40:21.036342842 +0000
+++ /home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/config/jrdb/dataset_params.gin      2023-10-15 21:01:53.859915432 +0000
@@ -55,7 +55,7 @@
  'tressider-2019-04-26_3_test']

-JRDBDatasetParams.path = '<dataset_path>'
+JRDBDatasetParams.path = '/home/47800/SocialNavigation_v2/pre_tf_dataset'

 JRDBDatasetParams.train_scenes = %TRAIN_SCENES
 JRDBDatasetParams.eval_scenes = %TEST_SCENES

In the '/home/47800/SocialNavigation_v2/pre_tf_dataset', I have

(base) 47800@instance-3:~/originalHST$ ls /home/47800/SocialNavigation_v2/pre_tf_dataset
bytes-cafe-2019-02-07_0                     hewlett-class-2019-01-23_0_test              outdoor-coupa-cafe-2019-02-06_0_test
clark-center-2019-02-28_0                   hewlett-class-2019-01-23_1_test              packard-poster-session-2019-03-20_0
clark-center-2019-02-28_1                   hewlett-packard-intersection-2019-01-24_0    packard-poster-session-2019-03-20_1
clark-center-intersection-2019-02-28_0      huang-2-2019-01-25_0                         packard-poster-session-2019-03-20_2
cubberly-auditorium-2019-04-22_0            huang-2-2019-01-25_1_test                    quarry-road-2019-02-28_0_test
cubberly-auditorium-2019-04-22_1_test       huang-basement-2019-01-25_0                  serra-street-2019-01-30_0_test
discovery-walk-2019-02-28_0_test            huang-intersection-2019-01-22_0_test         stlc-111-2019-04-19_0
discovery-walk-2019-02-28_1_test            huang-lane-2019-02-12_0                      stlc-111-2019-04-19_1_test
food-trucks-2019-02-12_0_test               indoor-coupa-cafe-2019-02-06_0_test          stlc-111-2019-04-19_2_test
forbes-cafe-2019-01-22_0                    jordan-hall-2019-04-22_0                     svl-meeting-gates-2-2019-04-08_0
gates-159-group-meeting-2019-04-03_0        lomita-serra-intersection-2019-01-30_0_test  svl-meeting-gates-2-2019-04-08_1
gates-ai-lab-2019-02-08_0                   memorial-court-2019-03-16_0                  tressider-2019-03-16_0
gates-ai-lab-2019-04-17_0_test              meyer-green-2019-03-16_0                     tressider-2019-03-16_1
gates-basement-elevators-2019-01-17_0_test  meyer-green-2019-03-16_1_test                tressider-2019-03-16_2_test
gates-basement-elevators-2019-01-17_1       nvidia-aud-2019-01-25_0_test                 tressider-2019-04-26_0_test
gates-foyer-2019-01-17_0_test               nvidia-aud-2019-04-18_0                      tressider-2019-04-26_1_test
gates-to-clark-2019-02-28_0_test            nvidia-aud-2019-04-18_1_test                 tressider-2019-04-26_2
gates-to-clark-2019-02-28_1                 nvidia-aud-2019-04-18_2_test                 tressider-2019-04-26_3_test

Updated: I ran the lines

tf.print('\n**************ds_train:**************\n',ds_train)
tf.print('\n**************dist_train_dataset:**************\n',dist_train_dataset)

got

**************ds_train:**************
 <_ShuffleDataset element_spec={'agents/position': TensorSpec(shape=(8, 19, 2), dtype=tf.float32, name=None), 'agents/orientation': TensorSpec(shape=(8, 19, 1), dtype=tf.float32, name=None), 'agents/keypoints': TensorSpec(shape=(8, 19, 99), dtype=tf.float32, name=None), 'robot/position': TensorSpec(shape=(19, 3), dtype=tf.float32, name=None), 'scene/id': TensorSpec(shape=(), dtype=tf.string, name=None), 'scene/timestamp': TensorSpec(shape=(), dtype=tf.int64, name=None), 'agents/gaze': TensorSpec(shape=(8, 19, 1), dtype=tf.float32, name=None)}>

**************dist_train_dataset:**************
 <tensorflow.python.distribute.input_lib.DistributedDataset object at 0x7f39da9e83a0>

Q: Is there might be some problems with ' scene/timestamp': TensorSpec(shape=(), dtype=tf.int64, name=None)' ?

ran

tf.print('samples:',
  tf.data.experimental.sample_from_datasets(ds_train, weights=None, seed=None, stop_on_empty_dataset=False)
  )

got

  File "/home/47800/miniconda3/envs/hstpy310/lib/python3.10/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 526, in __len__
    raise TypeError("The dataset is infinite.")
TypeError: The dataset is infinite.

Q: Is the dataset supposed to be infinite?

Tim-Salzmann commented 1 year ago

Hi Mo,

Unfortunately, we are still struggling to reproduce the error.

Does the error occur in (or even before) the very first training iteration?

Does the same error occur when you are not training but evaluating?

The shapes you are getting from tf.print('\n**************ds_train:**************\n',ds_train) look good to me except that the batch dimension is missing. Am I right in assuming that you are running this before ds_train = ds_train.batch(train_params.batch_size, drop_remainder=True) is run in train.py?

Q: Is there might be some problems with ' scene/timestamp': TensorSpec(shape=(), dtype=tf.int64, name=None)'

The timestep is a scalar, so this is expected.

Q: Is the dataset supposed to be infinite?

Yes! once the train dataset reaches its end it will be repeated https://github.com/google-research/human-scene-transformer/blob/7e9b9278f253b4ba48c1931d6eccaca8362efa4b/human_scene_transformer/jrdb/input_fn.py#L696-L697

AlfredMoore commented 1 year ago

Does the error occur in (or even before) the very first training iteration?

Yes, at the first iteration.

Does the same error occur when you are not training but evaluating?


/home/47800/miniconda3/envs/hstpy310/lib/python3.10/site-packages/tensorflow/python/data/ops/map_op.py:35: UserWarning: The `deterministic` argument has no effect unless the `num_parallel_calls` argument is specified.
warnings.warn("The `deterministic` argument has no effect unless the "
2023-10-18 18:13:54.933894: W tensorflow/core/framework/op_kernel.cc:1839] OP_REQUIRES failed at mkl_softmax_op.cc:252 : ABORTED: Input dims must be <= 5 and >=1
Traceback (most recent call last):
File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/jrdb/eval.py", line 166, in <module>
app.run(main)
File "/home/47800/miniconda3/envs/hstpy310/lib/python3.10/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/47800/miniconda3/envs/hstpy310/lib/python3.10/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/jrdb/eval.py", line 162, in main
evaluation(_CKPT_PATH.value)
File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/jrdb/eval.py", line 64, in evaluation
_, _ = model(next(iter(dataset.batch(1))), training=False)
File "/home/47800/miniconda3/envs/hstpy310/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/model/model.py", line 178, in call
input_batch = self.agent_encoding_layer(input_batch, training=training)
File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/model/agent_encoder.py", line 189, in call
attn_out, attn_score = self.attn_layer(
tensorflow.python.framework.errors_impl.AbortedError: Exception encountered when calling layer 'softmax' (type Softmax).

{{function_node wrappedSoftmaxdevice/job:localhost/replica:0/task:0/device:CPU:0}} Input dims must be <= 5 and >=1 [Op:Softmax] name:

Call arguments received by layer 'softmax' (type Softmax): • inputs=tf.Tensor(shape=(1, 11, 19, 4, 1, 4), dtype=float32) • mask=tf.Tensor(shape=(1, 11, 19, 1, 1, 4), dtype=bool)


Good news is that the dataset has been loaded as expected, but there might be an extra dimension of the input data.
Tim-Salzmann commented 1 year ago

Hi Mo,

Could you please try and set the environment variable TF_DISABLE_MKL=1 to disable MKL backend for tf?

Thanks Tim

AlfredMoore commented 1 year ago

Hi Tim,

Could you please try and set the environment variable TF_DISABLE_MKL=1 to disable MKL backend for tf?

Now I am using the command lines,

export TF_DISABLE_MKL=1
cd /home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer
export PYTHONPATH=/path/to/directory:/home/47800/SocialNavigation_v2/human-scene-transformer
python /home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/jrdb/eval.py --model_path=./model/jrdb/ --checkpoint_path=./model/jrdb/ckpts/ckpt-30

but still have additional dim

  File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/jrdb/eval.py", line 64, in evaluation
    _, _ = model(next(iter(dataset.batch(1))), training=False)
  File "/home/47800/miniconda3/envs/hstpy310/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/model/model.py", line 178, in call
    input_batch = self.agent_encoding_layer(input_batch, training=training)
  File "/home/47800/SocialNavigation_v2/human-scene-transformer/human_scene_transformer/model/agent_encoder.py", line 189, in call
    attn_out, attn_score = self.attn_layer(
tensorflow.python.framework.errors_impl.AbortedError: Exception encountered when calling layer 'softmax' (type Softmax).

{{function_node __wrapped__Softmax_device_/job:localhost/replica:0/task:0/device:CPU:0}} Input dims must be <= 5 and >=1 [Op:Softmax] name: 

Call arguments received by layer 'softmax' (type Softmax):
  • inputs=tf.Tensor(shape=(1, 11, 19, 4, 1, 4), dtype=float32)
  • mask=tf.Tensor(shape=(1, 11, 19, 1, 1, 4), dtype=bool)

Do you know which dim is redundant?

BTW, is the 'M1' chip in the runtime table the APPLE M1 or Google cloud M1 ?

Tim-Salzmann commented 1 year ago

Hi Mo,

Could you, in addition also set TF_ENABLE_ONEDNN_OPTS=0?

Could you please outline the procedure of how you installed tensorflow? Conda / pip?

Do you know which dim is redundant?

Unfortunately non of the dimensions is redundant

BTW, is the 'M1' chip in the runtime table the APPLE M1 or Google cloud M1 ?

This is Apple M1

Best Tim

AlfredMoore commented 1 year ago

Hi Tim

Could you, in addition also set TF_ENABLE_ONEDNN_OPTS=0?

Thank you so much. It works :)

Could you please outline the procedure of how you installed tensorflow? Conda / pip?

similar to the codes below

conda create -n env_name python=3.x
pip3 install -r human-scene-transformer/requirement.txt

Btw, could you tell me why it works with TF_ENABLE_ONEDNN_OPTS=0 and TF_DISABLE_MKL=1 ? And then the dim error disappears ?

updated, new Q: I add a marker in the loop. Does this mean the training get stuck and never move forward? Because it has been stuck there for several hours.

I1020 17:57:29.060910 140299564987008 train_model.py:256] Beginning training.
I1020 17:57:29.061050 140299564987008 train_model.py:259] 0
I1020 17:57:29.164841 140299564987008 api.py:460] train_step
I1020 17:57:34.702615 140299564987008 api.py:460] iter
I1020 17:57:35.291868 140299564987008 api.py:460] train_step
I1020 17:57:38.506911 140299564987008 api.py:460] iter
2023-10-20 17:57:51.474034: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 3825205248 exceeds 10% of free system memory.
2023-10-20 17:57:51.608297: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 3825205248 exceeds 10% of free system memory.
2023-10-20 17:57:52.013815: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 3825205248 exceeds 10% of free system memory.
2023-10-20 17:57:52.038450: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 3825205248 exceeds 10% of free system memory.
2023-10-20 17:58:01.058731: W tensorflow/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 3825205248 exceeds 10% of free system memory.

But fortunately, the eval.py runs well with the checkpoint.

3881it [21:54,  2.95it/s]
MinADE: 0.26
MinADE @ 1s: 0.12
MinADE @ 2s: 0.20
MinADE @ 3s: 0.28
MinADE @ 4s: 0.37
MLADE: 0.45
MLADE @ 1s: 0.21
MLADE @ 2s: 0.39
MLADE @ 3s: 0.56
MLADE @ 4s: 0.71
NLL: -0.59
NLL @ 1s: -0.90
NLL @ 2s: -0.65
NLL @ 3s: -0.08
NLL @ 4s: 0.32
Tim-Salzmann commented 1 year ago

Hi Mo,

Glad it works now!

Btw, could you tell me why it works with TF_ENABLE_ONEDNN_OPTS=0 and TF_DISABLE_MKL=1 ? And then the dim error disappears ?

MKL is a tensorflow backend by Intel optimized for their CPUs. Unfortunately, it does not support tensor dimensions > 5 for some operations, making it incompatible with this codebase (We added a note in the readme).

I add a marker in the loop. Does this mean the training get stuck and never move forward? Because it has been stuck there for several hours.

Not necessarily. We combine many training iterations in one tf.function context, which avoids overhead on TPU / GPU. This could take a long time on CPU. You could set the following lines to 1 to have a more fine-grained feedback over training iterations.

https://github.com/google-research/human-scene-transformer/blob/7e9b9278f253b4ba48c1931d6eccaca8362efa4b/human_scene_transformer/config/jrdb/training_params.gin#L7-L8C6

Feel free to close the issue should this have solved your problem.

Best Tim

AlfredMoore commented 1 year ago

Hi Dr. Tim,

It is solved perfectly. Thank you so much! I will close this issue soon.

Best regards, Mo