mistralai / client-python

Python client library for Mistral AI platform
Apache License 2.0
485 stars 106 forks source link

[BUG CLIENT]: Incomplete batch api result #159

Closed GabrielVidal1 closed 2 days ago

GabrielVidal1 commented 2 days ago

Python -VV

Python 3.10.4 (main, May 25 2022, 21:26:26) [GCC 11.2.1 20220219]

Pip Freeze

annotated-types==0.7.0
anyio==4.6.2.post1
anytree==2.12.1
appdirs==1.4.4
asgiref==3.3.4
astroid==3.3.5
attrs==24.2.0
autopep8==2.3.1
azure-ad-verify-token==0.2.0
boto3==1.35.55
botocore==1.35.55
cached-property==2.0.1
certifi==2020.6.20
cffi==1.14.1
chardet==3.0.4
coreapi==2.3.3
coreschema==0.0.4
coverage==7.6.4
cryptography==2.9.2
defusedxml==0.7.1
dill==0.3.9
Django==3.2.25
django-annoying==0.10.6
django-cors-headers==3.4.0
django-deprecate-fields==0.1.2
django-dia==0.5.0
django-elasticsearch-dsl==7.1.4
django-elasticsearch-dsl-drf==0.20.8
django-extensions==3.0.3
django-factory-boy==1.0.0
django-filter==23.5
django-fixture-magic==0.1.5
django-linear-migrations==2.3.0
django-multiselectfield==0.1.12
django-nine==0.2.3
django-nose==1.4.5
django-ordered-model==3.5
django-phonenumber-field==5.1.0
django-redis-cache==3.0.1
django-silk==4.4.1
django-slowtests==1.1.1
djangorestframework==3.11.1
djangorestframework-jwt==1.11.0
docutils==0.20.1
drf-spectacular==0.24.2
drf-yasg==1.17.1
elasticsearch==7.8.0
elasticsearch-dsl==7.2.1
et_xmlfile==1.0.1
eval_type_backport==0.2.0
exceptiongroup==1.2.2
execnet==2.1.1
factory-boy==2.12.0
Faker==4.1.1
future==0.18.2
geographiclib==2.0
geopy==2.4.1
gevent==24.10.3
gprof2dot==2024.6.6
greenlet==3.1.1
gunicorn==20.0.4
h11==0.14.0
httpcore==1.0.7
httpx==0.27.2
idna==2.10
inflection==0.5.0
iniconfig==2.0.0
isodate==0.7.2
isort==5.13.2
itypes==1.2.0
jdcal==1.4.1
jeepney==0.8.0
Jinja2==2.11.1
jmespath==1.0.1
jsonpath-python==1.0.6
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
keyring==21.1.1
lxml==4.9.1
mailjet-rest==1.3.4
MarkupSafe==1.1.1
mccabe==0.7.0
mistralai==1.2.3
mpmath==1.3.0
munkres==1.1.4
mypy-extensions==1.0.0
nose==1.3.7
numpy==1.22.4
openpyxl==3.1.2
packaging==20.4
pandas==2.2.1
phonenumbers==8.13.49
phonenumberslite==8.13.49
pillow==10.4.0
platformdirs==4.3.6
pluggy==1.5.0
progressbar==2.5
psycopg2==2.9.9
psycopg2-binary==2.8.5
py==1.8.2
pycodestyle==2.12.1
pycparser==2.20
pycryptodome==3.12.0
pydantic==2.10.1
pydantic_core==2.27.1
PyJWT==1.7.1
pylint==3.3.1
pylint-django==2.5.5
pylint-plugin-utils==0.8.2
pyparsing==2.4.7
pytest==8.3.3
pytest-forked==1.6.0
pytest-xdist==1.34.0
python-dateutil==2.8.2
python-decouple==3.3
pytz==2020.1
PyYAML==5.3.1
redis==3.5.3
referencing==0.35.1
requests==2.25.1
requests-file==2.1.0
requests-toolbelt==1.0.0
rpds-py==0.21.0
ruamel.yaml==0.18.6
ruamel.yaml.clib==0.2.12
s3transfer==0.10.3
SecretStorage==3.3.3
sentry-sdk==1.39.2
six==1.15.0
sniffio==1.3.1
sqlparse==0.3.1
sympy==1.9
tblib==1.7.0
text-unidecode==1.3
tomli==2.0.2
tomlkit==0.13.2
tqdm==4.60.0
typing-inspect==0.9.0
typing_extensions==4.12.2
tzdata==2024.2
Unidecode==1.1.1
uritemplate==3.0.1
urllib3==1.26.11
watchtower==3.3.1
xlrd==2.0.1
XlsxWriter==3.2.0
zeep==4.0.0
zope.event==5.0
zope.interface==7.1.1

Reproduction Steps

Using mainly parts of the example script provided to use the batch API (found here) I get an error as the output_file in download_file is not complete with all responses.

Restarting the download process of the batch result a second time works as expected.

Here are the 3 functions from the example I used with a input_file of ~4000 samples

def run_batch_job(client, input_file, model):
    """
    Run a batch job using the provided input file and model.

    Args:
        client (Mistral): The Mistral client instance.
        input_file (File): The input file object.
        model (str): The model to use for the batch job.

    Returns:
        BatchJob: The completed batch job object.
    """
    batch_job = client.batch.jobs.create(
        input_files=[input_file.id],
        model=model,
        endpoint="/v1/chat/completions",
        metadata={"job_type": "testing"}
    )

    while batch_job.status in ["QUEUED", "RUNNING"]:
        batch_job = client.batch.jobs.get(job_id=batch_job.id)
        print_stats(batch_job)
        time.sleep(1)

    print(f"Batch job {batch_job.id} completed with status: {batch_job.status}")
    return batch_job

def download_file(client, file_id, output_path):
    """
    Download a file from the Mistral server.

    Args:
        client (Mistral): The Mistral client instance.
        file_id (str): The ID of the file to download.
        output_path (str): The path where the file will be saved.
    """
    if file_id is not None:
        print(f"Downloading file to {output_path}")
        output_file = client.files.download(file_id=file_id)
        with open(output_path, "w") as f:
            for chunk in output_file.stream:
                f.write(chunk.decode("utf-8"))
        print(f"Downloaded file to {output_path}")

def main(num_samples, success_path, error_path, model):
    """
    Main function to run the batch job.

    Args:
        num_samples (int): Number of samples to process.
        success_path (str): Path to save successful outputs.
        error_path (str): Path to save error outputs.
        model (str): Model name to use.
    """
    client = create_client()
    input_file = create_input_file(client, num_samples)
    print(f"Created input file {input_file}")

    batch_job = run_batch_job(client, input_file, model)
    print(f"Job duration: {batch_job.completed_at - batch_job.created_at} seconds")
    download_file(client, batch_job.error_file, error_path)
    download_file(client, batch_job.output_file, success_path)

With ~4500 samples, the jobs completes without errors, but I get the following error from the download_file file function :

    result = self.client.process_batch(prompts)
  File "/usr/src/SERVER_DJANGO/impact_calculator/emissions/services/mistral.py", line 165, in process_batch
    self._download_file(batch_job.output_file, success_path)
  File "/usr/src/SERVER_DJANGO/impact_calculator/emissions/services/mistral.py", line 123, in _download_file
    f.write(chunk.decode("utf-8"))
UnicodeDecodeError: 'utf-8' codec can't decode bytes in position 16374-16375: unexpected end of data

After checking the downloaded file at batch_job.output_file I find that the file stops unexpectedly at the ~2000 line:

...
{"id":"79-5e5093bb-4efe-407e-a526-d04f39df1b67","custom_id":"1912","response":{"status_code":200,"body":{"id":"17730faf50de4b4a93d9a953c7605d21","object":"chat.completion","model":"mistral-large-latest","usage":{"prompt_tokens":199,"completion_tokens":19,"total_tokens":218},"created":1732273082,"choices":[{"index":0,"finish_reason":"stop","message":{"role":"assistant","content":"Cladding in natural stone ALBAMIEL From 2 to 15 cm thick","tool_calls":null}}]}},"error":null}
{"id":"80-a2ab6000-6c46-4232-9cef-e4d6767ce1de","custom_id":"1913","response":{"status_code":200,"body":{"id":"4786718687604deaa08cdb77bcfa55e8","object":"chat.com

Expected Behavior

I expect to be able to download the complete result file of a batch with the provided function, without any errors.

def download_file(client, file_id, output_path):
    """
    Download a file from the Mistral server.

    Args:
        client (Mistral): The Mistral client instance.
        file_id (str): The ID of the file to download.
        output_path (str): The path where the file will be saved.
    """
    if file_id is not None:
        print(f"Downloading file to {output_path}")
        output_file = client.files.download(file_id=file_id)
        with open(output_path, "w") as f:
            for chunk in output_file.stream:
                f.write(chunk.decode("utf-8"))
        print(f"Downloaded file to {output_path}")

Additional Context

No response

Suggested Solutions

This problem seems to happen once every 2-3 tries, so should I just add a small delay (eg 5s) before trying to download the results of a batch to ensure having a full file ?

jean-malo commented 2 days ago

Hi @GabrielVidal1 The docs need to be updated, the server returns a bytes stream so you should write the files as bytes and then decode later. So something like this will work

        output_file = client.files.download(file_id=file_id)
        with open(output_path, "wb") as f:
            for chunk in output_file.stream:
                f.write(chunk)

If you don't do this then there's a chance you land on a chunk with an incomplete bytes sequence which will raise the error that you see

jean-malo commented 2 days ago

If you want to only write a decoded file then you could use an incremental decoder that will handle buffering for you (the codecs module has an implementation I believe)

GabrielVidal1 commented 2 days ago

Hi @jean-malo,

Thanks for your quick response! Writing the file as bytes and decoding it later works for me 👍