openvla / openvla

OpenVLA: An open-source vision-language-action model for robotic manipulation.
MIT License
984 stars 127 forks source link

Unable to Reproduce 6 Actions/s Inference on RTX4090 #66

Open Depetrol opened 1 month ago

Depetrol commented 1 month ago

I'm running the OpenVLA policy on a machine with RTX4090 with 24GB GPU memory, but the inference is only about 4.4 actions/s after warmup. Am I doing something wrong here?

GPU info:

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 4090        On  | 00000000:81:00.0 Off |                  Off |
|  0%   47C    P2             226W / 450W |  15296MiB / 24564MiB |     67%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

CPU info:

lscpu                                                                                                                         [±main ●]
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      43 bits physical, 48 bits virtual
CPU(s):                             48
On-line CPU(s) list:                0-47
Thread(s) per core:                 2
Core(s) per socket:                 24
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          AuthenticAMD
CPU family:                         23
Model:                              49
Model name:                         AMD EPYC 7402P 24-Core Processor
Stepping:                           0
Frequency boost:                    enabled
CPU MHz:                            1500.000
CPU max MHz:                        2800.0000
CPU min MHz:                        1500.0000
BogoMIPS:                           5600.33
Virtualization:                     AMD-V
L1d cache:                          768 KiB
L1i cache:                          768 KiB
L2 cache:                           12 MiB
L3 cache:                           128 MiB
NUMA node0 CPU(s):                  0-47

Evaluation code:

from transformers import AutoModelForVision2Seq, AutoProcessor
from PIL import Image
from tqdm import tqdm
import requests
import torch

processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)
vla = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b", 
    attn_implementation="flash_attention_2",  # [Optional] Requires `flash_attn`
    torch_dtype=torch.bfloat16, 
    low_cpu_mem_usage=True, 
    trust_remote_code=True,
    device_map="cuda:0",
    use_safetensors=True,
).to("cuda")

image_url = "https://rail.eecs.berkeley.edu/datasets/bridge_release/raw/bridge_data_v2/datacol2_toykitchen7/drawer_pnp/01/2023-04-19_09-18-15/raw/traj_group0/traj0/images0/im_30.jpg"
image = Image.open(requests.get(image_url, stream=True).raw).resize((256, 256))
prompt = "In: What action should the robot take to {<INSTRUCTION>}?\nOut:"

# Predict Action (7-DoF; un-normalize for BridgeData V2)
for i in tqdm(range(50)):
    inputs = processor(prompt, image).to("cuda:0", dtype=torch.bfloat16)
    action = vla.predict_action(**inputs, unnorm_key="bridge_orig", do_sample=False)
    # tqdm.write(str(action))
# print(action)

Output:

/root/miniconda3/envs/openvla/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
/root/miniconda3/envs/openvla/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:04<00:00,  1.35s/it]
/root/miniconda3/envs/openvla/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:12<00:00,  4.03it/s]
marneneha commented 3 weeks ago

You are not putting prompt in the code, try, "prompt = In: What action should the robot take to push the black handle back?\nOut:"

Depetrol commented 3 weeks ago

I tried this, but the inference speed is still maxing out at 4.4 actions/s.