Latent Variable Models

1. Use latent variable model

𝑝(π‘₯)=βˆ«π‘(π‘₯|𝑧)𝑝(𝑧)π‘‘π‘§π‘πœƒ(π‘₯|𝑧)→𝑝(π‘₯|𝑧)

𝑝(𝑧) is the prior distribution, will be a predefined density function.

𝑝(π‘₯)=βˆ«π‘(π‘₯|𝑧)𝑝(𝑧)π‘‘π‘§π‘ž(π‘₯|𝑧)→𝑝(π‘₯|𝑧)

𝑝(𝑧) is the prior distribution, will be a predefined density function.

What we want is to learn π‘πœƒ(π‘₯|𝑧) to approximate 𝑝(π‘₯|𝑧), which is usually measured by the KL divergence. But its hard to deal with that, so we approximate 𝑝(π‘₯,𝑧) instead since 𝑝(π‘₯,𝑧)=𝑝(π‘₯|𝑧)𝑝(𝑧).

KL(𝑝(π‘₯,𝑧)β€–π‘πœƒ(π‘₯,𝑧))=βˆ¬π‘(π‘₯,𝑧)log(𝑝(π‘₯,𝑧)π‘πœƒ(π‘₯,𝑧))𝑑π‘₯𝑑𝑧=βˆ«π‘(π‘₯)[βˆ«π‘(𝑧|π‘₯)log𝑝(π‘₯)𝑝(𝑧|π‘₯)π‘πœƒ(π‘₯,𝑧)𝑑𝑧]𝑑π‘₯=𝐸π‘₯βˆΌπ‘(π‘₯)[βˆ«π‘(𝑧|π‘₯)(log𝑝(π‘₯)+log𝑝(𝑧|π‘₯)π‘πœƒ(π‘₯,𝑧))𝑑𝑧]=𝐸π‘₯βˆΌπ‘(π‘₯)[log𝑝(π‘₯)βˆ«π‘(𝑧|π‘₯)𝑑𝑧]+𝐸π‘₯βˆΌπ‘(π‘₯)[βˆ«π‘(𝑧|π‘₯)log𝑝(𝑧|π‘₯)π‘πœƒ(π‘₯,𝑧)𝑑𝑧]=𝐸π‘₯βˆΌπ‘(π‘₯)[log𝑝(π‘₯)]βˆ’πΈπ‘₯βˆΌπ‘(π‘₯)[βˆ«π‘(𝑧|π‘₯)logπ‘πœƒ(π‘₯,𝑧)𝑝(𝑧|π‘₯)𝑑𝑧]

The first term is a constant, so we only need to maximize the second term:

β„’=𝐸π‘₯βˆΌπ‘(π‘₯)[βˆ«π‘(𝑧|π‘₯)logπ‘πœƒ(π‘₯,𝑧)𝑝(𝑧|π‘₯)𝑑𝑧]=𝐸π‘₯βˆΌπ‘(π‘₯)[βˆ«π‘(𝑧|π‘₯)logπ‘πœƒ(π‘₯|𝑧)𝑝(𝑧)𝑝(𝑧|π‘₯)𝑑𝑧]=𝐸π‘₯βˆΌπ‘(π‘₯)[βˆ«π‘(𝑧|π‘₯)logπ‘πœƒ(π‘₯|𝑧)𝑑𝑧+βˆ«π‘(𝑧|π‘₯)log𝑝(𝑧)𝑝(𝑧|π‘₯)𝑑𝑧]=𝐸π‘₯βˆΌπ‘(π‘₯)[πΈπ‘§βˆΌπ‘(𝑧|π‘₯)[logπ‘πœƒ(π‘₯|𝑧)]]βˆ’KL(𝑝(𝑧|π‘₯)‖𝑝(𝑧))

This is Evidence Lower Bound (ELBO). But is maximizing the ELBO similar to doing maximum likelihood estimation (MLE)? Yes, since we can show that

𝐸π‘₯βˆΌπ‘(π‘₯)[ELBO(π‘₯)]=𝐸π‘₯βˆΌπ‘(π‘₯)[βˆ«π‘(𝑧|π‘₯)logπ‘πœƒ(π‘₯,𝑧)𝑝(𝑧|π‘₯)𝑑𝑧]=𝐸π‘₯βˆΌπ‘(π‘₯)[βˆ«π‘(𝑧|π‘₯)logπ‘πœƒ(π‘₯)π‘πœƒ(𝑧|π‘₯)𝑝(𝑧|π‘₯)𝑑𝑧]=𝐸π‘₯βˆΌπ‘(π‘₯)[logπ‘πœƒ(π‘₯)]βˆ’KL(𝑝(𝑧|π‘₯)β€–π‘πœƒ(𝑧|π‘₯))]

1.1. VAE

We choose 𝑝(𝑧) to be 𝒩(0,1), and use networks to approximate 𝑝(𝑧|π‘₯) and 𝑝(π‘₯|𝑧).

(πœ‡,𝜎2)=Β EncoderNetworkΒ πœ‘(π‘₯),π‘žπœ‘(𝑧|π‘₯)=𝒩(𝑧|πœ‡,Β diag(𝜎2))

𝑝(π‘₯|𝑧) is one-to-one mapping. We use π‘πœƒ(π‘₯|𝑧)=𝛿(π‘₯βˆ’π‘“πœƒ(𝑧)) to approximate 𝑝(π‘₯|𝑧).

π‘“πœƒ(𝑧)=Β DecoderNetworkΒ πœƒ(𝑧),π‘πœƒ(π‘₯|𝑧)=𝒩(π‘₯|π‘“πœƒ(𝑧),𝜎 decΒ 2𝐼)

where 𝜎dec is a hyperparameter. Thus the first term of the loss function is

𝐸π‘₯βˆΌπ‘(π‘₯)[πΈπ‘§βˆΌπ‘žπœ‘(𝑧|π‘₯)[βˆ’logπ‘πœƒ(π‘₯|𝑧)]]=𝐸π‘₯βˆΌπ‘(π‘₯)[πΈπ‘§βˆΌπ‘žπœ‘(𝑧|π‘₯)[βˆ’log12πœ‹πœŽΒ decexp(βˆ’(π‘₯βˆ’π‘“πœƒ(𝑧))22𝜎 dec2)]]=12𝜎 dec2𝐸π‘₯βˆΌπ‘(π‘₯)[πΈπ‘§βˆΌπ‘žπœ‘(𝑧|π‘₯)[(π‘₯βˆ’π‘“πœƒ(𝑧))2]]βˆ’log12πœ‹πœŽΒ dec

Since KL(𝑁0‖𝑁1)=12(tr(βˆ‘1βˆ’1βˆ‘0)+(πœ‡1βˆ’πœ‡0)π‘‡βˆ‘1βˆ’1(πœ‡1βˆ’πœ‡0)+log|βˆ‘1||βˆ‘0|βˆ’π‘˜), the second term is

KL(π‘žπœ‘(𝑧|π‘₯)‖𝑝(𝑧))=12(βˆ’log𝜎2+πœ‡2+𝜎2βˆ’1)

We are trying to find (πœ‘,πœƒ)=argmaxπœ‘,πœƒπΈπ‘₯βˆΌπ‘(π‘₯)[ELBO(π‘₯)].

1.1.1. Conditioned VAE (CVAE)

We define 𝐿CVAEΒ =𝐸(π‘₯,𝑦)βˆΌπ‘(π‘₯,𝑦)[πΈπ‘§βˆΌπ‘žπœ‘(𝑧|π‘₯,𝑦)[logπ‘πœƒ(𝑦|π‘₯,𝑧)]]βˆ’KL(π‘žπœ‘(𝑧|π‘₯,𝑦)β€–π‘πœƒ(𝑧|π‘₯)) and also Gaussian stochastic neural network (GSNN) with loss 𝐿GSNNΒ =𝐸(π‘₯,𝑦)βˆΌπ‘(π‘₯,𝑦)[πΈπ‘§βˆΌπ‘žπœ‘(𝑧|π‘₯)[logπ‘πœƒ(𝑦|π‘₯,𝑧)]]. The total loss is 𝐿hybridΒ =𝛼𝐿 CVAEΒ +(1βˆ’π›Ό)𝐿 GSNN.

1.1.2. 𝛽-VAE

β„’=𝐸π‘₯βˆΌπ‘(π‘₯)[πΈπ‘§βˆΌπ‘žπœ‘(𝑧|π‘₯)[logπ‘πœƒ(π‘₯|𝑧)]]βˆ’π›½KL(π‘žπœ‘(𝑧|π‘₯)‖𝑝(𝑧)), when 𝛽>1, each dimension of π‘§βˆΌπ‘žπœ‘(𝑧|π‘₯) are forced to be more independent (disentangled).

1.1.3. VAE with Discrete Latent

1.1.3.1. Gumbel-Softmax

Gumbel Max is a way to sample from a categorical distribution. We assume the probability of each category is 𝑝𝑖, then argmax𝑖(logπ‘π‘–βˆ’log(βˆ’logπœ€π‘–)),πœ€π‘–βˆΌπ‘ˆ[0,1] is equivalent to sampling from the categorical distribution, which is a reparametrization trick.

But argmax is not differentiable, so we use softmax to approximate it:

softmax(logπ‘π‘–βˆ’log(βˆ’logπœ€π‘–)𝜏),πœ€π‘–βˆΌπ‘ˆ[0,1]

Where 𝜏 is a temperature parameter. The smaller 𝜏, the more likely the result is to be one-hot.

Using Gumbel-Softmax, we can use 𝑝(𝑧)=uniform(0,π‘˜βˆ’1) instead of 𝑝(𝑧)=𝒩(0,1).

1.1.3.2. Vector-Quantization VAE (VQ-VAE)

Reduce dimensions and use PixelCNN to generate images.

In reality, we encoder π‘₯ into a π‘šΓ—π‘š grid of 𝑑-dimensional vectors. But argmin is not differentiable, so we use Straight-Through Estimator to define our own gradient and change loss function:

β€–π‘₯βˆ’decoder(π‘§π‘ž)β€–22β†’β€–π‘₯βˆ’decoder(𝑧+Β sg[π‘§π‘žβˆ’π‘§])β€–22

To make π‘§π‘ž more similar to 𝑧, we can add β€–π‘§βˆ’π‘§π‘žβ€–22 to the loss function. Decompose β€–π‘§π‘žβˆ’π‘§β€–22 into β€–sg[𝑧]βˆ’π‘§π‘žβ€–22+β€–π‘§βˆ’sg[𝑧]β€–22. The first term fixes 𝑧 and makes π‘§π‘ž closer to 𝑧 and the second term makes 𝑧 closer to π‘§π‘ž. Since π‘§π‘ž is more free to change, so the loss function is:

β€–π‘₯βˆ’decoder(𝑧+Β sg[π‘§π‘žβˆ’π‘§])β€–22+𝛽‖sg[𝑧]βˆ’π‘§π‘žβ€–22+π›Ύβ€–π‘§βˆ’sg[π‘§π‘ž]β€–22

where 𝛾<𝛽. After training, we can use 𝑝(𝑧) to train auto-regressive models like PixelCNN for better sampling.

1.1.3.3. VQ-VAE 2

Bi-level VQ-VAE, bottom level conditions on top level.

1.1.3.4. DALL-E

Discrete VAE using ResNet with 8192 codebook size & 1024 image tokens.

1.1.3.5. DALL-E 2/3

Image generation model over image embeddings.

1.1.3.6. Latent Diffusion Models (LDM)

dVAE + Transformer prior over large-scale text-image paired data