Open srush opened 6 months ago
@srush, thank you for your interest! Your paper seems like a very interesting use case. Your errors were caused by a bug on T4 GPUs, but we've just released a patch to resolve it (#16). It's important to note that for T4 GPUs, you should run pip3 install "triton>=2.2.0"
and manually set a memory parameter for your model (TheBloke/Mistral-7B-v0.1-AWQ).
You can find an example on Colab here: https://colab.research.google.com/drive/13lOJt8uFYZJetqQIudAlK8oJJX8PENNk?usp=sharing.
By utilizing the inference endpoint, our system will automatically handle KV cache reuse for you. As our project is still in its early stages, we're eager to test more use cases to evaluate our cache and scheduling policies. Once you have your implementation set up, we'd be very interested in reviewing it to assess our system's performance and identify potential improvements. Please let us know if you have any questions.
Wow! This library is so good. I implemented the entire training procedure in 1/10 the lines of code and it is way faster (oh actually maybe spoke too soon, progress bar was maybe off).
from sglang import function, system, user, assistant, gen, set_default_backend, RuntimeEndpoint
@function
def text_qa(s, question, prompt):
s += prompt
s += "Input:" + question + gen("answer", max_tokens=5, regex="(positive)|(negative)", stop=None)
set_default_backend(RuntimeEndpoint("http://localhost:30000"))
states = text_qa.run_batch(
[{"question": x, "prompt" : pt.state.prompts[j]}
for x in X_train_text
for j in range(len(pt.state.prompts)) ],
progress_bar=True
)
print(states)
Going to try to do inference now.
The bug I was running into is that if you start a run_batch job in colab it seems to keep running even if you interrupt. I guess this is because it is spawning threads behind the scenes? Any easy way to kill it for real?
In my example of I killed and started again, there were two different progress bars showing up.
@srush We haven't tested extensively for the Colab environment. I just noticed that Colab only has two CPU cores, so run_batch
will only use 2 threads to submit the requests, which can cause a huge bottleneck.
Could you try the following code? It should make the execution much faster. Please let us know what performance you get and what is your baseline system. We should expect a very decent speedup compared to HF transformer and vllm. If you did not get a significant speedup, we can take a closer look at your notebook if you can share it.
states = text_qa.run_batch(
[{"question": x, "prompt" : pt.state.prompts[j]}
for x in X_train_text
for j in range(len(pt.state.prompts)) ],
num_threads=64, # you can also try 32 or 128
progress_bar=True
)
For the progress bar error, you are right. I did not have a good solution for it. I think colab is not friendly for this kind of multi-threading/multi-processing execution mode.
On our dev machines, we use this script to kill all python processes, but it will kill the colab runtime as well.
Thanks! I will move to a better machine and do the comparison for real. Just wanted to make sure I wasn't crazy.
@srush For your use case, you can also try sgl.gen(choices=['positive', 'negtive'])
. I am not sure which one gives the higher accuracy or better speed.
choices
argument in sgl.gen
is implemented by computing the normalized log probabilities of all choices and selecting the one with the highest probability.regex
argument in sgl.gen
is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex.hey, Ive been trying to use the colab notebook linked in this issue (im on Mac so im not able to use my device ) but im getting an error for connecting the backend on executing : import requests
response = requests.post( "http://localhost:30000/generate", json={ "text": "The capital of France is", "sampling_params": { "temperature": 0, "max_new_tokens": 32, }, }, ) print(response.json()) I get: ConnectionRefusedError Traceback (most recent call last) /usr/local/lib/python3.10/dist-packages/urllib3/connection.py in _new_conn(self) 202 try: --> 203 sock = connection.create_connection( 204 (self._dns_host, self.port),
18 frames ConnectionRefusedError: [Errno 111] Connection refused
The above exception was the direct cause of the following exception:
NewConnectionError Traceback (most recent call last) NewConnectionError: <urllib3.connection.HTTPConnection object at 0x79158fea2bf0>: Failed to establish a new connection: [Errno 111] Connection refused
The above exception was the direct cause of the following exception:
MaxRetryError Traceback (most recent call last) MaxRetryError: HTTPConnectionPool(host='localhost', port=30000): Max retries exceeded with url: /generate (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x79158fea2bf0>: Failed to establish a new connection: [Errno 111] Connection refused'))
During handling of the above exception, another exception occurred:
ConnectionError Traceback (most recent call last) /usr/local/lib/python3.10/dist-packages/requests/adapters.py in send(self, request, stream, timeout, verify, cert, proxies) 517 raise SSLError(e, request=request) 518 --> 519 raise ConnectionError(e, request=request) 520 521 except ClosedPoolError as e:
ConnectionError: HTTPConnectionPool(host='localhost', port=30000): Max retries exceeded with url: /generate (Caused by NewConnectionError('<urllib3.connection.HTTPConnection object at 0x79158fea2bf0>: Failed to establish a new connection: [Errno 111] Connection refused'))
how do I get this to work on colab
@Kamakshi8104 Please wait until the server has been launched successfully. It may take some time to download the weights and set up. You can use the last cell !tail nohup.out
to monitor the progress.
Awesome project. We have a paper https://arxiv.org/abs/2310.14034 with really complicated KV caching that I would love to go back and implement in SGLang.
I tried to get an example working in Colab for a demo, but I got kind of stuck getting the server running.
This runs fine:
!nohup python -m sglang.launch_server --model-path TheBloke/Mistral-7B-v0.1-AWQ --port 30000
But then when I run the following,
I just get this.
Any ideas?