dreamingtulpa / replicate-ruby

Replicate Ruby client
60 stars 6 forks source link

analog-diffusion-1.0.ckpt: failed to allocate memory (NoMemoryError) #2

Closed basicfeatures closed 1 year ago

basicfeatures commented 1 year ago

Any clues what to do? Cheers!

run_dreambooth.rb

require "replicate"
# gem install --user-install replicate-ruby

Replicate.configure do |config|
  config.api_token = "XXX"
end

model = Replicate.client.retrieve_model("replicate/dreambooth")
version = model.latest_version

# https://replicate.com/replicate/dreambooth/api
inputs = {
  "instance_prompt": "a photo of a dsd man",
  "class_prompt": "a photo of a man",
  "instance_data": IO.binread("dsd.zip"),
  "num_class_images": 25,
  "save_sample_prompt": "...",
  "save_sample_negative_prompt": "...",
  "n_save_sample": 4,
  "save_guidance_scale": 7.5,
  "save_infer_steps": 50,
  "pad_tokens": false,
  "with_prior_preservation": true,
  "prior_loss_weight": 1,
  "seed": 1337,
  "resolution": 512,
  "center_crop": false,
  "train_text_encoder": true,
  "train_batch_size": 1,
  "sample_batch_size": 4,
  "num_train_epochs": 1,
  "max_train_steps": 2000,
  "gradient_accumulation_steps": 1,
  "gradient_checkpointing": false,
  "learning_rate": 1e-06,
  "scale_lr": false,
  "lr_scheduler": "constant",
  "lr_warmup_steps": 0,
  "use_8bit_adam": false,
  "adam_beta1": 0.9,
  "adam_beta2": 0.999,
  "adam_weight_decay": 0.01,
  "adam_epsilon": 1e-08,
  "max_grad_norm": 1,
  "ckpt_base": IO.binread("analog-diffusion-1.0.ckpt")
}

prediction = version.predict(inputs)
puts "Check status at https://replicate.com/predictions..."
dreamingtulpa commented 1 year ago

I haven't used the replicate/dreambooth model directly myself, but I've added support for the experimental dreambooth endpoint.

Checkout the readme.

Regarding your error, does that one get raised on your machine or on Replicate's side? Looks to me like loading the ckpt file into your memory causes the OOM error (which wouldn't be surprising, as the file is probably 4GB+).

basicfeatures commented 1 year ago

Coolio!

Regarding your error, does that one get raised on your machine or on Replicate's side?

The file is 2gb and I have 4gb available. I tried stuff like IO.binread("analog-diffusion-1.0.ckpt", 20) but to no avail.

dreamingtulpa commented 1 year ago

From the versions tab:

You must provide a url for ckpt_base, training starts from an existing fine-tuned checkpoint

So I guess you need to use a public url for ckpt_base like this:

training = Replicate.client.create_training(
  input: {
    instance_prompt: "zwx style",
    class_prompt: "style",
    instance_data: upload.serving_url,
    max_train_steps: 5000,
    ckpt_base: 'https://domain.com/path/to/your/model.ckpt'
  },
  trainer_version: '9c41656f8ae2e3d2af4c1b46913d7467cd891f2c1c5f3d97f1142e876e63ed7a',
  model: 'yourusername/yourmodel'
)

Haven't tested it though! Please confirm or update what works.

basicfeatures commented 1 year ago

Not much luck so far but this approach does seem a lot more convenient than the previous:

$ ruby run_dreambooth.rb
run_dreambooth.rb:19:in `<main>': undefined method `create_upload' for #<Replicate::Client:0x00000635ee428310 @api_token="XXX", @api_endpoint="https://api.replicate.com/v1", @webhook_url=nil> (NoMethodError)

upload = Replicate.client.create_upload
                         ^^^^^^^^^^^^^^
Did you mean?  create_prediction

run_dreambooth.rb

Replicate.configure do |config|
  config.api_token = "XXX"
end

upload = Replicate.client.create_upload
upload.attach("dsd.zip")

training = Replicate.client.create_training(
  input: {
    instance_prompt: "a photo of a dsd man",
    class_prompt: "a photo of a man",
    instance_data: upload.serving_url,
    num_class_images: 25,
    max_train_steps: 5000,
    ckpt_base: "https://huggingface.co/wavymulder/Analog-Diffusion/resolve/main/analog-diffusion-1.0.ckpt"
  },
  trainer_version: "9c41656f8ae2e3d2af4c1b46913d7467cd891f2c1c5f3d97f1142e876e63ed7a",
  model: ("replicate/dreambooth")
)
dreamingtulpa commented 1 year ago

You must bundle update, dreambooth endpoint has been added with gem version 0.2. Also just pushed a new helper method which makes uploading a zip file easier:

upload = Replicate.client.upload_zip('path/to/file.zip')

... is enough now. No need to call #attach anymore.

basicfeatures commented 1 year ago

Not getting any errors, but nothing shows up at https://replicate.com/predictions:

require "replicate"

# gem install --user-install specific_install
# gem git_install --user-install https://github.com/dreamingtulpa/replicate-ruby

Replicate.configure do |config|
  config.api_token = "XXX"
end

upload = Replicate.client.upload_zip("dsd.zip")

training = Replicate.client.create_training(
  input: {
    instance_prompt: "a photo of a dsd man",
    class_prompt: "a photo of a man",
    instance_data: upload.serving_url,
    num_class_images: 25,
    max_train_steps: 5000,
    ckpt_base: "https://huggingface.co/wavymulder/Analog-Diffusion/resolve/main/analog-diffusion-1.0.ckpt"
  },
  trainer_version: "9c41656f8ae2e3d2af4c1b46913d7467cd891f2c1c5f3d97f1142e876e63ed7a",
  model: ("replicate/dreambooth")
)

puts "Check status at https://replicate.com/predictions..."
dreamingtulpa commented 1 year ago

The model parameter is wrong, needs to be yourusername/yourmodelname.

basicfeatures commented 1 year ago

So I missed the most essential part 🤣

This Cog/Docker stuff is mad confusing though.

basicfeatures commented 1 year ago

I'd say this one is safe to close. Thanks a bunch for the help!