Open saifmodan opened 3 months ago
it's much better for me to understand what you mean by posting the error stack you get on your end. I've tested the code on 2 different machines, works fine. I believe it could be an issue with your environment. If you can post the exact command you are running, and the error you see maybe i can help you.Also, remember you need an graphics card with at least 6-8GB of ram. But I can add an option to run on cpu, but it will be slow.
please describe code video explanation i will understand
Assuming you have a graphics card and installed the dependencies, and you want it to write code locally, you can start the server like so: python server.py --model mistralai/Mistral-7B-Instruct-v0.2
- this will start to download the mistral 8B instruct v0.2 model. when that is done the server will start up. After that is done, you can start the client on the same computer like so: python client.py
- this will connect to the server at localhost on port 8000. Now you should see the chat interface and you can ask it to write code and it will be highlighted.
please solve the problem in code
import torch import uvicorn import gc import asyncio import argparse import io from fastapi import FastAPI, WebSocket, Depends from fastapi.responses import Response from threading import Thread from sqlalchemy.orm import Session as DBSession from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, TextIteratorStreamer, pipeline from diffusers import StableDiffusionXLPipeline, DPMSolverSinglestepScheduler, AutoencoderTiny
from session import Session, SessionManager, SessionDB, SessionImageDB, get_db
app = FastAPI() session_manager = SessionManager()
parser = argparse.ArgumentParser() parser.add_argument('--model', action='store', default="meta-llama/Meta-Llama-3-8B-Instruct") parser.add_argument('--port', action='store', default=8000) parser.add_argument('--image_generation', action='store_true', default=False) parser.add_argument('--image_model', action='store', default="sd-community/sdxl-flash") parser.add_argument('--image_cpu_offload', action='store_true', default=False) args = parser.parse_args()
bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) config = AutoConfig.from_pretrained(args.model) model = AutoModelForCausalLM.from_pretrained( args.model, device_map='auto', config=config, quantization_config=bnb_config, attn_implementation="flash_attention_2" ) tokenizer = AutoTokenizer.from_pretrained(args.model) terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(""), ]
summarizer = pipeline( task="summarization", model="facebook/bart-large-cnn", min_length=2, max_length=10, do_sample=True, temperature=0.6, top_p=0.9, )
if args.image_generation: sdxl_pipe = StableDiffusionXLPipeline.from_pretrained(args.image_model, torch_dtype=torch.float16) sdxl_pipe.scheduler = DPMSolverSinglestepScheduler.from_config(sdxl_pipe.scheduler.config, timestep_spacing="trailing") sdxl_pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16) sdxl_pipe.enable_vae_tiling() sdxl_pipe.enable_vae_slicing() if args.image_cpu_offload: sdxl_pipe.enable_sequential_cpu_offload()
async def stream_tokens(streamer: TextIteratorStreamer): for token in streamer: yield token yield None
async def generate_response(prompt: str): torch.cuda.empty_cache() gc.collect() inputs = tokenizer(prompt, return_tensors="pt").to("cuda") streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generation_kwargs = { "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], "streamer": streamer, "do_sample": True, "temperature": 0.6, "top_p": 0.9, "max_length": config.max_position_embeddings, }
def make_title(session: Session): messages = session.get_messages()[-2:] prompt = "\n".join([message["content"] for message in messages]) return summarizer(prompt)
def make_prompt(session: Session): inputs = tokenizer.apply_chat_template( session.get_messages(), add_generation_prompt=True, return_tensors="pt", tokenize=True ) num_tokens = inputs.shape[-1] if num_tokens > int(config.max_position_embeddings * 0.9): session.truncate_messages() return make_prompt(session) else: return tokenizer.apply_chat_template( session.get_messages(), add_generation_prompt=True, tokenize=False )
def generate_image(session_id: int, prompt: str, db: DBSession): torch.cuda.empty_cache() gc.collect() image = sdxl_pipe(prompt, num_inference_steps=6, guidance_scale=3).images[0] img_byte_arr = io.BytesIO() image.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue()
@app.websocket("/stream/{session_id}") async def stream(websocket: WebSocket, session_id: int, db: DBSession = Depends(get_db)): await websocket.accept() message = await websocket.receive_text() session = session_manager.get_session(session_id, db)
@app.get("/session") async def get_session(db: DBSession = Depends(get_db)): session = session_manager.get_new_session(db) return session.id
@app.get("/session/{session_id}") async def get_session(session_id: int, db: DBSession = Depends(get_db)): session = session_manager.get_session(session_id, db) return session
@app.get("/session-list") async def get_session_list(db: DBSession = Depends(get_db)): sessions = session_manager.get_session_list(db) return sessions
@app.delete("/session/{session_id}") async def delete_session(session_id: int, db: DBSession = Depends(get_db)): session_manager.remove_session(session_id, db) db.commit() return
@app.get("/session/{session_id}/title") async def get_session_title(session_id: int, db: DBSession = Depends(get_db)): session = session_manager.get_session(session_id, db) summary_response = make_title(session) session.title = summary_response[0]["summary_text"] db_session = db.query(SessionDB).filter(SessionDB.id == session.id).first() db_session.title = session.title db.add(db_session) db.commit() return session.title
@app.get("/image/{image_id}") async def get_image(image_id: int, db: DBSession = Depends(get_db)): image_db = db.query(SessionImageDB).filter(SessionImageDB.id == image_id).first()
img_byte_arr = io.BytesIO(image_db.image) return Response(img_byte_arr.getvalue(), media_type="image/png")
if name == "main":
uvicorn.run(app, host="0.0.0.0", port=args.port)