vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
28.17k stars 4.16k forks source link

[issue tracker] make quantization compatible with dynamo dynamic shape #9234

Open youkaichao opened 1 day ago

youkaichao commented 1 day ago

Anything you want to discuss about vllm.

here is a simple demo code:

import torch
from torch.utils.cpp_extension import load_inline

custom_library = torch.library.Library("custom", "DEF")
custom_library.define("add_cpp(Tensor x, int y) -> Tensor")

cpp_source = """
#include <torch/extension.h>

torch::Tensor custom_add(torch::Tensor x, int64_t y) {
    return x + y;
}

TORCH_LIBRARY_IMPL(custom, CPU, m) {
    m.impl("add_cpp", custom_add);
}
"""

custom_op = load_inline(
    name="custom_op",
    cpp_sources=cpp_source,
    extra_cflags=[],
    functions=["custom_add"]
)

@torch.library.register_fake("custom::add_cpp")
def _(x: torch.Tensor, y: int) -> torch.Tensor:
    return torch.empty((y,), dtype=torch.float32)

import torch

@torch.library.custom_op("custom::add_py", mutates_args=[])
def add_py(x: torch.Tensor, y: int) -> torch.Tensor:
    return x + y

@add_py.register_fake
def _(x: torch.Tensor, y: int) -> torch.Tensor:
    return torch.empty((y,), dtype=torch.float32)

@torch.compile(backend="eager", fullgraph=True)
def f(x):
    # return torch.ops.custom.add_py(x, x.shape[0]) # passes
    return torch.ops.custom.add_cpp(x, x.shape[0]) # errors with `Not all values of RelaxedUnspecConstraint(L['x'].size()[0]) are valid because L['x'].size()[0] was inferred to be a constant (2).`

x = torch.ones(2, 4)
torch._dynamo.mark_dynamic(x, 0)
print(f(x)[0])

when we register the custom op from c++ side, dynamic shape will be directly specialized to an integer, and fail. when we register the custom op from Python side, dynamic shape works as expected.

we should change the way we register quantization as custom ops, from c++ side to python side.

there's also one complicated object https://github.com/vllm-project/vllm/blob/f3a507f1d31e13a99c4fc8ac02738a73c3e3136f/vllm/scalar_type.py#L15 that appears in the custom op parameter :

https://github.com/vllm-project/vllm/blob/f3a507f1d31e13a99c4fc8ac02738a73c3e3136f/vllm/_custom_ops.py#L315-L321

we can use strings to represent the type, and look up the actual object to pass into the c++ function.

Before submitting a new issue...

youkaichao commented 1 day ago

cc @bnellnm

youkaichao commented 1 day ago

there's a related issue in pytorch https://github.com/pytorch/pytorch/issues/112883 , and the comment seems to be that pytorch will not fix it in the near future.

I tested it in pytorch nightly (2.6.0.dev20241004) , it still has this problem.

bnellnm commented 13 hours ago

I was able to workaround the problem by modifying the schemas to take SymInts. I'll look into the scalar_type issue.

import torch
from torch.utils.cpp_extension import load_inline

custom_library = torch.library.Library("custom", "DEF")
custom_library.define("add_cpp(Tensor x, SymInt y) -> Tensor")

cpp_source = """                                                                                                                                    
#include <torch/extension.h>                                                                                                                        

torch::Tensor custom_add(torch::Tensor x, int64_t y) {                                                                                              
    return x + y;                                                                                                                                   
}                                                                                                                                                   

TORCH_LIBRARY_IMPL(custom, CPU, m) {                                                                                                                
    m.impl("add_cpp", custom_add);                                                                                                                  
}                                                                                                                                                   
"""

custom_op = load_inline(
    name="custom_op",
    cpp_sources=cpp_source,
    extra_cflags=[],
    functions=["custom_add"]
)

@torch.library.register_fake("custom::add_cpp")
def _(x: torch.Tensor, y: torch.SymInt) -> torch.Tensor:
    return torch.empty((y,), dtype=torch.float32)

import torch

@torch.library.custom_op("custom::add_py", mutates_args=[])
def add_py(x: torch.Tensor, y: int) -> torch.Tensor:
    return x + y

@add_py.register_fake
def _(x: torch.Tensor, y: int) -> torch.Tensor:
    return torch.empty((y,), dtype=torch.float32)

@torch.compile(backend="eager", fullgraph=True)
def f(x):
    # return torch.ops.custom.add_py(x, x.shape[0]) # passes                                                                                        
    return torch.ops.custom.add_cpp(x, x.shape[0]) # errors with `Not all values of RelaxedUnspecConstraint(L['x'].size()[0]) are valid because L['x'].size()[0] was inferred to be a constant (2).`                                                                                                   

x = torch.ones(2, 4)
torch._dynamo.mark_dynamic(x, 0)
print(f(x)[0])
bnellnm commented 12 hours ago

This also works.

custom_library = torch.library.Library("custom", "DEF")

cpp_source = """                                                                                                                                    
#include <torch/extension.h>                                                                                                                        

torch::Tensor custom_add(torch::Tensor x, int64_t y) {                                                                                              
    return x + y;                                                                                                                                   
}                                                                                                                                                   

TORCH_LIBRARY_FRAGMENT(custom, m)                                                                                                                   
{                                                                                                                                                   
    m.def("add_cpp(Tensor x, SymInt y) -> Tensor");                                                                                                 
    m.impl("add_cpp", torch::kCPU, custom_add);                                                                                                     
}                                                                                                                                                   
"""

custom_op = load_inline(
    name="custom_op",
    cpp_sources=cpp_source,
    extra_cflags=[],
    functions=["custom_add"]
)

@torch.library.register_fake("custom::add_cpp")
def _(x: torch.Tensor, y: torch.SymInt) -> torch.Tensor:
    return torch.empty((y,), dtype=torch.float32)

I think the ScalarType problem is orthogonal to the SymInt problem.

youkaichao commented 11 hours ago

I think the ScalarType problem is orthogonal to the SymInt problem.

yes, they are two separate problems. for dynamo dynamic shape to understand quantization ops, both problems need to be solved.