KL Divergence and its variants

1. KL Divergence and its variants

KL(𝑃𝑄)=𝐸𝑥𝑃[log𝑃(𝑥)𝑄(𝑥)]

Forward (inclusive) KL: KL(𝑃𝑄) where 𝑃 is the true distribution and 𝑄 is the approximating distribution.

  • Mode covering: Since we sample from 𝑃 and penalize when 𝑄(𝑥) is small, 𝑄 tends to cover all modes of 𝑃 (even if it means being over-dispersed). If 𝑃(𝑥)>0 but 𝑄(𝑥)0, the penalty is large.
  • Typical use: Maximize likelihood, VAE (ensures decoder explains all data)

Reverse (exclusive) KL: KL(𝑄𝑃)

  • Mode seeking: Since we sample from 𝑄 and penalize when 𝑃(𝑥) is small, 𝑄 tends to concentrate on a single mode of 𝑃 (under-dispersed but sharper). If 𝑄(𝑥)>0 but 𝑃(𝑥)0, the penalty is large.
  • Reduces exposure bias 1
  • Typical use: Variational inference, policy optimization in RL

1.1. Jensen-Shannon Divergence

JSD(𝑃𝑄)=12(KL(𝑃𝑀)+ KL(𝑄𝑀)),𝑀=12(𝑃+𝑄)

JSD measures divergence relative to the mixture distribution 𝑀, which makes it:

  • Symmetric: JSD(𝑃𝑄)=JSD(𝑄𝑃) (unlike KL)
  • Bounded: 0JSD(𝑃𝑄)1 (log 2 when distributions have disjoint support)
  • Nearly a metric: JSD(𝑃𝑄) satisfies triangle inequality
  • Balances between mode-covering and mode-seeking behavior of KL variants

1.2. Wasserstein Distance

The distribution of 𝑇(𝑥) is called the push-forward of 𝑃, denoted by 𝑇#𝑃(𝐴)=𝑃({𝑥:𝑇(𝑥)𝐴})=𝑃(𝑇1(𝐴))

The Monge version of the optimal transport distance is inf𝑇𝑥𝑇(𝑥)𝑝𝑑𝑃(𝑥) where the infimum is over all 𝑇 such that 𝑇#𝑃=𝑄. Intuitively, this measures how far you have to move the mass of 𝑃 to turn it into 𝑄. A minimizer 𝑇, if one exists, is called the optimal transport map.

Let Π(𝑃,𝑄) denote all joint distributions 𝜋 for (𝑋,𝑌) that have marginals 𝑃 and 𝑄. In other words, 𝑇𝑋#𝜋=𝑃 and 𝑇𝑌#𝜋=𝑄 where 𝑇𝑋(𝑥,𝑦)=𝑥 and 𝑇𝑌(𝑥,𝑦)=𝑦. Then the Wasserstein distance is

𝑊𝑝(𝑃,𝑄)=(inf𝛾Π(𝑃,𝑄)𝑥𝑦𝑝𝑑𝛾(𝑥,𝑦))1𝑝=(inf𝛾Π(𝑃,𝑄)𝐸𝑥,𝑦𝛾[𝑥𝑦𝑝])1𝑝

where 𝑝1. When 𝑝=1, this is also called the Earth Mover's Distance.

It can be shown from Kantorovich Rubinstein Duality that

𝑊𝑝𝑝(𝑃,𝑄)=𝜓,𝜑𝜓(𝑦)𝑑𝑄(𝑦)𝜑(𝑥)𝑑𝑃(𝑥)

where 𝜓(𝑦)𝜑(𝑥)𝑥𝑦𝑝. When 𝑝=1, we have

𝑊1(𝑃,𝑄)=sup𝑇𝐿1𝐸𝑥𝑃[𝑇(𝑥)]𝐸𝑦𝑄[𝑇(𝑦)]

where 𝑇𝐿1 means |𝑇(𝑥)𝑇(𝑦)|𝑥𝑦.

When to use Wasserstein Distance instead of KL:

  • Non-overlapping distributions: KL divergence becomes infinite (or undefined) when distributions have non-overlapping support, while WD remains finite and meaningful. This is critical in high-dimensional spaces where distributions rarely overlap perfectly.
  • Meaningful gradients: Even when distributions barely overlap, WD provides useful gradients for optimization. This is why Wasserstein GAN (WGAN)) works better than vanilla GAN - it can still learn when the generator distribution is far from the real data distribution.
  • True metric: WD is a proper distance metric (satisfies triangle inequality), making it more suitable for geometric interpretations and certain theoretical analyses.
  • Weak topology: WD convergence is weaker than KL convergence, meaning 𝑊(𝑃𝑛,𝑃)0 implies convergence in distribution, which is often more natural for generative modeling.

1.3. Fisher Divergence

𝐹(𝑃𝑄)=12𝐸𝑥𝑃[𝑥log𝑃(𝑥)𝑥log𝑄(𝑥)22]

Unlike KL divergence which compares probability values, Fisher divergence compares the score functions (gradients of log probabilities). This makes it particularly useful when:

  • Dealing with unnormalized distributions (only need score functions, not normalization constants)
  • Training score-based generative models
  • The score function is more well-behaved than the density itself

1.4. Applications

  • Variational Inference: Reverse KL (mode-seeking behavior prevents over-dispersed approximations)
  • GAN: JSD (symmetric, bounded measure between real and generated distributions)
  • WGAN: Wasserstein Distance (stable training with meaningful gradients)
  • VAE: Forward KL (mode-covering ensures all data modes are explained)
  • RL (e.g., PPO, TRPO): Reverse KL (prevents policy from assigning probability to bad actions)
  • Score-based models (diffusion): Fisher Divergence (training without normalized densities)