import torch
from transformers import DecisionTransformerModel
# 1. Load model
model_id = "edbeeching/decision-transformer-gym-hopper-medium"
model = DecisionTransformerModel.from_pretrained(model_id)
# 2. Define inputs
# states (`torch.FloatTensor` of shape `(batch_size, episode_length, state_dim)`):
# The states for each step in the trajectory
# actions (`torch.FloatTensor` of shape `(batch_size, episode_length, act_dim)`):
# The actions taken by the "expert" policy for the current state, these are masked for auto regressive
# prediction
# rewards (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
# The rewards for each state, action
# returns_to_go (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
# The returns for each state in the trajectory
# timesteps (`torch.LongTensor` of shape `(batch_size, episode_length)`):
# The timestep for each step in the trajectory
# attention_mask (`torch.FloatTensor` of shape `(batch_size, episode_length)`):
# Masking, used to mask the actions when performing autoregressive prediction
batch_size = 2
episode_length = 16
state_dim = model.config.state_dim
act_dim = model.config.act_dim
states=torch.randn((batch_size, episode_length, state_dim))
actions=torch.randn((batch_size, episode_length, act_dim))
rewards=torch.randn((batch_size, episode_length, 1))
returns_to_go=torch.randn((batch_size, episode_length, 1))
timesteps=torch.tensor(0, dtype=torch.long).reshape(1, 1)
attention_mask=torch.randn((batch_size, episode_length))
# 3. Define outputs
# last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
# Sequence of hidden-states at the output of the last layer of the model.
# state_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, state_dim)`):
# Environment state predictions
# action_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, action_dim)`):
# Model action predictions
# return_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, 1)`):
# Predicted returns for each state
# hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
# Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
# shape `(batch_size, sequence_length, hidden_size)`.
# Hidden-states of the model at the output of each layer plus the initial embedding outputs.
# attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
# Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
# sequence_length)`.
# Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
# heads.
dynamic_axes = {0 : 'batch_size', 1: 'episode_length'}
# 4. Export the model
torch.onnx.export(model, # model being run
(states, actions, rewards, returns_to_go, timesteps, attention_mask), # model input (or a tuple for multiple inputs)
"model.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=13, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['states', 'actions', 'rewards', 'returns_to_go', 'timesteps', 'attention_mask'], # the model's input names
output_names = ['state_preds', 'action_preds', 'return_preds', 'last_hidden_state'], # the model's output names
dynamic_axes={
'states' : dynamic_axes,
'actions' : dynamic_axes,
'rewards' : dynamic_axes,
'returns_to_go' : dynamic_axes,
'timesteps' : dynamic_axes,
'attention_mask' : dynamic_axes,
'state_preds' : dynamic_axes,
'action_preds' : dynamic_axes,
'return_preds' : dynamic_axes,
'last_hidden_state' : dynamic_axes,
}
)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Example code:
Export models to ONNX:
Requirements:
Code: