Flow map matching

1. Flow map matching (FMM)

The central object in our method is the flow map, which maps points along trajectories of solutions to an ordinary differential equation (ODE)

1.1. Stochastic interpolants and probability flows

Stochastic interpolant: 𝐼𝑑=𝛼𝑑π‘₯0+𝛽𝑑π‘₯1+𝛾𝑑𝑧, where 𝛼0=𝛽1=1,𝛼1=𝛽0=0, and 𝛾0=𝛾1=0.

Probability flow: The probability density of 𝐼𝑑 is the solution to

π‘₯̇𝑑=𝑏𝑑(π‘₯𝑑),π‘₯𝑑=0=π‘₯0∼𝜌0

where 𝑏𝑑(π‘₯)=𝐸[𝐼̇𝑑|𝐼𝑑=π‘₯]. The drift 𝑏 can be learned efficiently in practice by solving a square loss regression problem

𝑏=argminπ‘Μ‚βˆ«01𝐸[|𝑏̂𝑑(𝐼𝑑)βˆ’πΌΜ‡π‘‘|2]𝑑𝑑

1.2. Flow map: definition and characterizations

Flow map: The flow map 𝑋𝑠,𝑑:𝑅𝑑→𝑅𝑑 is the unique map such that

𝑋𝑠,𝑑(π‘₯𝑠)=π‘₯𝑑 forΒ allΒ (𝑠,𝑑)∈[0,1]2,

where (π‘₯𝑑)π‘‘βˆˆ[0,1] is any solution to the ODE.

Tangent condition: Let 𝑋𝑠,𝑑 denote the flow map. Then

limπ‘ β†’π‘‘βˆ‚π‘‘π‘‹π‘ ,𝑑(π‘₯)=𝑏𝑑(π‘₯)βˆ€π‘‘βˆˆ[0,1],βˆ€π‘₯βˆˆπ‘…π‘‘.

We define 𝑣𝑠,𝑑 as the exact remainder obtained by truncating a Taylor expansion in π‘‘βˆ’π‘  of 𝑋𝑠,𝑑(π‘₯) at first order

𝑋𝑠,𝑑(π‘₯)=π‘₯+(π‘‘βˆ’π‘ )𝑣𝑠,𝑑(π‘₯),𝑣𝑑,𝑑(π‘₯)=𝑏𝑑(π‘₯)

Geometrically, 𝑣𝑠,𝑑 describes the β€œslope” of the line drawn between π‘₯𝑠 and π‘₯𝑑 on a single ODE trajectory.

Some of its useful properties: The flow map 𝑋𝑠,𝑑(π‘₯) is the unique solution to the Lagrangian equation

βˆ‚π‘‘π‘‹π‘ ,𝑑(π‘₯)=𝑏𝑑(𝑋𝑠,𝑑(π‘₯)),𝑋𝑠,𝑠(π‘₯)=π‘₯,

for all (𝑠,𝑑,π‘₯)∈[0,1]2×ℝ𝑑. In addition, it satisfies

𝑋𝑑,𝜏(𝑋𝑠,𝑑(π‘₯))=𝑋𝑠,𝜏(π‘₯)

for all (𝑠,𝑑,𝜏,π‘₯)∈[0,1]3×ℝ𝑑. In particular, 𝑋𝑠,𝑑(𝑋𝑠,𝑑(π‘₯))=π‘₯ for all (𝑠,𝑑,π‘₯)∈[0,1]2×ℝ𝑑, i.e., the flow map is invertible.

The flow map 𝑋𝑠,𝑑 is the unique solution of the Eulerian equation,

βˆ‚π‘‘π‘‹π‘ ,𝑑(π‘₯)+𝑏𝑠(π‘₯)Β·βˆ‡π‘‹π‘ ,𝑑(π‘₯)=0,𝑋𝑠,𝑠(π‘₯)=π‘₯,

for all (𝑠,𝑑,π‘₯)∈[0,1]2×ℝ𝑑.

1.3. Flow map training

1.3.1. Distillation of a known velocity field

Lagrangian map distillation: Let 𝑀𝑠,π‘‘βˆˆπΏ1([0,1]2) be a weight function satisfying 𝑀𝑠,𝑑>0 and let 𝐼𝑠 be the stochastic interpolant. Then the flow map is the global minimizer over 𝑋̂ of the loss

β„’LMD(𝑋̂)=∫[0,1]2𝑀𝑠,𝑑𝐸[|βˆ‚π‘‘π‘‹Μ‚π‘ ,𝑑(𝐼𝑠)βˆ’π‘π‘‘(𝑋̂𝑠,𝑑(𝐼𝑠))|2]𝑑𝑠𝑑𝑑,

subject to the boundary condition that 𝑋̂𝑠,𝑠(π‘₯)=π‘₯ for all π‘₯βˆˆπ‘…π‘‘ and π‘ βˆˆ[0,1]. 𝐸 denotes an expectation over the coupling (π‘₯0,π‘₯1)∼𝜌0(π‘₯0,π‘₯1) and π‘§βˆΌπ‘(0,𝐼𝑑).

Eulerian map distillation: The flow map is the global minimizer over 𝑋̂ of the loss

β„’EMD(𝑋̂)=∫[0,1]2𝑀𝑠,𝑑𝐸[|βˆ‚π‘ π‘‹Μ‚π‘ ,𝑑(𝐼𝑠)+𝑏𝑠(𝐼𝑠)Β·βˆ‡π‘‹Μ‚π‘ ,𝑑(𝐼𝑠)|2]𝑑𝑠𝑑𝑑,
1.3.1.1. From Distillation to Direct Training: The stopgrad Necessity

The distillation losses β„’LMD and β„’EMD assume that we have access to the true, smooth drift field 𝑏𝑑. A natural question arises: what if 𝑏𝑑 is unknown and we only have access to samples from the stochastic interpolant, including the noisy velocity 𝐼̇𝑑?

A naive approach might be to simply replace the true drift 𝑏𝑠 with its single-sample, noisy estimate 𝐼̇𝑠 in the loss function. For example, the Eulerian loss would become:

β„’Naive(𝑋̂)=∫[0,1]2𝑀𝑠,𝑑𝐸[|βˆ‚π‘ π‘‹Μ‚π‘ ,𝑑(𝐼𝑠)+πΌΜ‡π‘ Β·βˆ‡π‘‹Μ‚π‘ ,𝑑(𝐼𝑠)|2]𝑑𝑠𝑑𝑑,

However, this naive objective is flawed and will not converge to the correct flow map 𝑋.

The issue lies in the relationship 𝑏𝑠(π‘₯)=𝐸[𝐼̇𝑠|𝐼𝑠=π‘₯]. The term 𝐼̇𝑠 is a random variable, while 𝑏𝑠(𝐼𝑠) is its conditional mean. Due to the property 𝐸[π‘Œ2]=(𝐸[π‘Œ])2+Var(π‘Œ), the naive loss implicitly contains an extra variance term:

𝐸[|βˆ‚π‘ π‘‹Μ‚+πΌΜ‡π‘ Β·βˆ‡π‘‹Μ‚|2]=𝐸[|βˆ‚π‘ π‘‹Μ‚+π‘π‘ Β·βˆ‡π‘‹Μ‚|2]+𝐸[Var(πΌΜ‡π‘ Β·βˆ‡π‘‹Μ‚|𝐼𝑠)]

This extra variance term acts as a penalty that depends on βˆ‡π‘‹Μ‚. To minimize the total loss, the optimizer is incentivized to find a solution 𝑋̂ with an artificially small gradient βˆ‡π‘‹Μ‚, leading to a biased and incorrect result.

To counteract this, a common technique is to use a stop-gradient operator. The operator, stopgrad(z), allows z to pass through during the forward pass but blocks gradients from flowing back through it during optimization. A corrected Eulerian loss would look like:

β„’EEΒ =βˆ«π‘€π‘ ,𝑑𝐸[|βˆ‚π‘ π‘‹Μ‚π‘ ,𝑑(𝐼𝑠)+stopgrad(πΌΜ‡π‘ Β·βˆ‡π‘‹Μ‚π‘ ,𝑑(𝐼𝑠))|2]𝑑𝑠𝑑𝑑

By blocking the gradient from the noisy term, we can ensure that the expected gradient of the loss is zero at the true solution, making it a valid objective.

This challenge of handling noisy velocities directly is a primary motivation for developing more sophisticated objectives like Flow Map Matching (FMM), which we introduce next. FMM provides an alternative, well-posed loss function for direct training.

1.3.2. Direct training with flow map matching (FMM)

Flow map matching: The flow map is the global minimizer over 𝑋̂ of the loss

β„’FMM(𝑋̂)=∫[0,1]2𝑀𝑠,𝑑(𝐸[|βˆ‚π‘‘π‘‹Μ‚π‘ ,𝑑(𝑋̂𝑑,𝑠(𝐼𝑑))βˆ’πΌΜ‡π‘‘|2]+𝐸[|𝑋̂𝑠,𝑑(𝑋̂𝑑,𝑠(𝐼𝑑))βˆ’πΌπ‘‘|2])𝑑𝑠𝑑𝑑,

1.3.3. Progressive distillation

Progressive flow map matching: Let 𝑋̂ be a two-time flow map. Given πΎβˆˆβ„•, let π‘‘π‘˜=𝑠+π‘˜βˆ’1πΎβˆ’1(π‘‘βˆ’π‘ ) for π‘˜=1,…,𝐾. Then the objective

β„’PFMM(𝑋̂)=∫[0,1]2𝑀𝑠,𝑑𝐸[|𝑋̂𝑠,𝑑(𝐼𝑠)βˆ’(π‘‹Μ‚π‘‘πΎβˆ’1,π‘‘πΎβˆ˜π‘‹Μ‚π‘‘πΎβˆ’2,π‘‘πΎβˆ’1βˆ˜β€¦βˆ˜π‘‹Μ‚π‘‘1,𝑑2)(𝐼𝑠)|2]𝑑𝑠𝑑𝑑,

produces the same output in one step as the 𝐾-step iterated map 𝑋̂.

1.3.4. Self-distillation

Self-distillation: The flow map 𝑋𝑠,𝑑 is given for all 0≀𝑠≀𝑑≀1 by 𝑋𝑠,𝑑=π‘₯+(π‘‘βˆ’π‘ )𝑣𝑠,𝑑(π‘₯) where 𝑣𝑠,𝑑(π‘₯) the unique minimizer over 𝑣̂ of

ℒ𝑆𝐷(𝑣̂)=ℒ𝑏(𝑣̂)+ℒ𝐷(𝑣̂),

where ℒ𝑏(𝑣̂) is given by

ℒ𝑏(𝑣̂)=∫01𝐸π‘₯0,π‘₯1[|𝑣̂𝑑,𝑑(𝐼𝑑)βˆ’πΌΜ‡π‘‘|2]𝑑𝑑,

and where ℒ𝐷(𝑣̂) is any linear combination of the following three objectives:

(i) The Lagrangian self-distillation (LSD) objective,

ℒ𝐷LSD(𝑣̂)=∫01∫0𝑑𝐸π‘₯0,π‘₯1[|βˆ‚π‘‘π‘‹Μ‚π‘ ,𝑑(𝐼𝑠)βˆ’π‘£Μ‚π‘‘,𝑑(𝑋̂𝑠,𝑑(𝐼𝑠))|2]𝑑𝑠𝑑𝑑,

(ii) The Eulerian self-distillation (ESD) objective,

ℒ𝐷ESD(𝑣̂)=∫01∫0𝑑𝐸π‘₯0,π‘₯1[|βˆ‚π‘ π‘‹Μ‚π‘ ,𝑑(𝐼𝑠)+βˆ‡π‘‹Μ‚π‘ ,𝑑(𝐼𝑠)𝑣̂𝑑,𝑑(𝐼𝑠)|2]𝑑𝑠𝑑𝑑;

(iii) The progressive self-distillation (PSD) objective,

ℒ𝐷PSD(𝑣̂)=∫01∫0π‘‘βˆ«π‘ π‘‘πΈπ‘₯0,π‘₯1[|𝑋̂𝑠,𝑑(𝐼𝑠)βˆ’π‘‹Μ‚π‘’,𝑑(𝑋̂𝑠,𝑒(𝐼𝑠))|2]𝑑𝑒𝑑𝑠𝑑𝑑.

Above, 𝑋̂𝑠,𝑑(π‘₯)=π‘₯+(π‘‘βˆ’π‘ )𝑣̂𝑠,𝑑(π‘₯) and 𝐸π‘₯0,π‘₯1 denotes an expectation over the random draw of (π‘₯0,π‘₯1).

1.3.5. Align your flow (AYF)

The first training objective aims to ensure that for a fixed 𝑠, the output of the flow map remains constant as we move (π‘₯𝑑,𝑑) along the PF-ODE.

AYF-Eulerian Map Distillation (AYF-EMD): Let π‘“πœƒ(π‘₯𝑑,𝑑,𝑠) be the flow map. Consider the loss function defined between two adjacent starting time steps 𝑑 and 𝑑′=𝑑+πœ€(π‘ βˆ’π‘‘) for a small πœ€>0,

𝐸π‘₯𝑑,𝑑,𝑠[𝑀(𝑑,𝑠)β€–π‘“πœƒ(π‘₯𝑑,𝑑,𝑠)βˆ’π‘“πœƒβˆ’(π‘₯𝑑′,𝑑′,𝑠)β€–22],

where π‘₯𝑑′ is obtained by applying a 1-step Euler solver to the PF-ODE from 𝑑 to 𝑑′. In the limit as πœ€β†’0, the gradient of this loss function with respect to πœƒ gives

βˆ‡πœƒπΈπ‘₯𝑑,𝑑,𝑠[𝑀′(𝑑,𝑠)Β sign(π‘‘βˆ’π‘ )β‹…π‘“πœƒπ‘‡(π‘₯𝑑,𝑑,𝑠)β‹…π‘‘π‘“πœƒβˆ’(π‘₯𝑑,𝑑,𝑠)𝑑𝑑]

where 𝑀′(𝑑,𝑠)=𝑀(𝑑,𝑠)Γ—|π‘‘βˆ’π‘ |. The AYF-EMD loss naturally generalizes the loss used to train continuous-time consistency models, as it reduces to the same objective when 𝑠=0.

The second approach ensures consistency at timestep 𝑠 instead. This method tries to ensure that for a fixed (π‘₯𝑑,𝑑), the trajectory π‘“πœƒ(π‘₯𝑑,𝑑,Β·) is aligned with that points' PF-ODE.

AYF-Lagrangian Map Distillation (AYF-LMD): Let π‘“πœƒ(π‘₯𝑑,𝑑,𝑠) be the flow map. Consider the loss function defined between two adjacent ending timesteps 𝑠 and 𝑠′=𝑠+πœ€(π‘‘βˆ’π‘ ) for a small πœ€>0,

𝐸π‘₯𝑑,𝑑,𝑠[𝑀(𝑑,𝑠)β€–π‘“πœƒ(π‘₯𝑑,𝑑,𝑠)βˆ’Β ODE𝑠′→𝑠(π‘“πœƒβˆ’(π‘₯𝑑,𝑑,𝑠′))β€–22],

where ODE𝑑→𝑠(π‘₯) refers to running a 1-step Euler solver on the PF-ODE starting from π‘₯ at timestep 𝑑 to timestep 𝑠. In the limit as πœ€β†’0, the gradient of this objective with respect to πœƒ converges to:

βˆ‡πœƒπΈπ‘₯𝑑,𝑑,𝑠[𝑀′(𝑑,𝑠)Β sign(π‘ βˆ’π‘‘)β‹…π‘“πœƒπ‘‡(π‘₯𝑑,𝑑,𝑠)β‹…(π‘‘π‘“πœƒβˆ’(π‘₯𝑑,𝑑,𝑠)π‘‘π‘ βˆ’π‘£πœ‘(π‘“πœƒβˆ’(π‘₯𝑑,𝑑,𝑠),𝑠))],

where 𝑀′(𝑑,𝑠)=𝑀(𝑑,𝑠)Γ—|π‘‘βˆ’π‘ |.

References

  1. Flow map matching with stochastic interpolants: A mathematical framework for consistency models
  2. How to build a consistency model: Learning flow maps via self-distillation
  3. Align Your Flow: Scaling Continuous-Time Flow Map Distillation