ashawkey / stable-dreamfusion

Text-to-3D & Image-to-3D & Mesh Exportation with NeRF + Diffusion.
Apache License 2.0
8k stars 710 forks source link

Question: Generate image using SDS #96

Open jenkspt opened 1 year ago

jenkspt commented 1 year ago

(Awesome work here -- I saw there is already a paper out based on your work: latent-nerf)

As a sanity check I'm trying generate an image with "differentiable image parameterization" (DIP) and the SDS algorithm here's the MWE:

import math
from tqdm import tqdm
import torch
import torch.nn as nn
from nerf.sd import StableDiffusion, seed_everything
from torch.optim.lr_scheduler import LambdaLR

import matplotlib.pyplot as plt

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles: float = 0.5):

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return LambdaLR(optimizer, lr_lambda, -1)

device = 'cuda:0'
guidance = StableDiffusion(device)
# limited memory here, don't need the decoder
guidance.vae.decoder = None

prompt = '3D texture of pebbles'
text_embeddings = guidance.get_text_embeds(prompt, '')
guidance.text_encoder.to('cpu')
torch.cuda.empty_cache()

seed_everything(42)
# put parameters approximately in range(0, 1) since this is what `encode_imgs` expects
rgb = nn.Parameter(torch.randn(1, 3, 512, 512, device=device) / 2  + .5)
optimizer = torch.optim.AdamW([rgb], lr=1e-1, weight_decay=0)
num_steps = 5000
scheduler = get_cosine_schedule_with_warmup(optimizer, 100, int(num_steps*1.5))

for step in tqdm(range(num_steps)):
    optimizer.zero_grad()
    guidance.train_step(text_embeddings, rgb, guidance_scale=100)

    optimizer.step()
    scheduler.step()

plt.imshow(rgb.detach().clamp(0, 1).squeeze(0).permute(1,2,0).cpu())
plt.axis('off')
plt.show()

I ended up with this: "3D texture of pebbles" image

I've tried using sigmoid activation for the image and various learning rates, but the images still come out super saturated.

The DreamFusion authors claim they were able to get similar DIP results as DDPM

SDS produces detail comparable to ancestral sampling, but enables new transfer learning applications because it operates in parameter space.

My question is: were you able to get this working? Or do you have any suggestions/ideas to get quality similar to DDPM?

ashawkey commented 1 year ago

@jenkspt This is very interesting! I failed to make RGB space optimization work too, but it seems latent space optimization can work well:

RGB space: tmp_img_300

Latent space: tmp_lat_img_300

import math
from tqdm import tqdm
import torch
import torch.nn as nn
from nerf.sd import StableDiffusion, seed_everything
from torch.optim.lr_scheduler import LambdaLR

import matplotlib.pyplot as plt

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles: float = 0.5):

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return LambdaLR(optimizer, lr_lambda, -1)

device = 'cuda:0'
guidance = StableDiffusion(device)
guidance.vae.encoder = None

prompt = 'pineapple'
text_embeddings = guidance.get_text_embeds(prompt, '')
guidance.text_encoder.to('cpu')
torch.cuda.empty_cache()

seed_everything(42)
latents = nn.Parameter(torch.randn(1, 4, 64, 64, device=device))
optimizer = torch.optim.AdamW([latents], lr=1e-1, weight_decay=0)
num_steps = 1000
scheduler = get_cosine_schedule_with_warmup(optimizer, 100, int(num_steps*1.5))

for step in tqdm(range(num_steps)):
    optimizer.zero_grad()

    t = torch.randint(guidance.min_step, guidance.max_step + 1, [1], dtype=torch.long, device=guidance.device)
    with torch.no_grad():
        # add noise
        noise = torch.randn_like(latents)
        latents_noisy = guidance.scheduler.add_noise(latents, noise, t)
        # pred noise
        latent_model_input = torch.cat([latents_noisy] * 2)
        noise_pred = guidance.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

    # perform guidance (high scale from paper!)
    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + 100 * (noise_pred_text - noise_pred_uncond)

    w = (1 - guidance.alphas[t])
    grad = w * (noise_pred - noise)

    latents.backward(gradient=grad, retain_graph=True)

    optimizer.step()
    scheduler.step()

    if step > 0 and step % 100 == 0:
        rgb = guidance.decode_latents(latents)
        img = rgb.detach().squeeze(0).permute(1,2,0).cpu().numpy()
        print('[INFO] save image', img.shape, img.min(), img.max())
        plt.imsave(f'tmp_lat_img_{step}.jpg', img)

The RGB-latent transition in stable-diffusion is distinct from imagen/ediffi, and it seems directly operating on the latent space is a better idea, like the recent latent-nerf.

jenkspt commented 1 year ago

Thanks! This is really helpful.

The latent space images are definitely better. I've realized that the 'saturated' look is from the high guidance scale -- and we can generate better images using the equivalent of static or dynamic thresholding by adding the following:

    with torch.no_grad():
        # Static threshold
        latents.data = latents.data.clip(-1, 1)
        # Dynamic thresholding
        #s = torch.as_tensor(np.percentile(latents.abs().cpu().numpy(), 90, axis=(1,2,3)), dtype=latents.dtype).to(device)
        #latents.data = latents.clip(-s, s) / s

(static) pineapple_static_threshold

(dynamic) pineapple_dynamic_threshold

I also tried using tanh activation in latent space. Found it harder to get good samples (tanh) pineapple_tanh

I'm wondering if any of these findings can translate to the NeRF model, or if there are other experiments that can be run on the 'DIP' parameterization and translated to the NeRF model.

Anyways thanks for the help!

ashawkey commented 1 year ago

Wow, are these generated by optimizing the RGB space? Could you hint me where you add the latents clamping?

jenkspt commented 1 year ago

These are from optimizing the latent space. I added the clamping after the scheduler.step() line from your example

ashawkey commented 1 year ago

Thanks!

thuwzy commented 1 year ago
latents.data = latents.data.clip(-1, 1)

A very interesting finding! I find it very useful. Thank you very much,

phymhan commented 1 year ago

Thanks! Very interesting results! Briefly tried this with NeRF model, not sure if it is clearly helpful. Here are some results I got:

Hamburger, with -O option, baseline

with dynamic thresholding

Pineapple, with -O2 option, baseline

with dynamic thresholding

For all these experiments, I perform one-step gradient descent on noise_pred in the direction where the mse loss between predicted x0 and its thresholded version is minimized. Alternatively, I also tried to compute a new noise_pred from the thresholded predicted x0, but that does not seem to work.

jenkspt commented 1 year ago

What exactly did you try? It's unclear to me how to translate static/dynamic thresholding to the NeRF model.

phymhan commented 1 year ago

This is what I tried: https://github.com/phymhan/stable-dreamfusion/blob/dcd7fc0557d32d611fc228f890bdaf1245880395/nerf/sd.py#L129 Basically I use noise_pred to get the current estimate of x0, then perform one step gradient descent on noise_pred to make it closer to thresholded x0.

jenkspt commented 1 year ago

Interesting, I would expect this:

I also tried to compute a new noise_pred from the thresholded predicted x0, but that does not seem to work

to produce at least comparable samples to the original

phymhan commented 1 year ago

Yep, that is weird. Will do more tests on this.

huanranchen commented 10 months ago

It seems impossible to do this in RGB space. Have someone succeeded to do this?

qiminchen commented 1 month ago

Based on the latest repo, I tried to add azimuth to the text prompt, e.g. azimuth = 0, 90, 180 gives front, side, back view of the chair, however, if azimuth is in between, e.g azimuth=120, the result still looks like side view instead of something like 2/3 side and 1/3 back based on the interpolation. I wonder what is the reason here?

import math
from tqdm import tqdm
import torch
import torch.nn as nn
from guidance.sd_utils import StableDiffusion, seed_everything
from torch.optim.lr_scheduler import LambdaLR

import matplotlib.pyplot as plt

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles: float = 0.5):

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return LambdaLR(optimizer, lr_lambda, -1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
guidance = StableDiffusion(device, False, False)
guidance.vae.encoder = None

azimuth = 120
prompt = 'a photo of an awesome chair'
embeddings = {}
for d in ['front', 'side', 'back']:
    embeddings[d] = guidance.get_text_embeds([f"{prompt}, {d} view"])

if -90 <= azimuth < 90:
    if azimuth >= 0:
        r = 1 - azimuth / 90
    else:
        r = 1 + azimuth / 90
    start_z = embeddings['front']
    end_z = embeddings['side']
else:
    if azimuth >= 0:
        r = 1 - (azimuth - 90) / 90
    else:
        r = 1 + (azimuth + 90) / 90
    start_z = embeddings['side']
    end_z = embeddings['back']

cond_embeddings = r * start_z + (1 - r) * end_z
# cond_embeddings = guidance.get_text_embeds([prompt])
uncond_embeddings = guidance.get_text_embeds([''])
text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
guidance.text_encoder.to('cpu')
torch.cuda.empty_cache()

seed_everything(42)
latents = nn.Parameter(torch.randn(1, 4, 64, 64, device=device))
optimizer = torch.optim.AdamW([latents], lr=1e-1, weight_decay=0)
num_steps = 1000
scheduler = get_cosine_schedule_with_warmup(optimizer, 100, int(num_steps*1.5))

for step in tqdm(range(num_steps)):
    optimizer.zero_grad()

    t = torch.randint(guidance.min_step, guidance.max_step + 1, (latents.shape[0],), dtype=torch.long, device=guidance.device)
    with torch.no_grad():
        # add noise
        noise = torch.randn_like(latents)
        latents_noisy = guidance.scheduler.add_noise(latents, noise, t)
        # pred noise
        latent_model_input = torch.cat([latents_noisy] * 2)
        tt = torch.cat([t] * 2)
        noise_pred = guidance.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

    # perform guidance (high scale from paper!)
    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + 100 * (noise_pred_text - noise_pred_uncond)

    w = (1 - guidance.alphas[t])
    grad = w * (noise_pred - noise)

    latents.backward(gradient=grad, retain_graph=True)

    optimizer.step()
    scheduler.step()

    with torch.no_grad():
        latents.data = latents.data.clip(-1, 1)

    if step > 0 and (step + 1) % 100 == 0:
        rgb = guidance.decode_latents(latents)
        img = rgb.detach().squeeze(0).permute(1, 2, 0).cpu().numpy()
        print('[INFO] save image', img.shape, img.min(), img.max())
        plt.imsave(f'test/tmp_lat_img_{step+1}.jpg', img)

azimuth=0

tmp_lat_img_1000

azimuth=90

tmp_lat_img_1000

azimuth=180

tmp_lat_img_1000

azimuth=120

tmp_lat_img_1000