pytorch / torchchat

Run PyTorch LLMs locally on servers, desktop and mobile
BSD 3-Clause "New" or "Revised" License
3.34k stars 215 forks source link

MPS: chat command is much slower than generate on Mac #999

Open iseeyuan opened 3 months ago

iseeyuan commented 3 months ago

🐛 Describe the bug

For generate on llama3.1, I got 9.1 tok/s, but chat is much slower. I got around 1.4 tok/s. Test laptop: MacBook Pro with M1 Max, 64 GB memory. Sonoma 14.5

Details for both generate and chat:

(torchchat) myuan@myuan-mbp torchchat % python3 torchchat.py generate llama3.1 --prompt "write me a story about a boy and his bear"
NumExpr defaulting to 10 threads.
PyTorch version 2.5.0.dev20240710 available.
Using device=mps
Loading model...
Time to load model: 32.03 seconds
-----------------------------------------------------------
write me a story about a boy and his bear
Once upon a time, in a dense forest, there lived a young boy named Timmy. Timmy was a kind and gentle soul, with a heart full of love for all creatures. He had grown up in the forest, surrounded by the sights, sounds, and smells of nature. As he wandered through the woods, he would often talk to the trees, the rabbits, and the birds, and they seemed to respond to his gentle voice.
One day, while exploring a hidden glade, Timmy stumbled upon a big, fluffy bear. The bear was unlike any other he had seen before - its fur was a soft, golden hue, and its eyes twinkled like the stars on a clear night. Timmy was both startled and fascinated by the bear, and he slowly reached out a hand to touch its fur.
To his surprise, the bear nuzzled his hand, and Timmy felt a deep connection to the creature. The bear, whose name was Orion, had
Time for inference 1: 21.87 sec total, time to first token 5.20 sec with parallel prefill, 199 tokens, 9.10 tokens/sec, 109.91 ms/token
Bandwidth achieved: 146.12 GB/s
*** This first iteration will include cold start effects for dynamic import, hardware caches. ***

========================================

Average tokens/sec: 9.10
Memory used: 0.00 GB
(torchchat) myuan@myuan-mbp torchchat % python3 torchchat.py chat llama3.1
NumExpr defaulting to 10 threads.
PyTorch version 2.5.0.dev20240710 available.
Using device=mps
Loading model...
Time to load model: 29.87 seconds
-----------------------------------------------------------
Starting Interactive Chat
Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of 8192 tokens is hit or until the user says /bye
Do you want to enter a system prompt? Enter y for yes and anything else for no.
y
What is your system prompt?
software engineer
User: Write a python code to solve a maze problem
Model: Here's a simple implementation of a maze solver using Breadth-First Search (BFS) algorithm in Python:

```python
from collections import deque

def solve_maze(maze, start, end):
    # Define the possible movements (up, down, left, right)
    movements = [(0, 1), (0, -1), (1, 0), (-1, 0)]

    # Create a queue for BFS
    queue = deque([(start, [])])

    # Create a set to store visited cells
    visited = set([start])

    while queue:
        (x, y), path = queue.popleft()

        # Check if the current cell is the destination
        if (x, y) == end:
            return path + [(x, y)]

        # Explore neighbors
        for dx, dy in movements:
            nx, ny = x + dx, y + dy

            # Check if the neighbor is inside the
Time for inference 1: 142.65 sec total, time to first token 4.86 sec with parallel prefill, 199 tokens, 1.40 tokens/sec, 716.83 ms/token
Bandwidth achieved: 22.40 GB/s
*** This first iteration will include cold start effects for dynamic import, hardware caches. ***
User: How to optimize it?
Model: Here are some ways to optimize the maze solver:

1. **Using a more efficient data structure**: Currently, we are using a list to represent the maze. We can use a matrix (2D array) to represent the maze, which would be more efficient.

2. **Avoiding repeated computations**: We are recalculating the `is_valid` function on each iteration. We can pre-calculate the maze and store it in a separate matrix to avoid repeated computations.

3. **Using a more efficient algorithm**: BFS has a time complexity of O(n*m), where n is the number of rows and m is the number of columns. We can use other algorithms like Dijkstra's algorithm or A\* algorithm, which have better time complexities.

4. **Using multithreading or multiprocessing**: If we are dealing with a very large maze, we can use multithreading or multiprocessing to speed up the computation.

Here's an optimized version of the code:

```python
from collections
Time for inference 2: 138.47 sec total, time to first token 1.54 sec with parallel prefill, 199 tokens, 1.44 tokens/sec, 695.81 ms/token
Bandwidth achieved: 23.08 GB/s
User: Can you show me a code example?
Model: Here's an optimized version of the maze solver using a matrix to represent the maze and avoiding repeated computations:

```python
from collections import deque

def solve_maze(maze):
    rows, cols = len(maze), len(maze[0])
    directions = [(0, 1), (0, -1), (1, 0), (-1, 0)]

    # Create a matrix to store the distances from the start to each cell
    distances = [[float('inf') for _ in range(cols)] for _ in range(rows)]
    distances[0][0] = 0

    # Create a queue for BFS
    queue = deque([(0, 0)])

    while queue:
        x, y = queue.popleft()

        # Explore neighbors
        for dx, dy in directions:
            nx, ny = x + dx, y + dy

            # Check if the neighbor is within the maze and is not a wall

Time for inference 3: 177.72 sec total, time to first token 0.85 sec with parallel prefill, 199 tokens, 1.12 tokens/sec, 893.08 ms/token
Bandwidth achieved: 17.98 GB/s

Versions

(torchchat) myuan@myuan-mbp torchchat % python collect_env.py Collecting environment information... PyTorch version: 2.5.0.dev20240710 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.0.40.1) CMake version: version 3.30.1 Libc version: N/A

Python version: 3.10.0 (default, Mar 3 2022, 03:54:28) [Clang 12.0.0 ] (64-bit runtime) Python platform: macOS-14.5-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Apple M1 Max

Versions of relevant libraries: [pip3] numpy==1.26.4 [pip3] torch==2.5.0.dev20240710 [pip3] torchao==0.3.1 [conda] numpy 1.26.4 pypi_0 pypi [conda] torch 2.5.0.dev20240710 pypi_0 pypi [conda] torchao 0.3.1 pypi_0 pypi

Jack-Khuu commented 3 months ago

Seems like a bug with max_seq_len: https://github.com/pytorch/torchchat/blob/0b001b9dc74c12e136f5ee9b3c19427b9acd24ff/generate.py#L631

When I hack it to 200 (compared to the default 8192) the perf for chat is close to that of generate

Jack-Khuu commented 3 months ago

This also seems isolated to MPS as well (I don't see as significant a drop on Cuda)

@manuelcandales Looks like your PR might solve this problem for free

iseeyuan commented 3 months ago

@Jack-Khuu , would a shorter max_seq_length just a hack? If the chat conversation go beyond the limit the chat will stop, which limit the user experience of long chat history.

malfet commented 3 months ago

See good old https://github.com/pytorch/torchchat/issues/783

Jack-Khuu commented 3 months ago

would a shorter max_seq_length just a hack?

It would be, which is why we're lucky to have https://github.com/pytorch/torchchat/pull/964 Manuel's changes in PT gets picked up by the pin bump and will hopefully resolve the seq_length issues

byjlw commented 2 months ago

This has now been fixed with the pin bump in #1029

Jack-Khuu commented 2 months ago

Seems like I'm still seeing it...., can someone else ack that they see a similar behavior?

byjlw commented 1 month ago

Seems like I'm still seeing it...., can someone else ack that they see a similar behavior?

You used the same max tokens flag?

byjlw commented 1 month ago

actually just did a test. ~15 t/s for generate and 1.5 t/s for chat using similar params