xenova / transformers.js

State-of-the-art Machine Learning for the web. Run 🤗 Transformers directly in your browser, with no need for a server!
https://huggingface.co/docs/transformers.js
Apache License 2.0
9.71k stars 571 forks source link

Add support for decision transformer (Closes #794) #795

Open xenova opened 2 weeks ago

xenova commented 2 weeks ago

Example code:

import { AutoModel, Tensor } from '@xenova/transformers';

// Load model
const model_id = 'onnx-community/decision-transformer-gym-hopper-expert';
const model = await AutoModel.from_pretrained(model_id, { quantized: false });

// Helper function to generate random tensor
function rand(dims) {
    const data = Float32Array.from({ length: dims.reduce((a, b) => a * b) }, () => Math.random());
    return new Tensor('float32', data, dims);
}

// Define config
const batch_size = 2;
const episode_length = 16;
const state_dim = model.config.state_dim;
const act_dim = model.config.act_dim;

// Generate random input
const states = rand([batch_size, episode_length, state_dim]);
const actions = rand([batch_size, episode_length, act_dim]);
const rewards = rand([batch_size, episode_length, 1]);
const returns_to_go = rand([batch_size, episode_length, 1]);
const timesteps = new Tensor('int64', new BigInt64Array([BigInt(episode_length)]), [1, 1]);
const attention_mask = rand([batch_size, episode_length]);

// Call model
const input = { states, actions, rewards, returns_to_go, timesteps, attention_mask };
const output = await model(input);
console.log(output);
// {
//     state_preds: Tensor {
//         dims: [2, 16, 11],
//         type: 'float32',
//         data: Float32Array(352)[ ... ],
//         size: 352
//     },
//     action_preds: Tensor {
//         dims: [2, 16, 3],
//         type: 'float32',
//         data: Float32Array(96)[ ... ],
//         size: 96
//     },
//     return_preds: Tensor {
//         dims: [2, 16, 1],
//         type: 'float32',
//         data: Float32Array(32)[ ... ],
//         size: 32
//     },
//     last_hidden_state: Tensor {
//         dims: [2, 48, 128],
//         type: 'float32',
//         data: Float32Array(12288)[ ... ],
//         size: 12288
//     }
// }

Export models to ONNX:

Requirements:

pip install transformers onnx==1.13.1

Code:

import torch
from transformers import DecisionTransformerModel

# 1. Load model
model_id = "edbeeching/decision-transformer-gym-hopper-medium"
model = DecisionTransformerModel.from_pretrained(model_id)

# 2. Define inputs
# states (`torch.FloatTensor` of shape `(batch_size, episode_length, state_dim)`):
#     The states for each step in the trajectory
# actions (`torch.FloatTensor` of shape `(batch_size, episode_length, act_dim)`):
#     The actions taken by the "expert" policy for the current state, these are masked for auto regressive
#     prediction
# rewards (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
#     The rewards for each state, action
# returns_to_go (`torch.FloatTensor` of shape `(batch_size, episode_length, 1)`):
#     The returns for each state in the trajectory
# timesteps (`torch.LongTensor` of shape `(batch_size, episode_length)`):
#     The timestep for each step in the trajectory
# attention_mask (`torch.FloatTensor` of shape `(batch_size, episode_length)`):
#     Masking, used to mask the actions when performing autoregressive prediction
batch_size = 2
episode_length = 16
state_dim = model.config.state_dim
act_dim = model.config.act_dim

states=torch.randn((batch_size, episode_length, state_dim))
actions=torch.randn((batch_size, episode_length, act_dim))
rewards=torch.randn((batch_size, episode_length, 1))
returns_to_go=torch.randn((batch_size, episode_length, 1))
timesteps=torch.tensor(0, dtype=torch.long).reshape(1, 1)
attention_mask=torch.randn((batch_size, episode_length))

# 3. Define outputs
# last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
#     Sequence of hidden-states at the output of the last layer of the model.
# state_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, state_dim)`):
#     Environment state predictions
# action_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, action_dim)`):
#     Model action predictions
# return_preds (`torch.FloatTensor` of shape `(batch_size, sequence_length, 1)`):
#     Predicted returns for each state
# hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
#     Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
#     shape `(batch_size, sequence_length, hidden_size)`.

#     Hidden-states of the model at the output of each layer plus the initial embedding outputs.
# attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
#     Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
#     sequence_length)`.

#     Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
#     heads.

dynamic_axes = {0 : 'batch_size', 1: 'episode_length'}

# 4. Export the model
torch.onnx.export(model, # model being run
                  (states, actions, rewards, returns_to_go, timesteps, attention_mask), # model input (or a tuple for multiple inputs)
                  "model.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=13,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['states', 'actions', 'rewards', 'returns_to_go', 'timesteps', 'attention_mask'],   # the model's input names
                  output_names = ['state_preds', 'action_preds', 'return_preds', 'last_hidden_state'], # the model's output names
                  dynamic_axes={
                      'states' : dynamic_axes,
                      'actions' : dynamic_axes,
                      'rewards' : dynamic_axes,
                      'returns_to_go' : dynamic_axes,
                      'timesteps' : dynamic_axes,
                      'attention_mask' : dynamic_axes,

                      'state_preds' : dynamic_axes,
                      'action_preds' : dynamic_axes,
                      'return_preds' : dynamic_axes,
                      'last_hidden_state' : dynamic_axes,
                  }
)
HuggingFaceDocBuilderDev commented 2 weeks ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.