Closed fagemx closed 3 months ago
有沒有考慮增加或另開版本 能放replicate的 lora_urls 有的話會很方便! 謝謝
没明白您具体的使用场景。是要在 replicate.com 的模型上支持自定义 lora 的功能吗?
是的,replicate.com上訓練完是lora_urls 想在Replicate上的也能用lora
Hello, I'm also looking into an option to use custom LoRa's within the cog version / on replicate. The use case is to use custom LoRa's to generate images including special objects or fantasy figures wich can be accived throug LoRa's.
I already experimented a bit with the code, but (on a RTX 4060) it never finished a gen when I try to generate an image for multiple minutes while using cog predict command. So I wasn't able to fully test my code but it correctly downloads the file into the folder and adds it to the ImageGenerationParams
. I'm also not able to push my feature branch so I will post the code here. It would be great if @konieshadow could implement the code and update the version on replicate so we can use custom LoRa's. I'm happy to assist if any additional changes are required or if there are any open questions.
lora_manager.py
import hashlib
import os
import requests
def _hash_url(url):
"""Generates a hash value for a given URL."""
return hashlib.md5(url.encode('utf-8')).hexdigest()
class LoraManager:
def __init__(self):
self.cache_dir = "/models/loras/"
def _download_lora(self, url):
"""Downloads a LoRa from a URL and saves it in the cache."""
url_hash = _hash_url(url)
filepath = os.path.join(self.cache_dir, f"{url_hash}.safetensors")
file_name = f"{url_hash}.safetensors"
if not os.path.exists(filepath):
print(f"start download for: {url}")
try:
response = requests.get(url, timeout=10, stream=True)
response.raise_for_status()
with open(filepath, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Download successfully, saved as {file_name}")
except Exception as e:
raise Exception(f"error downloading {url}: {e}")
else:
print(f"LoRa already downloaded {url}")
return file_name
def check(self, urls):
"""Manages the specified LoRAs: downloads missing ones and returns their file names."""
paths = []
for url in urls:
path = self._download_lora(url)
paths.append(path)
return paths
Changes in predict.py
from lora_manager import LoraManager
in the predict args
use_default_loras: bool = Input(default=True, description="Use default LoRAs"),
loras_custom_urls: str = Input(default="",
description="Custom LoRAs URLs in the format 'url,weight' provide multiple seperated by ; (example 'url1,0.3;url2,0.1')"),
loras = copy.copy(default_loras)
replaced with
lora_manager = LoraManager()
# Use default loras if selected
loras = copy.copy(default_loras) if use_default_loras else []
# add custom user loras if provided
if loras_custom_urls:
urls = [url.strip() for url in loras_custom_urls.split(';')]
loras_with_weights = [url.split(',') for url in urls]
total_loras_count = len(loras) + len(loras_with_weights)
if total_loras_count > 4:
raise ValueError("The total number of LoRAs (default and custom) cannot exceed 4")
custom_lora_paths = lora_manager.check([lw[0] for lw in loras_with_weights])
custom_loras = [[path, float(lw[1]) if len(lw) > 1 else 1.0] for path, lw in
zip(custom_lora_paths, loras_with_weights)]
loras.extend(custom_loras)
@TechnikMax Thank you for your codes. I will implement the function with it soon.
@konieshadow Any news about the implementation?
有沒有考慮增加或另開版本 能放replicate的 lora_urls 有的話會很方便! 謝謝