simonw / llm-replicate

LLM plugin for models hosted on Replicate
Apache License 2.0
58 stars 6 forks source link

Command to fetch existing Replicate permissions and store them in SQLite #11

Closed simonw closed 1 year ago

simonw commented 1 year ago

I just spotted https://replicate.com/docs/reference/http#predictions.list in their docs:

Get a paginated list of predictions that you've created with your account. This includes predictions created from the API and the Replicate website. Returns 100 records per page.

curl -s \
  -H "Authorization: Token r8_..." \
  https://api.replicate.com/v1/predictions
The response is a JSON object in the following format:
{
  "previous": null,
  "next": "https://api.replicate.com/v1/predictions?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw",
  "results": [{}, {}, {}]
}

The results key is a list of prediction objects in the following format:

{
  "id": "jpzd7hm5gfcapbfyt4mqytarku",
  "version": "b21cbe271e65c1718f2999b038c18b45e21e4fba961181fbfae9342fc53b9e05",
  "urls": {
    "get": "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku",
    "cancel": "https://api.replicate.com/v1/predictions/jpzd7hm5gfcapbfyt4mqytarku/cancel"
  },
  "created_at": "2022-04-26T20:00:40.658234Z",
  "started_at": "2022-04-26T20:00:84.583803Z",
  "completed_at": "2022-04-26T20:02:27.648305Z",
  "source": "web",
  "status": "succeeded"
}
simonw commented 1 year ago

Then: https://replicate.com/docs/reference/http#predictions.get

GET https://api.replicate.com/v1/predictions/{prediction_id}

Returning:

{
  "id": "rrr4z55ocneqzikepnug6xezpe",
  "version": "be04660a5b93ef2aff61e3668dedb4cbeb14941e62a3fd5998364a32d613e35e",
  "urls": {
    "get": "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe",
    "cancel": "https://api.replicate.com/v1/predictions/rrr4z55ocneqzikepnug6xezpe/cancel"
  },
  "created_at": "2022-09-13T22:54:18.578761Z",
  "started_at": "2022-09-13T22:54:19.438525Z",
  "completed_at": "2022-09-13T22:54:23.236610Z",
  "source": "api",
  "status": "succeeded",
  "input": {
    "prompt": "oak tree with boletus growing on its branches"
  },
  "output": [
    "https://replicate.com/api/models/stability-ai/stable-diffusion/files/9c3b6fe4-2d37-4571-a17a-83951b1cb120/out-0.png"
  ],
  "error": null,
  "logs": "Using seed: 36941...",
  "metrics": {
    "predict_time": 4.484541
  }
}

I'm going to put these in a replicate_predictions table.

simonw commented 1 year ago

It's not clear to me if I can get the crucial model name from this. The version on its own is no good without the model - see https://replicate.com/docs/reference/http#models.versions.get which needs this:

GET https://api.replicate.com/v1/models/{model_owner}/{model_name}/versions/{version_id}
simonw commented 1 year ago

This is blocked on:

I may ship it anyway since it's still useful, especially for users who don't use many different models on Replicate.

simonw commented 1 year ago

I could even have LLM take a guess at the model name by looking up the version ID in its replicate JSON files.

I will store that as _model_guess to avoid clashing with any fields they add tot their JSON in the future.

simonw commented 1 year ago

Another way to guess: look for URLs in the output field:

  "output": [
    "https://replicate.com/api/models/stability-ai/stable-diffusion/files/9c3b6fe4-2d37-4571-a17a-83951b1cb120/out-0.png"
  ]
simonw commented 1 year ago

I'm going to put _model_guess after id before saving to the DB.

simonw commented 1 year ago

Looks like guessing based on output doesn't work any more because they redesigned those URLs to look like this:

https://replicate.delivery/mgxm/e2b8944a-7aaa-4b19-b9a6-180fe6fa6ca6/out-0.png

Here's the code I had written that doesn't work:

def guess_model(data, version_to_model):
    version = data["version"]
    if version in version_to_model:
        return version_to_model[version]
    # Try to guess from the output
    output = data.get("output") or []
    if isinstance(output, list) and output:
        first_url = output[0]
        if (
            first_url.startswith("https://replicate.com/api/models/")
            and "/files/" in first_url
        ):
            # https://replicate.com/api/models/stability-ai/stable-diffusion/files/...
            parts = first_url.split("/")
            owner = parts[5]
            name = parts[6]
            return "{}/{}".format(owner, name)
    return None
simonw commented 1 year ago

I have 140 predictions right now, which means 142 API calls (2 to paginate through the list, 140 because I have to fetch details for each one).

I don't think the Replicate API has a rate limit on this, but in case it does I should make it so it doesn't attempt to fetch the same prediction twice, that way you can run the command again if it hits a rate limit error.

A progress bar would be really nice here.

Could I have a feature where it starts where you last left off? That depends on the order that the predictions endpoint returns - if it starts from most recent I could always fetch the first page but only fetch the second page if there are predictions on the first page I haven't seen before.

If it starts at the beginning then I'll always need to fetch every page of predictions. I can still save on the prediction details calls though.

simonw commented 1 year ago

The other case to consider is that sometimes a prediction is "pending" or "running" in which case I should fetch it again on subsequent runs of the command. Only predictions that have completed_at of null should be fetched each time (unless they have status failed I think).

simonw commented 1 year ago

Moving to a PR.