awslabs / gluonts

Probabilistic time series modeling in Python
https://ts.gluon.ai
Apache License 2.0
4.62k stars 753 forks source link

Export to onnx a gluonts TFT model #2981

Open Charles1710 opened 1 year ago

Charles1710 commented 1 year ago

Description

Hi, I trained a Gluonts TFT model with the MXNet package. When I use it to make forecasts, it works great. My goal is to export this model to ONNX so that I can import it in MATLAB. To export to ONNX, I first need to serialize the model, and then I can use the prediction_net-0000.params and prediction_net-symbol.json file to export to ONNX. I'm not sure if the problem is with the serialization step, or with the model export. After the training, I have a predictor that is a gluonts.mx.model.predictor.RepresentableBlockPredictor object. I create the .json and .params files with:

p = predictor.as_symbol_block_predictor(dataset=training_data)
p.serialize_prediction_net(path=Path(os.path.join(save_path, onnx_dir)))

The first line gives me this warning. I'm not sure if it indicates a problem, but the second line indeed creates the .params and .json files.

[...]\mxnet\gluon\block.py:1512: UserWarning: Cannot decide type for the following arguments. Consider providing them as input:
        data0: None
  input_sym_arg_type = in_param.infer_type()[0]

After, to export to ONNX, here's the lines of code:

sym_p = str(Path(os.path.join(ROOT, 'models', 'TFT', 'ONNX_r_w', 'prediction_net-symbol.json')))
params_p = str(Path(os.path.join(ROOT, 'models', 'TFT', 'ONNX_r_w', 'prediction_net-0000.params')))
input_shape = (32,100,1) # (batch_size, sequence_length, num_features) (32,100,1) ???
onnx_file_path = str(Path(os.path.join(ROOT, 'models', 'TFT', 'ONNX_r_w', 'TFT.onnx')))
onnx_mxnet.export_model(sym_p, params_p, [input_shape], np.float32, onnx_file_path)

Here is the output and error when running the onnx_mxnet.export_model() line:

infer_shape error. Arguments:
  data0: (32, 100, 8)
  temporalfusiontransformerpredictionnetwork0_target_projection_weight: (120, 1)
  temporalfusiontransformerpredictionnetwork0_target_projection_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork1_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork1_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork1_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork1_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork1_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork1_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork1_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork1_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_past_feat_dynamic_real_0_projection_weight: (120, 1)
  temporalfusiontransformerpredictionnetwork0_past_feat_dynamic_real_0_projection_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork2_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork2_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork2_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork2_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork2_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork2_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork2_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork2_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_past_feat_dynamic_cat_0_embedding_weight: (1, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork3_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork3_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork3_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork3_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork3_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork3_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork3_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork3_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_real_0_projection_weight: (120, 1)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_real_0_projection_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork4_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork4_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork4_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork4_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork4_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork4_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork4_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork4_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_real_1_projection_weight: (120, 1)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_real_1_projection_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork5_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork5_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork5_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork5_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork5_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork5_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork5_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork5_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_real_2_projection_weight: (120, 1)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_real_2_projection_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork6_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork6_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork6_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork6_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork6_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork6_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork6_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork6_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_real_3_projection_weight: (120, 1)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_real_3_projection_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork7_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork7_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork7_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork7_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork7_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork7_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork7_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork7_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_real_4_projection_weight: (120, 1)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_real_4_projection_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork8_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork8_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork8_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork8_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork8_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork8_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork8_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork8_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_real_5_projection_weight: (120, 1)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_real_5_projection_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork9_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork9_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork9_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork9_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork9_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork9_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork9_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork9_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_feat_dynamic_cat_0_embedding_weight: (1, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork10_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork10_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork10_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork10_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork10_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork10_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork10_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork10_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_feat_static_real_0_projection_weight: (120, 1)
  temporalfusiontransformerpredictionnetwork0_feat_static_real_0_projection_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork1_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork1_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork1_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork1_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork1_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork1_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork1_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork1_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_feat_static_cat_0_embedding_weight: (1, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork2_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork2_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork2_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork2_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork2_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork2_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork2_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork2_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork0_dense1_weight: (120, 240)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork0_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork0_dense2_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork0_dense2_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork0_dense3_weight: (4, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork0_dense3_bias: (4,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork0_dense0_weight: (2, 240)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork0_dense0_bias: (2,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork0_layernorm0_gamma: (2,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork0_gatedresidualnetwork0_layernorm0_beta: (2,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork0_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork0_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork0_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork0_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork0_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork0_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork0_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork0_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork0_dense1_weight: (120, 1320)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork0_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork0_dense2_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork0_dense2_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork0_dense3_weight: (20, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork0_dense3_bias: (20,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork0_dense0_weight: (10, 1200)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork0_dense0_bias: (10,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork0_layernorm0_gamma: (10,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork0_layernorm0_beta: (10,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork1_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork1_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork1_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork1_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork1_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork1_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork1_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork1_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork2_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork2_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork2_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork2_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork2_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork2_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork2_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork2_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork3_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork3_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork3_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork3_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork3_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork3_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork3_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork3_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork4_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork4_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork4_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork4_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork4_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork4_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork4_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork4_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork5_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork5_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork5_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork5_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork5_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork5_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork5_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork5_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork6_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork6_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork6_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork6_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork6_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork6_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork6_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork6_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork7_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork7_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork7_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork7_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork7_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork7_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork7_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork7_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork0_dense1_weight: (120, 960)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork0_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork0_dense2_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork0_dense2_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork0_dense3_weight: (14, 120)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork0_dense3_bias: (14,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork0_dense0_weight: (7, 840)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork0_dense0_bias: (7,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork0_layernorm0_gamma: (7,)
  temporalfusiontransformerpredictionnetwork0_variableselectionnetwork2_gatedresidualnetwork0_layernorm0_beta: (7,)
  temporalfusiontransformerpredictionnetwork0_temporalfusionencoder0_lstm0_i2h_weight: (480, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusionencoder0_lstm0_i2h_bias: (480,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork2_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork2_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork2_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork2_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork2_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork2_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork2_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork2_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusionencoder0_lstm0_h2h_weight: (480, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusionencoder0_lstm0_h2h_bias: (480,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork3_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork3_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork3_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork3_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork3_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork3_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork3_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork3_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusionencoder0_lstm1_i2h_weight: (480, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusionencoder0_lstm1_i2h_bias: (480,)
  temporalfusiontransformerpredictionnetwork0_temporalfusionencoder0_lstm1_h2h_weight: (480, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusionencoder0_lstm1_h2h_bias: (480,)
  temporalfusiontransformerpredictionnetwork0_temporalfusionencoder0_dense0_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusionencoder0_dense0_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_temporalfusionencoder0_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusionencoder0_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork1_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork1_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork1_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork1_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork1_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork1_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork1_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_gatedresidualnetwork1_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork0_dense0_weight: (120, 240)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork0_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork0_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork0_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork0_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork0_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork0_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork0_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_selfattention0_q_proj_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_selfattention0_q_proj_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_selfattention0_k_proj_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_selfattention0_k_proj_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_selfattention0_v_proj_weight: (12, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_selfattention0_v_proj_bias: (12,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_selfattention0_out_proj_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_selfattention0_out_proj_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_dense0_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_dense0_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork1_dense0_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork1_dense0_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork1_dense1_weight: (120, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork1_dense1_bias: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork1_dense2_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork1_dense2_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork1_layernorm0_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_gatedresidualnetwork1_layernorm0_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_dense1_weight: (240, 120)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_dense1_bias: (240,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_layernorm1_gamma: (120,)
  temporalfusiontransformerpredictionnetwork0_temporalfusiondecoder0_layernorm1_beta: (120,)
  temporalfusiontransformerpredictionnetwork0_dense0_weight: (3, 120)
  temporalfusiontransformerpredictionnetwork0_dense0_bias: (3,)
Traceback (most recent call last):
  File "C:\Users\techm\anaconda3\envs\onnx_env\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "C:\Users\techm\anaconda3\envs\onnx_env\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "c:\Users\techm\.vscode\extensions\ms-python.python-2023.14.0\pythonFiles\lib\python\debugpy\adapter/../..\debugpy\launcher/../..\debugpy\__main__.py", line 39, in <module>
    cli.main()
  File "c:\Users\techm\.vscode\extensions\ms-python.python-2023.14.0\pythonFiles\lib\python\debugpy\adapter/../..\debugpy\launcher/../..\debugpy/..\debugpy\server\cli.py", line 430, in main
    run()
  File "c:\Users\techm\.vscode\extensions\ms-python.python-2023.14.0\pythonFiles\lib\python\debugpy\adapter/../..\debugpy\launcher/../..\debugpy/..\debugpy\server\cli.py", line 284, in run_file
    runpy.run_path(target, run_name="__main__")
  File "c:\Users\techm\.vscode\extensions\ms-python.python-2023.14.0\pythonFiles\lib\python\debugpy\_vendored\pydevd\_pydevd_bundle\pydevd_runpy.py", line 322, in run_path      
    pkg_name=pkg_name, script_name=fname)
  File "c:\Users\techm\.vscode\extensions\ms-python.python-2023.14.0\pythonFiles\lib\python\debugpy\_vendored\pydevd\_pydevd_bundle\pydevd_runpy.py", line 136, in _run_module_code
    mod_name, mod_spec, pkg_name, script_name)
  File "c:\Users\techm\.vscode\extensions\ms-python.python-2023.14.0\pythonFiles\lib\python\debugpy\_vendored\pydevd\_pydevd_bundle\pydevd_runpy.py", line 124, in _run_code     
    exec(code, run_globals)
  File "c:\Users\techm\Desktop\LTI informatique et génie\948\AT20\PIP\pip\toOnnx.py", line 73, in <module>
    onnx_mxnet.export_model(sym_p, params_p, [input_shape], np.float32, onnx_file_path)
  File "C:\Users\techm\anaconda3\envs\onnx_env\lib\site-packages\mxnet\contrib\onnx\mx2onnx\export_model.py", line 79, in export_model
    verbose=verbose)
  File "C:\Users\techm\anaconda3\envs\onnx_env\lib\site-packages\mxnet\contrib\onnx\mx2onnx\export_onnx.py", line 207, in create_onnx_graph_proto
    graph_outputs = MXNetGraph.get_outputs(sym, params, in_shape, output_label)
  File "C:\Users\techm\anaconda3\envs\onnx_env\lib\site-packages\mxnet\contrib\onnx\mx2onnx\export_onnx.py", line 138, in get_outputs
    _, out_shapes, _ = sym.infer_shape(**inputs)
  File "C:\Users\techm\anaconda3\envs\onnx_env\lib\site-packages\mxnet\symbol\symbol.py", line 1101, in infer_shape
    res = self._infer_shape_impl(False, *args, **kwargs)
  File "C:\Users\techm\anaconda3\envs\onnx_env\lib\site-packages\mxnet\symbol\symbol.py", line 1265, in _infer_shape_impl
    ctypes.byref(complete)))
  File "C:\Users\techm\anaconda3\envs\onnx_env\lib\site-packages\mxnet\base.py", line 246, in check_call
    raise get_last_ffi_error()
mxnet.base.MXNetError: MXNetError: Error in operator temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork1_layernorm0_layernorm0: [09:08:17] C:\Jenkins\workspace\mxnet-tag\mxnet\src\operator\nn\layer_norm.cc:47: Check failed: axis >= 0 && axis < dshape.ndim(): Channel axis out of range: axis=-1

It raises an error with the axis related to the _temporalfusiontransformerpredictionnetwork0_variableselectionnetwork1_gatedresidualnetwork1_layernorm0layernorm0 located in the prediction_net-symbol.json file. I'm not sure if it's a problem with the serialization, or the ONNX export. As suggests the first line of the error message, there is a problem with the _inputshape parameter. The model takes as input a (100,) time series vector in a PandasDataset. I have no feature. What is the input shape of such type of model? I can't find similar situation on internet. Is it (batch_size, sequence_length, num_features)?

The parameters.json file created with the p.serialize_prediction_net() function has the following information:

{"batch_size": 32, "ctx": {"__kind__": "instance", "args": ["cpu", 0], "class": "mxnet.context.Context"}, "dtype": {"__kind__": "type", "class": "numpy.float32"}, "forecast_generator": {"__kind__": "instance", "args": [], "class": "gluonts.model.forecast_generator.QuantileForecastGenerator", "kwargs": {"quantiles": ["0.5", "0.1", "0.9"]}}, "input_names": ["past_target", "past_observed_values", "past_feat_dynamic_real", "past_feat_dynamic_cat", "feat_dynamic_real", "feat_dynamic_cat", "feat_static_real", "feat_static_cat"], "lead_time": 0, "prediction_length": 50}

Can we infer the model input shape from this? Thank you for your help!

Environment

Charles1710 commented 1 year ago

Does anyone have the same problem as me and/or know how to fix it?