I'm running the OpenVLA policy on a machine with RTX4090 with 24GB GPU memory, but the inference is only about 4.4 actions/s after warmup. Am I doing something wrong here?
GPU info:
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 4090 On | 00000000:81:00.0 Off | Off |
| 0% 47C P2 226W / 450W | 15296MiB / 24564MiB | 67% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
CPU info:
lscpu [±main ●]
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 43 bits physical, 48 bits virtual
CPU(s): 48
On-line CPU(s) list: 0-47
Thread(s) per core: 2
Core(s) per socket: 24
Socket(s): 1
NUMA node(s): 1
Vendor ID: AuthenticAMD
CPU family: 23
Model: 49
Model name: AMD EPYC 7402P 24-Core Processor
Stepping: 0
Frequency boost: enabled
CPU MHz: 1500.000
CPU max MHz: 2800.0000
CPU min MHz: 1500.0000
BogoMIPS: 5600.33
Virtualization: AMD-V
L1d cache: 768 KiB
L1i cache: 768 KiB
L2 cache: 12 MiB
L3 cache: 128 MiB
NUMA node0 CPU(s): 0-47
Evaluation code:
from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image
from tqdm import tqdm
import requests
import torch
processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
"openvla/openvla-7b",
attn_implementation="flash_attention_2", # [Optional] Requires `flash_attn`
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
device_map="cuda:0",
use_safetensors=True,
).to("cuda")
image_url = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_30.jpg"
image = Image.open(requests.get(image_url, stream=True).raw).resize((256, 256))
prompt = "In: What action should the robot take to {<INSTRUCTION>}?\nOut:"
# Predict Action (7-DoF; un-normalize for BridgeData V2)
for i in tqdm(range(50)):
inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
# tqdm.write(str(action))
# print(action)
Output:
/root/miniconda3/envs/openvla/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
/root/miniconda3/envs/openvla/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:04<00:00, 1.35s/it]
/root/miniconda3/envs/openvla/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:12<00:00, 4.03it/s]
I'm running the OpenVLA policy on a machine with RTX4090 with 24GB GPU memory, but the inference is only about 4.4 actions/s after warmup. Am I doing something wrong here?
GPU info:
CPU info:
Evaluation code:
Output: