is the prior distribution, will be a predefined density function.
is the prior distribution, will be a predefined density function.
What we want is to learn to approximate , which is usually measured by the KL divergence. But its hard to deal with that, so we approximate instead since .
The first term is a constant, so we only need to maximize the second term:
This is Evidence Lower Bound (ELBO). But is maximizing the ELBO similar to doing maximum likelihood estimation (MLE)? Yes, since we can show that
1.1. VAE
We choose to be , and use networks to approximate and .
is one-to-one mapping. We use to approximate .
where is a hyperparameter. Thus the first term of the loss function is
Since , the second term is
We are trying to find .
1.1.1. Conditioned VAE (CVAE)
We define and also Gaussian stochastic neural network (GSNN) with loss . The total loss is .
1.1.2. -VAE
, when , each dimension of are forced to be more independent (disentangled).
1.1.3. VAE with Discrete Latent
1.1.3.1. Gumbel-Softmax
Gumbel Max is a way to sample from a categorical distribution. We assume the probability of each category is , then is equivalent to sampling from the categorical distribution, which is a reparametrization trick.
But is not differentiable, so we use softmax to approximate it:
Where is a temperature parameter. The smaller , the more likely the result is to be one-hot.
Using Gumbel-Softmax, we can use instead of .
1.1.3.2. Vector-Quantization VAE (VQ-VAE)
Reduce dimensions and use PixelCNN to generate images.
In reality, we encoder into a grid of -dimensional vectors. But is not differentiable, so we use Straight-Through Estimator to define our own gradient and change loss function:
To make more similar to , we can add to the loss function. Decompose into . The first term fixes and makes closer to and the second term makes closer to . Since is more free to change, so the loss function is:
where . After training, we can use to train auto-regressive models like PixelCNN for better sampling.
1.1.3.3. VQ-VAE 2
Bi-level VQ-VAE, bottom level conditions on top level.
1.1.3.4. DALL-E
Discrete VAE using ResNet with 8192 codebook size & 1024 image tokens.
1.1.3.5. DALL-E 2/3
Image generation model over image embeddings.
1.1.3.6. Latent Diffusion Models (LDM)
dVAE + Transformer prior over large-scale text-image paired data