replicate / dreambooth

A Cog model that takes training images as input and generates custom Stable Diffusion model weights as output
https://replicate.com/replicate/dreambooth
Apache License 2.0
146 stars 41 forks source link

validate training data and respond with errors before running inference code #5

Closed zeke closed 1 year ago

zeke commented 1 year ago

Using this code:

model = replicate.models.get("replicate/cog-dreambooth-trainer")
version = model.versions.get("b9a7267b10bb9e5fb19eca853ee9582e8838c8cfe1bf174f4cd50991c857992e")
output = version.predict(instance_data=open("data.zip", "rb"))

The prediction runs for a few minutes, then fails with this error:

Traceback (most recent call last):
  File "/Users/z/git/replicate/cog-dreambooth/train.py", line 39, in <module>
    train()
  File "/Users/z/git/replicate/cog-dreambooth/train.py", line 32, in train
    output = version.predict(instance_data=open("data.zip", "rb"))
  File "/Users/z/.pyenv/versions/python-3.10.4/lib/python3.10/site-packages/replicate/version.py", line 31, in predict
    raise ModelError(prediction.error)
replicate.exceptions.ModelError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'NoneType'>

cc @chenxwh @bfirsh

chenxwh commented 1 year ago

the minimum input args: instance_prompt, class_prompt, instance_data, I added an example to the page

zeke commented 1 year ago

Screen Shot 2022-11-07 at 4 31 38 PM

Where does sks come from? Is that the default keyword for the trained concept?

zeke commented 1 year ago

It would be good to check that those required arguments are set, then return a clearer error.

chenxwh commented 1 year ago

this is like the special token in textual inversion, the tokenizer links the sks or anything you give (probably not a common word with existing semantic), later you give the model prompt like photo of sks dog in a bucket

chenxwh commented 1 year ago

It would be good to check that those required arguments are set, then return a clearer error.

True I guess remove default=None in pydantic will make them required?

zeke commented 1 year ago

Hit another error after adding the three required parameters:

Code:

  output = version.predict(
    instance_prompt= "a photo of sks wearing a top hat",
    class_prompt= "a photo of a person wearing a top hat",
    instance_data=open("data.zip", "rb"))

Error:

python-3.10.4 ∴ python train.py
Training model...
Checking training data
Zipping training data
Training on Replicate using model replicate/cog-dreambooth-trainer@b9a7267b10bb9e5fb19eca853ee9582e8838c8cfe1bf174f4cd50991c857992e
https://replicate.com/replicate/cog-dreambooth-trainer/versions/b9a7267b10bb9e5fb19eca853ee9582e8838c8cfe1bf174f4cd50991c857992e
Traceback (most recent call last):
  File "/Users/z/git/replicate/cog-dreambooth/train.py", line 42, in <module>
    train()
  File "/Users/z/git/replicate/cog-dreambooth/train.py", line 32, in train
    output = version.predict(
  File "/Users/z/.pyenv/versions/python-3.10.4/lib/python3.10/site-packages/replicate/version.py", line 31, in predict
    raise ModelError(prediction.error)
replicate.exceptions.ModelError: integer division or modulo by zero
chenxwh commented 1 year ago

Not sure if this causes it, but I see it is using python 3.10? Can it install the python version in cog.yaml? The memory optimization only works for 3.7

zeke commented 1 year ago

I ran it in an Actions workflow with 3.7 but got the same result. See https://github.com/replicate/cog-dreambooth/actions/runs/3423219665/jobs/5701559360#step:5:25

chenxwh commented 1 year ago

it is hard to tell what went wrong, the line 31, in predict referring to the 31st line in predict() from here? It does not seem to be referring to a line can cause that zero division error, I guess need to debug predict()why it works on cog/replicate but not through the flow

zeke commented 1 year ago

Must be something in the Python client. I'll investigate.

zeke commented 1 year ago

Python client is doing the right thing: spitting out the error message attached to the prediction object coming back from the API (though it is a bit confusing or misleading that the stack trace makes it look like the problem might be coming from replicate-python itself).

Looking at a successful prediction from @chenxwh 7mru3txpqzbazbak5m2bafyep4 and a failed prediction from @zeke q4ckycy3gbcdffn7xno3ietxse. See https://gist.github.com/zeke/685f5a995b9e5d26dfd5fd289498840f

I can see a couple differences between these two predictions:

zeke commented 1 year ago

Tried running with all the inputs, but got the same error: https://github.com/replicate/cog-dreambooth/actions/runs/3430880187/jobs/5718361697#step:5:25

So.. it's probably the data URL vs HTTPS url. Trying that meow.

zeke commented 1 year ago

Still getting integer division or modulo by zero even when using an HTTPS URL.

https://github.com/replicate/cog-dreambooth/actions/runs/3431008226/jobs/5718638153#step:5:21

Maybe it doesn't like my training data: https://zeke.github.io/files/dreambooth-training-data.zip

Next up, gonna try it with @chenxwh's known good training data: https://replicate.delivery/pbxt/HkTENyq0Ph5hzI3HgumA1HTZzycUNp8qWL2n2SnzuaHZpIP4/Archive.zip

zeke commented 1 year ago

Progress! https://github.com/replicate/cog-dreambooth/actions/runs/3431254041/jobs/5719172576

I was able to get it working using @chenxwh's existing zip of training data. So there's something about the training data I was supplying that it didn't like. Maybe the file size? Maybe the PNG? Maybe the file extensions jpeg vs jpg?

zeke commented 1 year ago

Ran it on a Codespace with Cog. Here's the failure:

    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/src/dreambooth.py", line 322, in __getitem__
    instance_path, instance_prompt = self.instance_images_path[index % self.num_instance_images]
ZeroDivisionError: integer division or modulo by zero
ⅹ /predictions call returned status 500
chenxwh commented 1 year ago

This is solved by making sure that training images should put under /data before zipping if it is not using the github.com/replicate/dreambooth-template flow. Noted in README here too.