Continuous Normalizing Flows

1. Continuous Normalizing Flows (CNF)

CNFs are a particular case of Neural ODE networks, with additional tricks to compute the likelihood in order to train them. Given a data point 𝑥(𝑖), we want to know 𝑝1(𝑥(𝑖)).

Directly computing 𝑝1(𝑥(𝑖)) is intractable, so we use similar approach to Change of Variables. From the transport equation::

𝑝𝑡(𝑥)𝑡=((𝑢𝜃𝑝𝑡))(𝑥)

By following the Lagrangian perspective (tracking individual particles)1, we have:

Instantaneous Change of Variables: Let 𝑥𝑡 be a finite continuous random variable with probability 𝑝𝑡(𝑥𝑡) dependent on time. Let 𝑑𝑥𝑡𝑑𝑡=𝑢𝜃(𝑥𝑡,𝑡) be a differential equation describing a continuous-in-time transformation of 𝑥𝑡. Assuming that 𝑢𝜃 is uniformly Lipschitz continuous in 𝑥 and continuous in 𝑡, then the change in log probability also follows a differential equation,

𝑑log𝑝𝑡(𝑥𝑡)𝑑𝑡=tr(𝑑𝑢𝜃𝑑𝑥𝑡)=tr(𝐽𝑢𝜃(𝑥𝑡))=(𝑢𝜃)(𝑥𝑡)

Thus for a given data point 𝑥(𝑖) at time 𝑡=1, we can compute its log-likelihood by solving the following system of ODEs backwards in time:

{𝑑𝑥𝑡𝑑𝑡=𝑢𝜃(𝑥𝑡,𝑡)𝑑log𝑝𝑡(𝑥𝑡)𝑑𝑡=(𝑢𝜃)(𝑥𝑡)

Starting from 𝑥1=𝑥(𝑖) and integrating from 𝑡=1 to 𝑡=0, we obtain log𝑝0(𝑥0)=log𝑝1(𝑥(𝑖))+01(𝑢𝜃)(𝑥𝑡)𝑑𝑡, which equivalent to log𝑝1(𝑥(𝑖))=log𝑝0(𝑥0)01(𝑢𝜃)(𝑥𝑡)𝑑𝑡.

The main benefits of continuous NF are:

  • The constraints one needs to impose on 𝑢 are much less stringent than in the discrete case2
  • Inverting the flow can be achieved by simply solving the ODE in reverse
  • Computing the likelihood does not require inverting the flow, nor to compute a log determinant; only the trace of the Jacobian is required, that can be approximated using the Hutchinson trick.3

However, training a neural ODE with log-likelihood does not scale well to high-dimensional spaces, and the process tends to be unstable, likely due to numerical approximations and to the (infinite) number of possible probability paths.

    1. Starting from the transport equation in Eulerian perspective:

      𝑝𝑡(𝑥)𝑡=((𝑢𝜃𝑝𝑡))(𝑥)

      Expanding the divergence:

      𝑝𝑡(𝑥)𝑡=𝑝𝑡(𝑥)(𝑢𝜃)(𝑥)𝑢𝜃(𝑥)𝑝𝑡(𝑥)

      Dividing by 𝑝𝑡(𝑥):

      log𝑝𝑡(𝑥)𝑡=(𝑢𝜃)(𝑥)𝑢𝜃(𝑥)log𝑝𝑡(𝑥)

      For the Lagrangian perspective, we consider the total derivative along a particle trajectory 𝑥𝑡 satisfying 𝑑𝑥𝑡𝑑𝑡=𝑢𝜃(𝑥𝑡,𝑡):

      𝑑𝑑𝑡log𝑝𝑡(𝑥𝑡)=log𝑝𝑡𝑡|𝑥𝑡+log𝑝𝑡(𝑥𝑡)𝑑𝑥𝑡𝑑𝑡

      Substituting 𝑑𝑥𝑡𝑑𝑡=𝑢𝜃(𝑥𝑡,𝑡):

      𝑑𝑑𝑡log𝑝𝑡(𝑥𝑡)=log𝑝𝑡𝑡|𝑥𝑡+log𝑝𝑡(𝑥𝑡)𝑢𝜃(𝑥𝑡,𝑡)

      Using the Eulerian result above:

      𝑑𝑑𝑡log𝑝𝑡(𝑥𝑡)=(𝑢𝜃)(𝑥𝑡)𝑢𝜃(𝑥𝑡)log𝑝𝑡(𝑥𝑡)+log𝑝𝑡(𝑥𝑡)𝑢𝜃(𝑥𝑡,𝑡)=(𝑢𝜃)(𝑥𝑡)

      The last two terms cancel out, yielding the instantaneous change of variables formula.

    2. Note that the function 𝑓 in the discrete case needs to be invertible, which is a strong constraint.
    3. The Hutchinson trick estimates the trace of a matrix by averaging 𝑣𝑇𝐴𝑣 over random vectors 𝑣 with zero mean and unit variance, avoiding explicit computation of the full Jacobian.

References

  1. Neural Ordinary Differential Equations
  2. A Visual Dive into Conditional Flow Matching