predibase / lorax

Multi-LoRA inference server that scales to 1000s of fine-tuned LLMs
https://loraexchange.ai
Apache License 2.0
2.19k stars 143 forks source link

Throughput and Latency degradation with a single LoRA adapter on A100 40 GB #670

Open kaushikmitr opened 5 days ago

kaushikmitr commented 5 days ago

System Info


Setup Summary for LoRAX Benchmarking with Llama-2 Model:

Benchmark Metrics: We measured:

You can view detailed results in the benchmark document: Benchmark 1 server - LoRAX.pdf


Observations and Questions:

Information

Tasks

Reproduction

Sample Query:

curl -i ${IP}:${PORT}//generate     -X POST     -d '{
        "inputs": "[INST] Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? [/INST]",
        "parameters": {
            "max_new_tokens": 10, 
           "adapter_ids" : "vineetsharma/qlora-adapter-Llama-2-7b-hf-TweetSumm"
        }
    }'     -H 'Content-Type: application/json'

Deployment YAML Configuration:

apiVersion: v1
kind: Service
metadata:
  name: lorax-llama2-7b-pool
spec:
  selector:
    app: lorax-llama2-7b-pool
  ports:
  - protocol: TCP
    port: 8000
    targetPort: 8000
  type: LoadBalancer

---
apiVersion: apps/v1
kind: Deployment
metadata:
  name: lorax-llama2-7b-pool
spec:
  replicas: 1
  selector:
    matchLabels:
      app: lorax-llama2-7b-pool
  template:
    metadata:
      labels:
        app: lorax-llama2-7b-pool
    spec:
      containers:
        - name: lora
          image: "ghcr.io/predibase/lorax:latest"
          imagePullPolicy: Always
          #command: ["python3", "-m", "lorax.entrypoints.openai.api_server"]
          args:
          - "--model-id"
          - "meta-llama/Llama-2-7b-hf"
          env:
            - name: PORT
              value: "8000"
            - name: HUGGING_FACE_HUB_TOKEN
              valueFrom:
                secretKeyRef:
                  name: hf-token
                  key: token
          ports:
            - containerPort: 8000
              name: http
              protocol: TCP
          livenessProbe:
            failureThreshold: 240
            httpGet:
              path: /health
              port: http
              scheme: HTTP
            initialDelaySeconds: 5
            periodSeconds: 5
            successThreshold: 1
            timeoutSeconds: 1
          readinessProbe:
            failureThreshold: 600
            httpGet:
              path: /health
              port: http
              scheme: HTTP
            initialDelaySeconds: 5
            periodSeconds: 5
            successThreshold: 1
            timeoutSeconds: 1
          resources:
            limits:
              nvidia.com/gpu: 1
            requests:
              nvidia.com/gpu: 1
          volumeMounts:
            - mountPath: /data
              name: data
            - mountPath: /dev/shm
              name: shm
      restartPolicy: Always
      schedulerName: default-scheduler
      terminationGracePeriodSeconds: 30
      volumes:
        - name: data
          emptyDir: {}
        - name: shm
          emptyDir:
            medium: Memory
        - name: adapters
          emptyDir: {}

---

Expected Behavior

After reviewing the LoRAX blog, particularly the statement:

"Processing 1M tokens spread evenly across 32 different fine-tuned models takes just about as much time as processing the same number of tokens for 1 fine-tuned model due to the near-optimal multi-adapter batching throughput associated with LoRAX."

I anticipated a smaller decrease in latency and throughput when using LoRA adapters. Could you please clarify how the savings in Figure 1 were calculated? It would be helpful to understand if this level of performance degradation is typical and if there are any specific tuning options that might help mitigate this. Thank you for your guidance.

ahg-g commented 2 days ago

/cc