Stochastic Interpolants: A Unifying Framework for Flows and Diffusions

1. Stochastic Interpolants: A Unifying Framework for Flows and Diffusions

Notation: We denote probability density functions as 𝜌0(π‘₯), 𝜌1(π‘₯), and 𝜌(𝑑,π‘₯), with π‘‘βˆˆ[0,1] and π‘₯βˆˆπ‘…π‘‘, omitting the function arguments when clear from the context. 𝐢1([0,1]) is the space of continuously differentiable functions from [0,1] to 𝑅, (𝐢2(𝑅𝑑))𝑑 is the space of twice continuously differentiable functions from 𝑅𝑑 to 𝑅𝑑, and 𝐢0𝑝(𝑅𝑑) is the space of compactly supported functions from 𝑅𝑑 to 𝑅 that are continuously differentiable 𝑝 times.

1.1. Stochastic Interpolants

Stochastic interpolant: Given two probability density functions 𝜌0,𝜌1:𝑅𝑑→𝑅β‰₯0, a stochastic interpolant between 𝜌0 and 𝜌1 is a stochastic process π‘₯𝑑 defined as

π‘₯𝑑=𝐼(𝑑,π‘₯0,π‘₯1)+𝛾(𝑑)𝑧,π‘‘βˆˆ[0,1],

where

  1. 𝐼∈𝐢2([0,1],(𝐢2(𝑅𝑑×𝑅𝑑))𝑑) satisfies the boundary conditions 𝐼(0,π‘₯0,π‘₯1)=π‘₯0 and 𝐼(1,π‘₯0,π‘₯1)=π‘₯1, as well as

    βˆƒπΆ1<∞:|βˆ‚π‘‘πΌ(𝑑,π‘₯0,π‘₯1)|≀𝐢1|π‘₯0βˆ’π‘₯1|βˆ€(𝑑,π‘₯0,π‘₯1)∈[0,1]×𝑅𝑑×𝑅𝑑

    We can think of 𝐼 as a planned path from π‘₯0 to π‘₯1 that is smooth. This states that 𝐼 does not move too fast along the way from π‘₯0 at 𝑑=0 to π‘₯1 at 𝑑=1, and as a result does not wander too far from either endpoint - this assumption is made for convenience but is not necessary for most arguments below.

  2. 𝛾:[0,1]→𝑅 satisfies 𝛾(0)=𝛾(1)=0,𝛾(𝑑)>0 for all π‘‘βˆˆ(0,1), and 𝛾2∈𝐢2([0,1])
  3. The pair (π‘₯0,π‘₯1) is drawn from a probability measure 𝜈 that marginalizes on 𝜌0 and 𝜌1, i.e. 𝜈(𝑑π‘₯0,𝑅𝑑)=𝜌0(π‘₯0)𝑑π‘₯0, 𝜈(𝑅𝑑,𝑑π‘₯1)=𝜌1(π‘₯1)𝑑π‘₯1. The measure 𝜈 allows for a coupling between the two densities 𝜌0 and 𝜌1, which affects the properties of the stochastic interpolant, but a simple choice is to take the product measure 𝜈(𝑑π‘₯0,𝑑π‘₯1)=𝜌0(π‘₯0)𝜌1(π‘₯1)𝑑π‘₯0𝑑π‘₯1, in which case π‘₯0 and π‘₯1 are independent.
  4. 𝑧 is a Gaussian random variable independent of (π‘₯0,π‘₯1), i.e. π‘§βˆΌπ‘(0,𝐼𝑑) and π‘§βŸ‚(π‘₯0,π‘₯1)

Given the above definition, we want to characterize the properties of the time dependent probability distribution πœ‡(𝑑,𝑑π‘₯)1 such that

βˆ€π‘‘βˆˆ[0,1]:βˆ«π‘…π‘‘πœ‘(π‘₯)πœ‡(𝑑,𝑑π‘₯)=𝐸[πœ‘(π‘₯𝑑)]Β forΒ anyΒ testΒ functionΒ πœ‘βˆˆπΆπ‘βˆž(𝑅𝑑)

and we have the following property:

βˆ«π‘…π‘‘πΈ[𝑓(𝑑,π‘₯0,π‘₯1,𝑧)|π‘₯𝑑=π‘₯]πœ‡(𝑑,𝑑π‘₯)=𝐸[𝑓(𝑑,π‘₯0,π‘₯1,𝑧)]

1.2. Stochastic Interpolant Properties

The most important property of the probability distribution of the stochastic interpolant π‘₯𝑑 is:

Stochastic interpolant properties: The probability distribution of the stochastic interpolant π‘₯𝑑 is absolutely continuous with respect to the Lebesgue measure at all times π‘‘βˆˆ[0,1] and its time-dependent density 𝜌(𝑑) satisfies 𝜌(0)=𝜌0 and 𝜌(1)=𝜌1, 𝜌∈𝐢1([0,1];𝐢𝑝(𝑅𝑑)) for any π‘βˆˆπ‘, and 𝜌(𝑑,π‘₯)>0 for all (𝑑,π‘₯)∈[0,1]×𝑅𝑑. In addition, 𝜌 solves the transport equation (TE)

βˆ‚π‘‘πœŒ+βˆ‡β‹…(π‘πœŒ)=0,

where we defined the velocity

𝑏(𝑑,π‘₯)=𝐸[π‘₯̇𝑑|π‘₯𝑑=π‘₯]=𝐸[βˆ‚π‘‘πΌ(𝑑,π‘₯0,π‘₯1)+𝛾̇(𝑑)𝑧|π‘₯𝑑=π‘₯].

This velocity is in 𝐢0([0,1];(𝐢𝑝(𝑅𝑑))𝑑) for any π‘βˆˆπ‘, and such that

βˆ€π‘‘βˆˆ[0,1]:βˆ«π‘…π‘‘|𝑏(𝑑,π‘₯)|2𝜌(𝑑,π‘₯)𝑑π‘₯<∞.
  1. For flow-based models (Objective), the objective is
ℒ𝑏[𝑏̂]=∫01𝔼(12|𝑏̂(𝑑,π‘₯𝑑)|2βˆ’(βˆ‚π‘‘πΌ(𝑑,π‘₯0,π‘₯1)+𝛾̇(𝑑)𝑧)·𝑏̂(𝑑,π‘₯𝑑))𝑑𝑑
  1. For score-based/diffusion models (Score), the score is given by

    𝑠(𝑑,π‘₯)=βˆ‡log𝜌(𝑑,π‘₯)=βˆ’π›Ύβˆ’1(𝑑)𝐸(𝑧|π‘₯𝑑=π‘₯)βˆ€(𝑑,π‘₯)∈(0,1)×𝑅𝑑

    and the objective is

    ℒ𝑠[𝑠̂]=∫01𝔼(12|𝑠̂(𝑑,π‘₯𝑑)|2+π›Ύβˆ’1(𝑑)𝑧·𝑠̂(𝑑,π‘₯𝑑))𝑑𝑑
  2. For energy-based models (Energy), if we model 𝑠̂(𝑑,π‘₯)=βˆ’βˆ‡πΈΜ‚(𝑑,π‘₯),

    ℒ𝐸[𝐸̂]=∫01𝔼(12|𝐸̂(𝑑,π‘₯𝑑)|2+π›Ύβˆ’1(𝑑)𝑧·𝐸̂(𝑑,π‘₯𝑑))𝑑𝑑

Having access to the score immediately allows us to rewrite the TE as forward and backward Fokker-Planck equations, which we state as:

Fokker-Planck equations (FPE): For any πœ€βˆˆπΆ0([0,1]) with πœ€(𝑑)β‰₯0 for all π‘‘βˆˆ[0,1], the probability density 𝜌 satisfies:

  1. The forward Fokker-Planck equation

    βˆ‚π‘‘πœŒ+βˆ‡β‹…(π‘πΉπœŒ)=πœ€(𝑑)Ξ”πœŒ,𝜌(0)=𝜌0,

    where we defined the forward drift

    𝑏𝐹(𝑑,π‘₯)=𝑏(𝑑,π‘₯)+πœ€(𝑑)𝑠(𝑑,π‘₯).

    The forward Fokker-Planck equation is well-posed when solved forward in time from 𝑑=0 to 𝑑=1, and its solution for the initial condition 𝜌(𝑑=0)=𝜌0 satisfies 𝜌(𝑑=1)=𝜌1.

  2. The backward Fokker-Planck equation

    βˆ‚π‘‘πœŒ+βˆ‡β‹…(π‘π΅πœŒ)=βˆ’πœ€(𝑑)Ξ”πœŒ,𝜌(1)=𝜌1,

    where we defined the backward drift

    𝑏𝐡(𝑑,π‘₯)=𝑏(𝑑,π‘₯)βˆ’πœ€(𝑑)𝑠(𝑑,π‘₯).

    The backward Fokker-Planck equation is well-posed when solved backward in time from 𝑑=1 to 𝑑=0, and its solution for the final condition 𝜌(1)=𝜌1 satisfies 𝜌(0)=𝜌0.

We design generative models using the stochastic processes associated with the TE, the forward FPE, and the backward FPE:

At any time π‘‘βˆˆ[0,1], the law of the stochastic interpolant π‘₯𝑑 coincides with the law of the three processes 𝑋𝑑, 𝑋𝑑𝐹, and 𝑋𝑑𝐡, respectively defined as:

  1. The solutions of the probability flow associated with the transport equation

    𝑑𝑑𝑑𝑋𝑑=𝑏(𝑑,𝑋𝑑),

    solved either forward in time from the initial data 𝑋𝑑=0∼𝜌0 or backward in time from the final data 𝑋𝑑=1=π‘₯1∼𝜌1.

  2. The solutions of the forward SDE associated with the FPE

    𝑑𝑋𝑑𝐹=𝑏𝐹(𝑑,𝑋𝑑𝐹)𝑑𝑑+2πœ€(𝑑)π‘‘π‘Šπ‘‘,

    solved forward in time from the initial data 𝑋𝑑=0𝐹∼𝜌0 independent of π‘Š.

  3. The solutions of the backward SDE associated with the backward FPE

    𝑑𝑋𝑑𝐡=𝑏𝐡(𝑑,𝑋𝑑𝐡)𝑑𝑑+2πœ€(𝑑)π‘‘π‘Šπ‘‘π΅,π‘Šπ‘‘π΅=βˆ’π‘Š1βˆ’π‘‘,

    solved backward in time from the final data 𝑋𝑑=1𝐡∼𝜌1 independent of π‘Šπ΅; the solution is by definition 𝑋𝑑𝐡=𝑍1βˆ’π‘‘πΉ where 𝑍𝑑𝐹 satisfies

    𝑑𝑍𝑑𝐹=βˆ’π‘π΅(1βˆ’π‘‘,𝑍𝑑𝐹)𝑑𝑑+2πœ€(𝑑)π‘‘π‘Šπ‘‘,

    solved forward in time from the initial data 𝑍𝑑=0𝐹∼𝜌1 independent of π‘Š.2

1.3. Instantiation

We connect the diffusion bridge perspective to the stochastic interpolant perspective by setting π‘₯𝑑𝑑=𝐼(𝑑,π‘₯0,π‘₯1)+2π‘Ž(𝑑)𝐡𝑑, where 𝐡𝑑 is a standard Brownian bridge process3, independent of π‘₯0 and π‘₯1. With some deduction we can know that π‘Ž=πœ€ and 𝛾(𝑑)=2π‘Ž(𝑑)𝑑(1βˆ’π‘‘), i.e.

π‘₯𝑑=𝐼(𝑑,π‘₯0,π‘₯1)+2π‘Ž(𝑑)𝑑(1βˆ’π‘‘)𝑧,

Using ItΓ΄ calculus and we can get the drift 𝑒(𝑑,π‘₯), but this requires many tedious calculations.

Given the stochastic interpolant perspective, we can write out

𝑏(𝑑,π‘₯)=𝐸(βˆ‚π‘‘πΌ(𝑑,π‘₯0,π‘₯1)+π‘Ž(1βˆ’2𝑑)𝑧2𝑑(1βˆ’π‘‘)|π‘₯𝑑=π‘₯),𝑠(𝑑,π‘₯)=βˆ‡log𝜌(𝑑,π‘₯)=βˆ’12π‘Žπ‘‘(1βˆ’π‘‘)𝐸(𝑧|π‘₯𝑑=π‘₯).

And using 𝑒=𝑏+π‘Žπ‘ , we have

𝑒(𝑑,π‘₯)=𝐸(βˆ‚π‘‘πΌ(𝑑,π‘₯0,π‘₯1)βˆ’2π‘Žπ‘‘π‘§1βˆ’π‘‘|π‘₯𝑑=π‘₯).
    1. We use probability measure πœ‡(𝑑,𝑑π‘₯) instead of density function 𝜌(𝑑,π‘₯) because the latter is not well defined when there's no smooth density function (Like the Dirac delta function). But in most cases, you can think of πœ‡(𝑑,𝑑π‘₯) as 𝜌(𝑑,π‘₯)𝑑π‘₯.
    2. To avoid repeated applications of the transformation 𝑑→1βˆ’π‘‘, it is convenient to directly use the reversed ItΓ΄ calculus rules stated in the following lemma:

      Reverse ItΓ΄ Calculus: If 𝑋𝑑𝐡 solves the backward SDE:

      1. For any π‘“βˆˆπΆ1([0,1];𝐢02(𝑅𝑑)) and π‘‘βˆˆ[0,1], the backward ItΓ΄ formula holds

        𝑑𝑓(𝑑,𝑋𝑑𝐡)=βˆ‚π‘‘π‘“(𝑑,𝑋𝑑𝐡)𝑑𝑑+βˆ‡π‘“(𝑋𝑑𝐡)β‹…π‘‘π‘‹π‘‘π΅βˆ’πœ€(𝑑)Δ𝑓(𝑑,𝑋𝑑𝐡)𝑑𝑑.
      2. For any π‘”βˆˆπΆ0([0,1];(𝐢0(𝑅𝑑))𝑑) and π‘‘βˆˆ[0,1], the backward ItΓ΄ isometries hold:

        𝐸𝐡π‘₯[βˆ«π‘‘1𝑔(𝑑,𝑋𝑑𝐡)β‹…π‘‘π‘Šπ‘‘π΅]=0;𝐸𝐡π‘₯[|βˆ«π‘‘1𝑔(𝑑,𝑋𝑑𝐡)β‹…π‘‘π‘Šπ‘‘π΅|2]=βˆ«π‘‘1𝐸𝐡π‘₯[|𝑔(𝑑,𝑋𝑑𝐡)|2]𝑑𝑑,

        where 𝐸𝐡π‘₯ denotes expectation conditioned on the event 𝑋𝑑=1𝐡=π‘₯.

    3. A Brownian Bridge is a stochastic process that describes a random path, similar to a standard Brownian motion, but with the crucial constraint that it is "pinned" to a specific value (usually zero) at both its start and end times, which can be written as 𝐡𝑑=π‘Šπ‘‘βˆ’π‘‘π‘Š1. Consequently, its randomness, or variance, is zero at the beginning and end, and reaches its maximum in the middle of the time interval.

References

  1. Stochastic Interpolants: A Unifying Framework for Flows and Diffusions