Stability-AI / stability-sdk

SDK for interacting with stability.ai APIs (e.g. stable diffusion inference)
https://platform.stability.ai/
MIT License
2.42k stars 338 forks source link

Adaptive sdxl #244

Closed satsumas closed 1 year ago

satsumas commented 1 year ago

Adds parameters used to enable T2I requests.

To test:

  1. Check out the following repositories at these branches:

generator_server adaptive_sdxl

stable-diffusion,ENG-811-PLATFORM-408-hatch branch

Interfaces, branch adaptive_sdxl

  1. Get the following SDXL 2.2.2 T2I adapters from the server at cw.hpc.stability.ai:/mnt/nvme/home/nitrosocke/t2i-adapter-trainer/training_logs/, and put them in models/t2iadapters/models

sdxl-sketch_e0_gs16000.pth

sdxl-canny_e0_gs10000.pth

sdxl-depth_e1_gs24000.pth

  1. Get the example images and the 3rd party models by running the scripts as instructed in stable_diffusion/src/stablecore/modules/t2iadapter/README.md.

  2. Then, from inside generator_server/ run the following. Note that the MODELS_DIR is specific to this advanced inference use case:

    
    hatch -e sdxl shell
    pip install -e src/generator_server/lib/stable_diffusion/
    pip install -e .
    pip install -e /data/workspace/sandbox/stability-sdk/ #ENG-811-PLATFORM-408-hatch branch
    export MODELS_DIR="/path/to/generator_server/deploy/models/stable-diffusion-xl-beta-v2-2-2-advanced"
    export $(grep -v '^#' $MODELS_DIR/model.env | xargs -d '\n')

ADAPTER_PATH="/path/to/generator_server/src/generator_server/lib/stable_diffusion/models" # inside here are your adapters and 3rd party models SD_REPO_PATH="/path/to/stable-diffusion-staging" ONEFLOW_PATH="/path/to/generator_server/src/generator_server/lib/stable_diffusion/outputs/txt2img-samples/serialized" python src/generator_server/grpc_server.py


Now the server is running, you can send T2I requests using the following code:

import io import os import warnings

from PIL import Image from stability_sdk import client import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation

Our Host URL should not be prepended with "https" nor should it have a trailing slash.

os.environ['STABILITY_HOST'] = 'localhost:50051'

Sign up for an account at the following link to get an API Key.

https://dreamstudio.ai/

Click on the following link once you have created an account to be taken to your API Key.

https://dreamstudio.ai/account

Paste your API Key below.

os.environ['STABILITY_KEY'] = 'anything-will-do'

Set up our connection to the API.

stability_api = client.StabilityInference( key=os.environ['STABILITY_KEY'], # API Key reference. host=os.environ['STABILITY_HOST'], # Host reference. verbose=True, # Print debug messages. engine="stable-diffusion-xl-beta-v2-2-2", # Set the engine to use for generation. For SD 2.0 use "stable-diffusion-v2-0".

Available engines: stable-diffusion-v1 stable-diffusion-v1-5 stable-diffusion-512-v2-0 stable-diffusion-768-v2-0

# stable-diffusion-512-v2-1 stable-diffusion-768-v2-1 stable-diffusion-xl-beta-v2-2-2 stable-inpainting-v1-0 stable-inpainting-512-v2-0

)

Set up our initial generation parameters.

answers = stability_api.generate( prompt= "A car with wings", seed=82675780, adapter_type = generation.SKETCH, adapter_strength = 0.75, adapter_init_type = generation.ADAPTER_IMAGE, sampler=generation.SAMPLER_K_DPMPP_SDE, init_image = Image.open("path/to/generator_server/src/generator_server/lib/stable_diffusion/examples/sketch/car.png") )

Set up our warning to print to the console if the adult content classifier is tripped.

If adult content classifier is not tripped, display generated image.

for resp in answers:

for artifact in resp.artifacts:
    if artifact.finish_reason == generation.FILTER:
        warnings.warn(
            "Your request activated the API's safety filters and could not be processed."
            "Please modify the prompt and try again.")
    if artifact.type == generation.ARTIFACT_IMAGE:
        global img
        img = Image.open(io.BytesIO(artifact.binary))
        img.save(str(artifact.seed)+ ".png") # Save our generated images its seed number as the filename.
        print(f"Generated image saved as {artifact.seed}.png")
sonarcloud[bot] commented 1 year ago

Kudos, SonarCloud Quality Gate passed!    Quality Gate passed

Bug A 0 Bugs
Vulnerability A 0 Vulnerabilities
Security Hotspot A 0 Security Hotspots
Code Smell A 0 Code Smells

No Coverage information No Coverage information
0.0% 0.0% Duplication