Ruthotto ML meet OT @ Oct 2020
Machine Learning ↔ Optimal Transport
Sayas Numerics Seminar Lars Ruthotto
Departments of Mathematics and Computer Science Emory University lruthotto@emory.edu @lruthotto
Title ML → OT Lag NN Exp OT→CNF Σ 1
Machine Learning Optimal Transport Sayas Numerics Seminar Lars - - PowerPoint PPT Presentation
Ruthotto ML meet OT @ Oct 2020 Machine Learning Optimal Transport Sayas Numerics Seminar Lars Ruthotto Departments of Mathematics and Computer Science Emory University lruthotto@emory.edu @lruthotto Title ML OT Lag NN Exp OT
Ruthotto ML meet OT @ Oct 2020
Departments of Mathematics and Computer Science Emory University lruthotto@emory.edu @lruthotto
Title ML → OT Lag NN Exp OT→CNF Σ 1
Ruthotto ML meet OT @ Oct 2020
◮ ML → OT: New Tricks from Learning
◮ based on relaxed dynamical optimal transport ◮ combine macroscopic / microscopic / HJB equations ◮ neural networks for value function ◮ combine analytic gradients and automatic differentiation ◮ generalization to mean field games and control problems
◮ OT → ML: Learning from Old Tricks
◮ variational inference via continuous normalizing flows ◮ applications: density estimation, generative modeling ◮ OT uniqueness and regularity of dynamics ◮ HJB, solid numerics, and efficient implementation ◮ orders of magnitude speedup training and inference
LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.
Title ML → OT Lag NN Exp OT→CNF Σ 2
Ruthotto ML meet OT @ Oct 2020
Emory Funding: ◮ DMS 1751636 ◮ BSF 2018209 ◮ FA9550-20-1-0372 Special Thanks: ◮ Organizers and staff of IPAM Long Program MLP 2019. ◮ Osher’s funding AFOSR MURI and ONR
Onken Wu Fung Li Nurbekyan
Title ML → OT Lag NN Exp OT→CNF Σ 3
Ruthotto ML meet OT @ Oct 2020
initial density, ρ0 target density, ρ1 density evolution
Given the initial density, ρ0, and the target density, ρ1, find the velocity v that renders the push-forward of ρ0 equal to ρ1 and minimizes the transport costs, i.e., minimizev,ρ 1
1 2v(x, t)2ρ(x, t)dxdt subject to ∂tρ + ∇ · (ρv) = 0, ρ(·, 0) = ρ0(·), ρ(·, 1) = ρ1(·)
Title ML → OT Lag NN Exp OT→CNF Σ 4
Ruthotto ML meet OT @ Oct 2020
initial density, ρ0 target density, ρ1 density evolution ρ(·, 1) push-fwd of ρ0
Given the initial density, ρ0, and the target density, ρ1, find the velocity v that renders the push-forward of ρ0 equal to ρ1 and minimizes the transport costs, i.e., minimizev,ρ 1
1 2v(x, t)2ρ(x, t)dxdt subject to ∂tρ + ∇ · (ρv) = 0, ρ(·, 0) = ρ0(·), ρ(·, 1) = ρ1(·)
Title ML → OT Lag NN Exp OT→CNF Σ 4
Ruthotto ML meet OT @ Oct 2020
initial density, ρ0 target density, ρ1 density evolution
Given the initial density, ρ0, and the target density, ρ1, find the velocity v that minimizes the discrepancy between the push-forward of ρ0 and ρ1 and the transport costs, i.e., minimizev,ρJMFG(ρ, v)
def
= 1
1 2v(x, t)2ρ(x, t)dxdt + G(ρ(·, 1), ρ1) subject to ∂tρ + ∇ · (ρv) = 0, ρ(·, 0) = ρ0(·) (CE) Examples for terminal cost G: L2, Kullback Leibler divergence,. . . Side note: relaxed OT problem is a potential mean field game (MFG)
Title ML → OT Lag NN Exp OT→CNF Σ 5
Ruthotto ML meet OT @ Oct 2020
initial density, ρ0 target density, ρ1 density evolution ρ(·, 1) push-fwd of ρ0
Given the initial density, ρ0, and the target density, ρ1, find the velocity v that minimizes the discrepancy between the push-forward of ρ0 and ρ1 and the transport costs, i.e., minimizev,ρJMFG(ρ, v)
def
= 1
1 2v(x, t)2ρ(x, t)dxdt + G(ρ(·, 1), ρ1) subject to ∂tρ + ∇ · (ρv) = 0, ρ(·, 0) = ρ0(·) (CE) Examples for terminal cost G: L2, Kullback Leibler divergence,. . . Side note: relaxed OT problem is a potential mean field game (MFG)
Title ML → OT Lag NN Exp OT→CNF Σ 5
Ruthotto ML meet OT @ Oct 2020
A single agent with initial position x ∈ Ω aims at choosing v that minimizes Jx,0(v) = 1 1 2v(s)2ds + G (z(1), ρ(z(1), 1)) , where their position changes according to ∂tz(s) = v (s) , 0 ≤ s ≤ 1, z(0) = x. ◮ G(x, ρ) = δG(ρ,ρ1)
δρ
(x) (variational derivative of G) ◮ agent interacts with the population through ρ and G ◮ z(·) is characteristic curve of (CE) starting at x
Title ML → OT Lag NN Exp OT→CNF Σ 6
Ruthotto ML meet OT @ Oct 2020
A single agent with initial position x ∈ Ω aims at choosing v that minimizes Jx,0(v) = 1 1 2v(s)2ds + G (z(1), ρ(z(1), 1)) , where their position changes according to ∂tz(s) = v (s) , 0 ≤ s ≤ 1, z(0) = x. ◮ G(x, ρ) = δG(ρ,ρ1)
δρ
(x) (variational derivative of G) ◮ agent interacts with the population through ρ and G ◮ z(·) is characteristic curve of (CE) starting at x Useful to define the value of an agent’s state (x, t) as Φ(x, t) = inf
v Jx,t(v)
Title ML → OT Lag NN Exp OT→CNF Σ 6
Ruthotto ML meet OT @ Oct 2020
initial density, ρ0 value function density evolution target density, ρ1
Lasry & Lions ’06: First-order optimality conditions of relaxed OT are −∂tΦ(x, t) + 1 2∇Φ(x, t)2 = 0, Φ(x, 1) = G(x, ρ(x, 1)) (HJB) and optimal strategy is v(x, t) = −∇Φ(x, t), which gives ∂tρ(x, t) − ∇ · (ρ(x, t)∇Φ(x, t)) = 0, ρ(x, 0) = ρ0(x) (CE) challenges: forward-backward structure and high-dimensionality of PDE system
Title ML → OT Lag NN Exp OT→CNF Σ 7
Ruthotto ML meet OT @ Oct 2020
Three options for solving the problem
Title ML → OT Lag NN Exp OT→CNF Σ 8
Ruthotto ML meet OT @ Oct 2020
Three options for solving the problem
Idea: Combine advantages of the above to tackle curse of dimensionality
Title ML → OT Lag NN Exp OT→CNF Σ 8
Ruthotto ML meet OT @ Oct 2020
Three options for solving the problem
Idea: Combine advantages of the above to tackle curse of dimensionality ◮ formulate as variational problem. minimize JMFG(ρ, −∇Φ) ◮ eliminate (CE) with Lagrangian PDE solver mesh-free, parallel ◮ parameterize Φ with NN universal approximator, mesh-free, cheap(?) ◮ penalize violations of (HJB) regularity, global convergence(?)
Title ML → OT Lag NN Exp OT→CNF Σ 8
Ruthotto ML meet OT @ Oct 2020
Assume Φ given. Then, the solution to ∂tρ(x, t) − ∇ · (ρ(x, t)∇Φ(x, t)) = 0, ρ(x, 0) = ρ0(x) satisfies ρ(z(x, t), t) det ∇z(x, t) = ρ0(x) along the characteristic curve ∂tz(x, t) = −∇Φ(z(x, t)), z(x, 0) = x.
Title ML → OT Lag NN Exp OT→CNF Σ 9
Ruthotto ML meet OT @ Oct 2020
Assume Φ given. Then, the solution to ∂tρ(x, t) − ∇ · (ρ(x, t)∇Φ(x, t)) = 0, ρ(x, 0) = ρ0(x) satisfies ρ(z(x, t), t) det ∇z(x, t) = ρ0(x) along the characteristic curve ∂tz(x, t) = −∇Φ(z(x, t)), z(x, 0) = x. instead of computing det ∇z(x, t) (cost O(d3) flops) use l(x, t)
def
= log det(∇z(x, t)) = 1 ∆Φ(z(x, t), t)dt Hint: Compute z and l in one ODE solve (parallelize over x1, x2, . . .).
Title ML → OT Lag NN Exp OT→CNF Σ 9
Ruthotto ML meet OT @ Oct 2020
minimizeΦ Eρ0
∂t z(x, t) l(x, t) cL(x, t) cH(x, t) = −∇Φ(z(x, t), t) −∆Φ(z(x, t), t)
1 2∇Φ(z(x, t), t)2
2∇Φ(z(x, t), t)2
, t ∈ (0, 1] z(x, 0) = x, l(x, 0) = cL(x, 0) = cH(x, 0) = 0
Title ML → OT Lag NN Exp OT→CNF Σ 10
Ruthotto ML meet OT @ Oct 2020
minimizeΦ Eρ0
∂t z(x, t) l(x, t) cL(x, t) cH(x, t) = −∇Φ(z(x, t), t) −∆Φ(z(x, t), t)
1 2∇Φ(z(x, t), t)2
2∇Φ(z(x, t), t)2
, t ∈ (0, 1] z(x, 0) = x, l(x, 0) = cL(x, 0) = cH(x, 0) = 0 ◮ z and l = log det needed to solve continuity eq. (CE) ◮ cL and cH accumulate cost along characteristic ◮ α1, α2: penalty parameters for HJB violation ◮ discretize dynamics with nt steps of Runge-Kutta-4 ◮ discretize E with Monte Carlo ◮ can use SA (SGD, ADAM,. . . ) or SAA (BFGS, Newton,. . . ) methods ◮ no grid needed and computation can be parallelized over x Next, parameterize Φ with NN. Needed: ∇Φ and ∆Φ
Title ML → OT Lag NN Exp OT→CNF Σ 10
Ruthotto ML meet OT @ Oct 2020
◮ deep learning: use neural networks (from ≈ 1950’s) with many hidden layers ◮ able to ”learn” complicated patterns from data ◮ applications: classification, face recognition, segmentation, driverless cars, . . . ◮ recent success fueled by: massive data sets, computing power ◮ A few recent references:
◮ Data Scientist: Sexiest Job of the 21st Century, Harvard Business Rev ’17 ◮ A radical new neural network design could overcome big challenges in AI, MIT Tech Review ’18
Title ML → OT Lag NN Exp OT→CNF Σ 11
Ruthotto ML meet OT @ Oct 2020
Yj+1 = σ(KjYj + bj) Yj+1 = Yj + σ(KjYj + bj) Yj+1 = Yj + σ (Kj,2σ(Kj,1Yj + bj,1) + bj,2) . . .
(Notation: Yj : features, Kj, bj: weights, σ : activation)
◮ deep learning: use neural networks (from ≈ 1950’s) with many hidden layers ◮ able to ”learn” complicated patterns from data ◮ applications: classification, face recognition, segmentation, driverless cars, . . . ◮ recent success fueled by: massive data sets, computing power ◮ A few recent references:
◮ Data Scientist: Sexiest Job of the 21st Century, Harvard Business Rev ’17 ◮ A radical new neural network design could overcome big challenges in AI, MIT Tech Review ’18
Title ML → OT Lag NN Exp OT→CNF Σ 11
Ruthotto ML meet OT @ Oct 2020
Let s = (x, t) ∈ Rd+1 and use (NN + quadratic) model for value function Φ(s, θ) = w⊤N(s, θN) + 1 2s⊤As + c⊤s + b, θ = (w, θN, vec(A), c, b) N(s, θN) is an M-layer ResNet with weights θN = (vec(K0), . . . , vec(KM), b0, . . . , bM).
Title ML → OT Lag NN Exp OT→CNF Σ 12
Ruthotto ML meet OT @ Oct 2020
Let s = (x, t) ∈ Rd+1 and use (NN + quadratic) model for value function Φ(s, θ) = w⊤N(s, θN) + 1 2s⊤As + c⊤s + b, θ = (w, θN, vec(A), c, b) N(s, θN) is an M-layer ResNet with weights θN = (vec(K0), . . . , vec(KM), b0, . . . , bM). forward propagation: u−1 = s u0 = σ(K0u−1 + b0) u1 = u0 + hσ(K1u0 + b1) . . . . . . uM = uM−1 + hσ(KMuM−1 + bM), Output: w⊤uM = w⊤N(s, θN)
Title ML → OT Lag NN Exp OT→CNF Σ 12
Ruthotto ML meet OT @ Oct 2020
Let s = (x, t) ∈ Rd+1 and use (NN + quadratic) model for value function Φ(s, θ) = w⊤N(s, θN) + 1 2s⊤As + c⊤s + b, θ = (w, θN, vec(A), c, b) N(s, θN) is an M-layer ResNet with weights θN = (vec(K0), . . . , vec(KM), b0, . . . , bM). forward propagation: u−1 = s u0 = σ(K0u−1 + b0) u1 = u0 + hσ(K1u0 + b1) . . . . . . uM = uM−1 + hσ(KMuM−1 + bM), Output: w⊤uM = w⊤N(s, θN) backward propagation: zM+1 = w zM = zM+1 + hK⊤
Mdiag(σ′(KMuM−1 + bM))zM+1,
. . . . . . z1 = z2 + hK⊤
1 diag(σ′(K1u0 + b1))z2,
z0 = K⊤
0 diag(σ′(K0s + b0))z1,
Output: z0 = ∇s(w⊤N(s, θN)) Next: Compute ∆Φ(s, θ) = tr
s(N(s, θN)w) + A)E
Title ML → OT Lag NN Exp OT→CNF Σ 12
Ruthotto ML meet OT @ Oct 2020
∆Φ(s, θ) = tr
s(N(s, θN)w) + A)E
E = eye(d+1,d)
Title ML → OT Lag NN Exp OT→CNF Σ 13
Ruthotto ML meet OT @ Oct 2020
∆Φ(s, θ) = tr
s(N(s, θN)w) + A)E
E = eye(d+1,d) Second term trivial. Focus on NN part and use forward mode for first layer t0 = tr
0 diag(σ′′(K0s + b0))z1)E
( ⊙ Hadamard product, 1 =ones(d,1))
Title ML → OT Lag NN Exp OT→CNF Σ 13
Ruthotto ML meet OT @ Oct 2020
∆Φ(s, θ) = tr
s(N(s, θN)w) + A)E
E = eye(d+1,d) Second term trivial. Focus on NN part and use forward mode for first layer t0 = tr
0 diag(σ′′(K0s + b0))z1)E
( ⊙ Hadamard product, 1 =ones(d,1))
Get ∆(N(s, θN)w) = t0 + h M
i=1 ti where for i ≥ 1
ti = tr
i−1∇s(K⊤ i diag(σ′′(Kiui−1(s) + bi))zi+1)Ji−1
Here, Ji−1 = ∇su⊤
i−1 ∈ Rm×d is a Jacobian matrix (update during forward pass)
Title ML → OT Lag NN Exp OT→CNF Σ 13
Ruthotto ML meet OT @ Oct 2020
∆Φ(s, θ) = tr
s(N(s, θN)w) + A)E
E = eye(d+1,d) Second term trivial. Focus on NN part and use forward mode for first layer t0 = tr
0 diag(σ′′(K0s + b0))z1)E
( ⊙ Hadamard product, 1 =ones(d,1))
Get ∆(N(s, θN)w) = t0 + h M
i=1 ti where for i ≥ 1
ti = tr
i−1∇s(K⊤ i diag(σ′′(Kiui−1(s) + bi))zi+1)Ji−1
Here, Ji−1 = ∇su⊤
i−1 ∈ Rm×d is a Jacobian matrix (update during forward pass)
Title ML → OT Lag NN Exp OT→CNF Σ 13
Ruthotto ML meet OT @ Oct 2020
500 101 102 iteration
with CHJB, nt = 2 no CHJB, nt = 2 no CHJB, nt = 8
ρ0, initial density ρ1, target density JMFG, mean field obj pull back with CHJB, nt = 2 push forward characteristics no CHJB, nt = 2 no CHJB, nt = 8
HJB penalty improves accuracy and(!) lowers computational costs
Title ML → OT Lag NN Exp OT→CNF Σ 14
Ruthotto ML meet OT @ Oct 2020
Eulerian scheme: ◮ dynamical OT formulation ◮ conservative finite volume ◮ leads to convex optimization ◮ solved to high accuracy with Newton’s method
E Haber, R Horesh A Multilevel Method for the Solution of Time Dependent Optimal Transport, NM-TMA 8(1), 2015.
Title ML → OT Lag NN Exp OT→CNF Σ 15
Ruthotto ML meet OT @ Oct 2020
Eulerian scheme: ◮ dynamical OT formulation ◮ conservative finite volume ◮ leads to convex optimization ◮ solved to high accuracy with Newton’s method Comparison:
# parameters JMFG Eulerian, fine 3,080,448 1.066e+01 (100.00%) Eulerian, coarse 376,960 1.082e+01 (101.47%) MFGnet (nt = 2) 637 1.072e+01 (100.59%) MFGnet (nt = 8) 637 1.063e+01 (99.69%) E Haber, R Horesh A Multilevel Method for the Solution of Time Dependent Optimal Transport, NM-TMA 8(1), 2015.
ρ0, initial density ρ1, target density pull back Lagrangian, ML push forward characteristics Eulerian, finite volume Title ML → OT Lag NN Exp OT→CNF Σ 15
Ruthotto ML meet OT @ Oct 2020
ρ0, initial density ρ1, target density initial time, t = 0 final time, t = 1 ΦLag(·, t) ,Lagrangian ML ΦEul(·, t) ,Eulerian FV error, |ΦLag(·, t) − ΦEul(·, t)|
Take away: Eulerian (≈ 3M parameters) and Lagrangian-ML (637 parameters) give comparable accuracy.
Title ML → OT Lag NN Exp OT→CNF Σ 16
Ruthotto ML meet OT @ Oct 2020
ρ0, initial density ρ1, target density initial time, t = 0 final time, t = 1 ΦLag(·, t) ,Lagrangian ML ΦEul(·, t) ,Eulerian FV error, |ΦLag(·, t) − ΦEul(·, t)|
Take away: Eulerian (≈ 3M parameters) and Lagrangian-ML (637 parameters) give comparable accuracy.
Title ML → OT Lag NN Exp OT→CNF Σ 16
Ruthotto ML meet OT @ Oct 2020
Model large populations of rational agents playing non-cooperative differential game.
Title ML → OT Lag NN Exp OT→CNF Σ 17
Ruthotto ML meet OT @ Oct 2020
Model large populations of rational agents playing non-cooperative differential game. minimizev,ρ JMFG(v, ρ)
def
= 1
1 F(ρ(·, t))dt + G(ρ(·, 1)) subject to ∂tρ(x, t) + ∇ · (ρ(x, t)v(x, t)) = 0, ρ(x, 0) = ρ0(x), Use running costs F to model, e.g., ◮ congestion FE(ρ) =
◮ spatio-temporal preference FP(ρ) =
time − → ← − space − →
Title ML → OT Lag NN Exp OT→CNF Σ 17
Ruthotto ML meet OT @ Oct 2020
Levon Nurbekyan @ IPAM Opening Workshop Computational methods for mean-field games https://bit.ly/3cELBmW Samy Wu Fung @ Emory Scientific Computing Seminar A GAN-based Approach for High-Dimensional Stochastic Mean Field Games https://bit.ly/2TcqvVp
Title ML → OT Lag NN Exp OT→CNF Σ 18
Ruthotto ML meet OT @ Oct 2020
Given samples x1, x2, . . . , xN ∈ Rd, find a velocity v that maximizes the likelihood of the samples w.r.t. the push-forward of the standard normal distribution ρ1, i.e., maximizev,z 1 N
N
ρ1(z(xk, 1)) · det ∇(z(xk, 1)) subject to ∂tz(xk, t) = v(z(xk, t), t), with z(xk, 0) = xk for all k.
W Grathwohl et al. FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative
z(x1, 0) . . . z(xN, 0) ρ1
Title ML → OT Lag NN Exp OT→CNF Σ 19
Ruthotto ML meet OT @ Oct 2020
Given samples x1, x2, . . . , xN ∈ Rd, find a velocity v that maximizes the likelihood of the samples w.r.t. the push-forward of the standard normal distribution ρ1, i.e., minimizev,z GCNF(v, z) := 1 N
N
1 2z(xk, 1)2 − l(xk, 1)
∂t z(xk, s) l(xk, s)
trace(∇v(z(xk, s), s))
Recall: l(xk, 1) = log det(∇z(xk, 1))
W Grathwohl et al. FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative
z(x1, 0) . . . z(xN, 0) ρ1
Title ML → OT Lag NN Exp OT→CNF Σ 19
Ruthotto ML meet OT @ Oct 2020
Given samples x1, x2, . . . , xN ∈ Rd, find a velocity v that maximizes the likelihood of the samples w.r.t. the push-forward of the standard normal distribution ρ1, i.e., minimizev,z GCNF(v, z) := 1 N
N
1 2z(xk, 1)2 − l(xk, 1)
∂t z(xk, s) l(xk, s)
trace(∇v(z(xk, s), s))
Recall: l(xk, 1) = log det(∇z(xk, 1))
W Grathwohl et al. FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative
z(x1, 0) . . . z(xN, 0) ρ1 z(xN, 1)
Title ML → OT Lag NN Exp OT→CNF Σ 19
Ruthotto ML meet OT @ Oct 2020
Given samples x1, x2, . . . , xN ∈ Rd, find a velocity v that maximizes the likelihood of the samples w.r.t. the push-forward of the standard normal distribution ρ1, i.e., minimizev,z GCNF(v, z) := 1 N
N
1 2z(xk, 1)2 − l(xk, 1)
∂t z(xk, s) l(xk, s)
trace(∇v(z(xk, s), s))
Recall: l(xk, 1) = log det(∇z(xk, 1))
W Grathwohl et al. FFJORD: Free-form Continuous Dynamics for Scalable Reversible Generative
z(x1, 0) . . . z(xN, 0) ρ1 z(xN, 1) ˆ ρ0
Title ML → OT Lag NN Exp OT→CNF Σ 19
Ruthotto ML meet OT @ Oct 2020
Given samples x1, x2, . . . , xN ∈ Rd, find the value function Φ such that the flow given by v = −∇Φ maximizes the likelihood
minv,z 1 N
N
1 2z(xk, 1)2 − l(xk, 1) + β1cL(xk, 1) + β2cH(xk, 1)
∂tz(xk, t) = v(z(xk, t), t), z(xk, 0) = xk ∀k
z(x1, 0) ρ1
Title ML → OT Lag NN Exp OT→CNF Σ 20
Ruthotto ML meet OT @ Oct 2020
Given samples x1, x2, . . . , xN ∈ Rd, find the value function Φ such that the flow given by v = −∇Φ maximizes the likelihood
minv,z 1 N
N
1 2z(xk, 1)2 − l(xk, 1) + β1cL(xk, 1) + β2cH(xk, 1)
∂tz(xk, t) = v(z(xk, t), t), z(xk, 0) = xk ∀k
z(x1, 0) ρ1 z(x1, 1)
Title ML → OT Lag NN Exp OT→CNF Σ 20
Ruthotto ML meet OT @ Oct 2020
Given samples x1, x2, . . . , xN ∈ Rd, find the value function Φ such that the flow given by v = −∇Φ maximizes the likelihood
minv,z 1 N
N
1 2z(xk, 1)2 − l(xk, 1) + β1cL(xk, 1) + β2cH(xk, 1)
∂tz(xk, t) = v(z(xk, t), t), z(xk, 0) = xk ∀k ◮ provides uniqueness ◮ more efficient time integration
z(x1, 0) ρ1 z(x1, 1) ˆ ρ0
L Yang, GE Karniadakis Potential Flow Generator with L2 OT Regularity for Generative Models. arXiv:1908.11462v1, 2018. L Zhang, Weinan E, L Wang Monge-Amp` ere Flow for Generative Modeling, arXiv:1809.10188v1, 2018. C Finlay, JH Jacobsen, L Nurbekyan, AM Oberman How to train your neural ODE, arXiv:2002.02798, 2020.
Title ML → OT Lag NN Exp OT→CNF Σ 20
Ruthotto ML meet OT @ Oct 2020
◮ Exact computation with automatic differentiation (AD) trace(∇v(x)) =
d
e⊤
i (∇v(x)⊤ei)
exact O(m · d2) FLOPS ◮ trace estimator with AD trace(∇v(x)) = Ew
S
S
(wk)⊤(∇v(x)⊤wk) inexact O(m · S · d) FLOPS
Title ML → OT Lag NN Exp OT→CNF Σ 21
Ruthotto ML meet OT @ Oct 2020
◮ Exact computation with automatic differentiation (AD) trace(∇v(x)) =
d
e⊤
i (∇v(x)⊤ei)
exact O(m · d2) FLOPS ◮ trace estimator with AD trace(∇v(x)) = Ew
S
S
(wk)⊤(∇v(x)⊤wk) inexact O(m · S · d) FLOPS
Title ML → OT Lag NN Exp OT→CNF Σ 21
Ruthotto ML meet OT @ Oct 2020
◮ Exact computation with automatic differentiation (AD) trace(∇v(x)) =
d
e⊤
i (∇v(x)⊤ei)
exact O(m · d2) FLOPS ◮ trace estimator with AD trace(∇v(x)) = Ew
S
S
(wk)⊤(∇v(x)⊤wk) inexact O(m · S · d) FLOPS OT-Flow: exact trace computation (highly parallel) using O(m2 · d) FLOPS.
Title ML → OT Lag NN Exp OT→CNF Σ 21
Ruthotto ML meet OT @ Oct 2020
moons circles pinwheel checkerboard samples density estimate
Title ML → OT Lag NN Exp OT→CNF Σ 22
Ruthotto ML meet OT @ Oct 2020
104 105 106 107 10−5 10−4 10−3 10−2 network parameters max mean discrepancy 100 101 102 10−1 100 101 102 103 104 105 training time [hours] testing time [sec] ◮ OT-Flow yields competitive accuracy w.r.t. MMD ◮ FFJORD, RNODE: between 2× and 22× more weights ◮ OT-Flow considerably faster in training and testing.
Title ML → OT Lag NN Exp OT→CNF Σ 23
Ruthotto ML meet OT @ Oct 2020
◮ let y1, y2, . . . ∈ R768 MNIST images ◮ train encoder E : R784 → R128 and decoder D : R128 → R784 s.t. D(E(y)) ≈ y ◮ latent space representation of data xj = E(yj) for all j. ◮ train OT-Flow f that maps {xj}j to ρ1 ∼ N(0, I128) ◮ interpolate between two images y1, y2 in latent space and get new image y(λ) = D(f −1(λf(E(y1))+(1−λ)f(E(y2))))
red boxed values are original; others are interpolated in rho_1 space
y1 y2 y4 y3
Title ML → OT Lag NN Exp OT→CNF Σ 24
Ruthotto ML meet OT @ Oct 2020
https://github.com/EmoryMLIP/OT-Flow Julia implementation for more general MFGs: https://github.com/EmoryMLIP/MFGnet.jl
Title ML → OT Lag NN Exp OT→CNF Σ 25
Ruthotto ML meet OT @ Oct 2020
Machine Learning → Optimal Transport
◮ ML attractive for high-dimensional PDEs, control, . . . ◮ MFGnet: mesh-free solver for variational problem and combine. . .
◮ microscopic: Lagrangian method for continuity and HJB eqs. ◮ macroscopic: variational problem, new penalties for HJB eq.
◮ details matter: models, numerics, architecture, training, . . . ◮ surprise: ML solution competitive to convex programming
LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.
Title ML → OT Lag NN Exp OT→CNF Σ 26
Ruthotto ML meet OT @ Oct 2020
Machine Learning → Optimal Transport
◮ ML attractive for high-dimensional PDEs, control, . . . ◮ MFGnet: mesh-free solver for variational problem and combine. . .
◮ microscopic: Lagrangian method for continuity and HJB eqs. ◮ macroscopic: variational problem, new penalties for HJB eq.
◮ details matter: models, numerics, architecture, training, . . . ◮ surprise: ML solution competitive to convex programming
Optimal Transport → Continuous Normalizing Flows
◮ OT regularization: well-posed simplifies time integration ◮ discretize-then-optimize + HJB penalty → very few time steps ◮ don’t take chances: use exact trace computation ◮ OT-Flow speeds up training and testing by ≈ 10x
LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.
Title ML → OT Lag NN Exp OT→CNF Σ 26
Ruthotto ML meet OT @ Oct 2020
Machine Learning → Optimal Transport
◮ ML attractive for high-dimensional PDEs, control, . . . ◮ MFGnet: mesh-free solver for variational problem and combine. . .
◮ microscopic: Lagrangian method for continuity and HJB eqs. ◮ macroscopic: variational problem, new penalties for HJB eq.
◮ details matter: models, numerics, architecture, training, . . . ◮ surprise: ML solution competitive to convex programming
Optimal Transport → Continuous Normalizing Flows
◮ OT regularization: well-posed simplifies time integration ◮ discretize-then-optimize + HJB penalty → very few time steps ◮ don’t take chances: use exact trace computation ◮ OT-Flow speeds up training and testing by ≈ 10x
LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.
Title ML → OT Lag NN Exp OT→CNF Σ 26
Ruthotto ML meet OT @ Oct 2020
Machine Learning → Optimal Transport
◮ ML attractive for high-dimensional PDEs, control, . . . ◮ MFGnet: mesh-free solver for variational problem and combine. . .
◮ microscopic: Lagrangian method for continuity and HJB eqs. ◮ macroscopic: variational problem, new penalties for HJB eq.
◮ details matter: models, numerics, architecture, training, . . . ◮ surprise: ML solution competitive to convex programming
Optimal Transport → Continuous Normalizing Flows
◮ OT regularization: well-posed simplifies time integration ◮ discretize-then-optimize + HJB penalty → very few time steps ◮ don’t take chances: use exact trace computation ◮ OT-Flow speeds up training and testing by ≈ 10x
LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.
Title ML → OT Lag NN Exp OT→CNF Σ 26
Ruthotto ML meet OT @ Oct 2020
Machine Learning → Optimal Transport
◮ ML attractive for high-dimensional PDEs, control, . . . ◮ MFGnet: mesh-free solver for variational problem and combine. . .
◮ microscopic: Lagrangian method for continuity and HJB eqs. ◮ macroscopic: variational problem, new penalties for HJB eq.
◮ details matter: models, numerics, architecture, training, . . . ◮ surprise: ML solution competitive to convex programming
Optimal Transport → Continuous Normalizing Flows
◮ OT regularization: well-posed simplifies time integration ◮ discretize-then-optimize + HJB penalty → very few time steps ◮ don’t take chances: use exact trace computation ◮ OT-Flow speeds up training and testing by ≈ 10x
LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.
Title ML → OT Lag NN Exp OT→CNF Σ 26
Ruthotto ML meet OT @ Oct 2020
Machine Learning → Optimal Transport
◮ ML attractive for high-dimensional PDEs, control, . . . ◮ MFGnet: mesh-free solver for variational problem and combine. . .
◮ microscopic: Lagrangian method for continuity and HJB eqs. ◮ macroscopic: variational problem, new penalties for HJB eq.
◮ details matter: models, numerics, architecture, training, . . . ◮ surprise: ML solution competitive to convex programming
Optimal Transport → Continuous Normalizing Flows
◮ OT regularization: well-posed simplifies time integration ◮ discretize-then-optimize + HJB penalty → very few time steps ◮ don’t take chances: use exact trace computation ◮ OT-Flow speeds up training and testing by ≈ 10x
LR, S Osher, W Li, L Nurbekyan, S Wu Fung A ML Framework for Solving High-Dimensional MFG and MFC PNAS 117 (17), 9183-9193, 2020 D Onken, S Wu Fung, X Li, LR OT-Flow: Fast and Accurate CNF via OT arXiv:2006.00104, 2020.
Title ML → OT Lag NN Exp OT→CNF Σ 26