Top Banner
Neural Ordinary Differential Equations Ricky T. Q. Chen * , Yulia Rubanova * , Jesse Bettencourt * , David Duvenaud * Equal Contribution University of Toronto, Vector Institute Contributions Black-box ODE solvers as a differentiable modeling component. Continuous-time recurrent neural nets and continuous-depth feedforward nets. Adaptive computation with explicit control over tradeoff between speed and numerical precision. ODE-based change of variables for automatically-invertible normalizing flows. Open-sourced ODE solvers with O(1)-memory backprop: https://github.com/rtqichen/torchdiffeq ODE Solvers: How Do They Work? z(t ) changes in time, defines an infinite set of trajectories. Define a differential equation: d z dt = f (z(t ), t ). Initial-value problem: given z(t 0 ), find z(t 1 )= z(t 0 )+ R t 1 t 0 f (z(t), t). Approximate solution with discrete steps, e.g. z(t + h )= z(t )+ hf (z, t ). Higher-order solvers are more accurate and use larger step sizes. Can adapt step size h given error tolerance level. Continuous version of ResNets ODE-Net replaces ResNet blocks with ODESolve(f , z(t 0 ), t 0 , t 1 ), where f is a neural net with parameters θ . z(t 1 )= z(t 0 )+ Z t 1 t 0 f (z(t ), t )dt = ODESolve(z(t 0 ), f , t 0 , t 1 ) Residual Network ODE-Net Input Output Depth Input Output Depth z(1) z(0) h t +1 = h t + f (h t t ) d h(t ) dt = f (h(t ), t ) ResNet ODE-Net def resnet(x, θ ): h1 = x + NeuralNet(x, θ [0]) h2 = h1 + NeuralNet(h1, θ [1]) h3 = h2 + NeuralNet(h2, θ [2]) h4 = h3 + NeuralNet(h3, θ [3]) return h4 def f(z, t, θ ): return NeuralNet([z, t], θ ) def ODEnet(x, θ ): return ODESolve(f, x, 0, 1, θ ) ‘Depth’ is automatically chosen by an adaptive ODE solver. Computing gradient for ODE solutions O(1) memory cost when training. Don’t store activations, follow dynamics in reverse. No backpropagation through the ODE solver – compute the gradient through another call to ODESolve. Adjoint State State Define adjoint state: a (t )= - L / z(t ) Adjoint state dynamics: a(t ) dt = -a (t ) f (z(t ),t ) z Solve ODE backwards in time: dL d θ = R t 1 t 0 a (t ) T f (z(t ),t ) ∂θ dt def f_and_a([z,a,grad], t): return [f, -a*df/da, -a*df/d theta] [z0, dL/dx, dL/d theta] = ODESolve(f_and_a, [z(t1), dL/dz(t), 0], t0, t1) ODE Nets for Supervised Learning Adaptive computation: can adjust speed vs precision. We can specify the error tolerance of ODE solution: ODESolve(f θ , z(t 0 ), t 0 , t 1 , θ, rtol , atol ) numerical error Number of evaluations Error tolerance # forward evals error tolerance # backward evals # forward evals Error tolerance # forward evals error tolerance # forward evals training epoch Performance on MNIST Test Error # Params Memory Time 1-Layer MLP 1.60% 0.24 M - - ResNet 0.41% 0.60 M O(L) O(L) RK-Net 0.47% 0.22 M O( ˜ L) O( ˜ L) ODE-Net 0.42% 0.22 M O(1) O( ˜ L) Instantaneous Change of Variables Change of variables theorem to compute exact changes in probability of samples transformed through bijective F : z 1 = z 0 + f (z 0 )= log p (z 1 ) - log p (z 0 )= - log det F z 0 Requires invertible F . Cost O(D 3 ). Theorem: Assuming that f is uniformly Lipschitz continuous in z and continuous in t , then: d z dt = f (z(t ), t ) = log p (z(t )) t = -tr df d z(t ) Function f does not have to be invertible. Cost O(D 2 ). Continuous Normalizing Flows (CNF) Automatically-invertible Normalizing Flows. Planar CNF is smooth and much easier to train than planar NF. Planar normalizing flow Continuous analog of planar flow (Rezende and Mohamed, 2015) z(t + 1) = z(t )+ uh (w T z(t )+ b ) d z(t ) dt = uh (w T z(t )+ b ) log p (z(t + 1)) = log p (z(t )) log p (z(t )) t = -u T h z(t ) - log 1+ u T h z Density 5% 20% 40% 60% 80% 100% Samples NF Target (a) Two Circles Density 5% 20% 40% 60% 80% 100% Samples NF Target (b) Two Moons Figure: Visualizing the transformation from noise to data. Continuous normalizing flows are efficiently reversible, so we can train on a density estimation task and still be able to sample from the learned density efficiently. Samples Data Figure: Have since scaled CNFs to images using Hutchinson’s estimator (Grathwohl et al. 2018). Continuous-time Generative Model for Time Series Time series with irregular observation times. No discretization of the timeline is needed. μ σ z t 0 z t 1 RNN encoder Latent space Data space ~ q (z t 0 |x t 0 ...x t N ) h t 0 h t 1 h t N ODE Solve(z t 0 ,f,✓ f ,t 0 , ..., t M ) z t M z t N z t N +1 Observed Unobserved x(t) t 0 t 1 t N Time t N +1 t M Prediction Extrapolation t 0 t 1 t N t N +1 t M ˆ x(t) Recurrent Neural Network Latent ODE Latent space Figure: Latent ODE learns smooth latent dynamics from noisy observations. Prior Works on ODE+DL LeCun. ”A theoretical framework for back-propagation.” (1988) Pearlmutter. ”Gradient calculations for dynamic recurrent neural networks: a survey.” (1993) Haber & Ruthotto. ”Stable Architectures for Deep Neural Networks.” (2017) Chang et al. ”Multi-level Residual Networks from Dynamical Systems View.” (2018)
1

Neural Ordinary Differential Equations › ~rtqichen › posters › neural_ode_poster.pdf · N) h t 0 h t 1 h t ODE Solve(z N t 0,f,! f,t 0,...,t M) z t M É z t N z t N +1 Observed

Jun 10, 2020

Download

Documents

dariahiddleston
Welcome message from author
This document is posted to help you gain knowledge. Please leave a comment to let me know what you think about it! Share it to your friends and learn new things together.
Transcript
Page 1: Neural Ordinary Differential Equations › ~rtqichen › posters › neural_ode_poster.pdf · N) h t 0 h t 1 h t ODE Solve(z N t 0,f,! f,t 0,...,t M) z t M É z t N z t N +1 Observed

Neural Ordinary Differential EquationsRicky T. Q. Chen∗, Yulia Rubanova∗, Jesse Bettencourt∗, David Duvenaud

∗Equal Contribution University of Toronto, Vector Institute

Contributions

Black-box ODE solvers as a differentiable modeling component.

• Continuous-time recurrent neural nets and continuous-depth feedforward nets.

• Adaptive computation with explicit control over tradeoff between speed andnumerical precision.

• ODE-based change of variables for automatically-invertible normalizing flows.

• Open-sourced ODE solvers with O(1)-memory backprop:https://github.com/rtqichen/torchdiffeq

ODE Solvers: How Do They Work?

• z(t) changes in time, defines an infinite set of trajectories.

• Define a differential equation: dzdt = f (z(t), t, θ).

• Initial-value problem: given z(t0), find z(t1) = z(t0) +∫ t1t0f (z(t), t, θ).

• Approximate solution with discrete steps, e.g. z(t + h) = z(t) + hf (z, t).

• Higher-order solvers are more accurate and use larger step sizes.

• Can adapt step size h given error tolerance level.

Continuous version of ResNets

ODE-Net replaces ResNet blocks with ODESolve(f , z(t0), t0, t1, θ), where f is aneural net with parameters θ.

z(t1) = z(t0) +

∫ t1

t0

f (z(t), t, θ)dt = ODESolve(z(t0), f , t0, t1, θ)

Residual Network ODE-Net

Input

Output

Dep

th

Input

Output

Dep

th

z(1)

z(0)

ht+1 = ht + f (ht, θt)dh(t)dt = f (h(t), t, θ)

ResNet ODE-Netdef resnet(x, θ):

h1 = x + NeuralNet(x, θ[0])

h2 = h1 + NeuralNet(h1, θ[1])

h3 = h2 + NeuralNet(h2, θ[2])

h4 = h3 + NeuralNet(h3, θ[3])

return h4

def f(z, t, θ):

return NeuralNet ([z, t], θ)

def ODEnet(x, θ):

return ODESolve(f, x, 0, 1, θ)

‘Depth’ is automatically chosen by an adaptive ODE solver.

Computing gradient for ODE solutions

O(1) memory cost when training.Don’t store activations, follow dynamics in reverse.No backpropagation through the ODE solver – compute the gradient throughanother call to ODESolve.

Adjoint StateState

Define adjoint state:a(t) = −∂L/∂z(t)

Adjoint state dynamics:a(t)dt = −a(t)∂f (z(t),t,θ)

∂z

Solve ODE backwards in time:dLdθ =

∫ t1t0a(t)T ∂f (z(t),t,θ)

∂θ dt

def f_and_a ([z,a,grad], t):

return [f, -a*df/da , -a*df/d theta]

[z0 , dL/dx , dL/d theta] =

ODESolve(f_and_a , [z(t1), dL/dz(t), 0], t0, t1)

ODE Nets for Supervised Learning

Adaptive computation: can adjust speed vs precision.We can specify the error tolerance of ODE solution:

ODESolve(fθ, z(t0), t0, t1, θ, rtol , atol)nu

mer

ical

erro

rN

umer

ical

err

or

Number of evaluationsEr

ror t

oler

ance

# forward evalser

ror

tole

ranc

e

#ba

ckw

ard

eval

s

# forward evals

# ba

ckw

ard

eval

s

Erro

r tol

eran

ce

# forward evals

erro

rto

lera

nce

#fo

rwar

dev

als

0 25 50 75 100(d) Training Epoch

7.5

10.0

12.5

15.0

NFE

Forw

ard

training epoch

Performance on MNIST

Test Error # Params Memory Time

1-Layer MLP 1.60% 0.24 M - -ResNet 0.41% 0.60 M O(L) O(L)

RK-Net 0.47% 0.22 M O(L̃) O(L̃)

ODE-Net 0.42% 0.22 M O(1) O(L̃)

Instantaneous Change of Variables

Change of variables theorem to compute exact changes in probability of samplestransformed through bijective F :

z1 = z0 + f (z0) =⇒ log p(z1)− log p(z0) = − log

∣∣∣∣det∂F

∂z0

∣∣∣∣Requires invertible F . Cost O(D3).

Theorem: Assuming that f is uniformly Lipschitz continuous in z and continuousin t, then:

dz

dt= f (z(t), t) =⇒ ∂ log p(z(t))

∂t= −tr

(df

dz(t)

)Function f does not have to be invertible. Cost O(D2).

Continuous Normalizing Flows (CNF)

Automatically-invertible Normalizing Flows.Planar CNF is smooth and much easier to train than planar NF.

Planar normalizing flow Continuous analog of planar flow(Rezende and Mohamed, 2015)

z(t + 1) = z(t) + uh(wTz(t) + b) dz(t)dt = uh(wTz(t) + b)

log p(z(t + 1)) = log p(z(t)) ∂ log p(z(t))∂t = −uT ∂h

∂z(t)

− log∣∣1 + uT∂h

∂z

∣∣

Den

sity

5% 20% 40% 60% 80% 100%

Sam

ples

NF Target

(a) Two Circles

Den

sity

5% 20% 40% 60% 80% 100%

Sam

ples

NF Target

(b) Two Moons

Figure: Visualizing the transformation from noise to data. Continuous normalizing flows areefficiently reversible, so we can train on a density estimation task and still be able to sample from thelearned density efficiently.

Samples

Data

Figure: Have since scaled CNFs to images using Hutchinson’s estimator (Grathwohl et al. 2018).

Continuous-time Generative Model for Time Series

Time series with irregular observation times.No discretization of the timeline is needed.

µ

zt0zt1

RNN encoder

Latent spaceData space

~

q(zt0 |xt0 ...xtN)

ht0 ht1 htN

ODE Solve(zt0 , f, ✓f , t0, ..., tM )

ztM

…ztN

ztN+1

Observed Unobserved

x(t)

t0 t1 tN

Time

tN+1 tM

Prediction Extrapolation

t0 t1 tN tN+1 tM

x̂(t)

Ground TruthObservationPredictionExtrapolation

Recurrent Neural Network Latent ODE Latent space

Figure: Latent ODE learns smooth latent dynamics from noisy observations.

Prior Works on ODE+DL

LeCun. ”A theoretical framework for back-propagation.” (1988)Pearlmutter. ”Gradient calculations for dynamic recurrent neural networks: a survey.” (1993)Haber & Ruthotto. ”Stable Architectures for Deep Neural Networks.” (2017)Chang et al. ”Multi-level Residual Networks from Dynamical Systems View.” (2018)