long8v / PTIR

Paper Today I Read
19 stars 0 forks source link

[11] DALL-E : Zero-Shot Text-to-Image Generation #11

Open long8v opened 2 years ago

long8v commented 2 years ago

image paper, article, code

TL;DR

Details

Step One: Learning the Visual Codebook

$\phi$와 $\theta$에 대해 ELB를 최대화하는 식으로 dVAE를 학습한다. code의 크기는 32 x 32이며 $K$=8192고 $p_\psi$는 uniform분포이다. code가 discrete 해서 미분 불가한 부분은 gumbel softmax를 사용해서 gradient를 흘려줬다.

$p\theta$는 log-laplace(정규 분포 지수 부분에 제곱대신 절대값) 분포로 평가됐다.

Stage Two: Learning the Prior

텍스트는 BPE encode해서 최대 256길이로 만들었고, 이미지는 dVAE encoder logit에서 argmax해서 1024개의 토큰을 얻었다. 두개의 인코딩을 concat해서 트랜스포머 디코더에 넣어줬고, 텍스트가 256보다 작을 경우에는 256개의 position에 따라 각각의 [PAD]토큰을 학습시켜줬다. -> OOD catpion에 더 강건했다. cross-entropy loss를 사용했고 텍스트와 이미지의 loss는 1/8, 7/8로 곱해줬다.

long8v commented 2 years ago

ELBO

어떤 random variable Z가 있고, Z와 X가 theta로 표현되는 p(Z, X | theta) 분포를 따른다고 하자. 보통 우리가 계산하는 건 1) theta와 X가 주어졌을 때 posterior Z 구하기 p(Z|X, theta) 2) P(theta | X) : likelihood를 최대화하는 theta 찾기이다. evidence란 theta가 주어졌을 때 likelihood를 부르는 다른 말이다.

image

이때 우리가 Z가 q라는 distribution을 따른다는 것을 알고 있을 때, p(X, Z| theta)는 p(X| Z, theta) * q(Z)로 표현할 수 있다. 이 때, log p(X | theta)의 최소값은 아래와 같이 표현할 수 있다.

image

이때 오른쪽 항을 ELBO라고 부른다

image

이 ELBO를 구하는 식은 아래와 같다.

image

마지막의 부등호는 jensen's inequality를 사용했다.

image

ELBO와 evidence의 차이는 KL divergence와 일치한다.

image

https://mbernste.github.io/posts/elbo/ https://www.youtube.com/watch?v=A8wf7QmmlUM

long8v commented 2 years ago

ELBO 2

https://www.youtube.com/watch?v=GbCAwVVKaHY https://long8v.notion.site/ELB-f7267ffc5301422c9965e3e6f0619958 image image

long8v commented 2 years ago

encoder.py

attr.s 라는 decorator로 init에 대한 validate등을 하는듯

@attr.s(eq=False)
class Conv2d(nn.Module):
    n_in:  int = attr.ib(validator=lambda i, a, x: x >= 1)
    n_out: int = attr.ib(validator=lambda i, a, x: x >= 1)
    kw:    int = attr.ib(validator=lambda i, a, x: x >= 1 and x % 2 == 1)

    use_float16:   bool         = attr.ib(default=True)
    device:        torch.device = attr.ib(default=torch.device('cpu'))
    requires_grad: bool         = attr.ib(default=False)

    def __attrs_post_init__(self) -> None:
        super().__init__()

        w = torch.empty((self.n_out, self.n_in, self.kw, self.kw), dtype=torch.float32,
            device=self.device, requires_grad=self.requires_grad)
        w.normal_(std=1 / math.sqrt(self.n_in * self.kw ** 2))

        b = torch.zeros((self.n_out,), dtype=torch.float32, device=self.device,
            requires_grad=self.requires_grad)
        self.w, self.b = nn.Parameter(w), nn.Parameter(b)