Sampling-free Uncertainty Estimation in Gated Recurrent Units with Applications to Normative Modeling in Neuroimaging Seong Jae Hwang Ronak R. Mehta Hyunwoo J. Kim Stering C. Johnson Vikas Singh M OTIVATION 1. Given a visually ’good looking’ sequence prediction, how can we tell that its trajectory is correct? 2. If it is, can we derive the degree of uncertainty on its prediction? SP-GRU Input Sequence =1 = 10 Ground Truth Output Prediction Model Uncertainty Map = 11 = 20 Figure: Image sequence prediction with uncertainty. Given the first 10 frames of an input sequence (left), our model SP-GRU makes the Output Prediction and the pixel-level Model Uncertainty Map where bright regions indicate high uncertainty. SP-GRU estimates the uncertainty deterministically without sampling model parameters. G OAL Derive a recurrent neural network architecture capable of estimating uncertainty with the following properties: 1. Deterministically estimate uncertainties in a sampling-free manner (e.g., without Monte Carlo sampling) 2. Uncertainties of all intermediate neurons can be expressed in terms of a distribu- tion P RELIMINARIES I Gated Recurrent Unit (GRU): Reset Gate: r t = σ (W r x t + b r ) Update Gate: z t = σ (W z x t + b z ) State Candidate: ˆ h t = tanh(U ˆ h x t + W ˆ h (r t h t -1 )+ b ˆ h ) Cell State: h t =(1 - z t ) ˆ h t + z t h t -1 I Exponential Families in Neural Networks: Let x ∈ X be a random variable with probability density/mass function (pdf/pmf) f X . Then f X is an exponential family dis- tribution if f X (x |η )= h(x ) exp(η T T (x ) - A(η )) with natural parameters η , base measure h(x ), and sufficient statistics T (x ). Con- stant A(η ) (log-partition function) ensures that the distribution normalizes to 1. × g l a l W l W l a l a l+1 ∼ EXPFAM(g l (W l a l )) Figure: A single exponential family neuron. Weights W l are learned, and the output of a neuron is a sample generated from the exponential family defined a priori and by the natural parameters g l (W l a l -1 ). M OMENT M ATCHING I Linear Moment Matching (LMM): (1) the mean a m following the standard linearity of random variable expectations and (2) the variance a s : o m = W m a m + b m , o s = W s a s + b s +(W m W m )a s + W s (a m a m ) (1) I Nonlinear Moment Matching (NMM): Using the fact that σ (x ) ≈ Φ(ζ x ) where Φ(·) is a probit function and ζ = p π/8 is a constant, approximate the sigmoid functions for a m and a s : a m ≈ σ m (o m , o s )= σ o m (1 + ζ 2 o s ) 1 2 ! , a s ≈ σ s (o m , o s )= σ ν (o m + ω ) (1 + ζ 2 ν 2 o s ) 1 2 ! - a 2 m (2) where ν = 4 - 2 √ 2 and ω = - log( √ 2 + 1). The hyperbolic tangent can be derived from tanh(x )= 2σ (2x ) - 1. a l-1 m a l-1 s o l m o l s a l m a l s LMM NMM Figure: Linear Moment Matching (LMM) and Nonlinear Moment Matching (NMM) are performed at the weights/bias sums and activations respectively. S AMPLING - FREE P ROBABILISTIC GRU (SP-GRU) ℎ −1 1− 1− ℎ ℎ ℎ −1 ℎ ℎ Figure: SP-GRU cell structure. Solid lines/boxes and red dotted lines/boxes correspond to operations and variables for mean m and variance s respectively. Circles are element-wise operators. Operation Linear Transformation Nonlinear Transformation Reset Gate o t r ,m = U r ,m x t m + W r ,m h t -1 m + b r ,m r t m = σ m (o t r ,m , o t r ,s ) o t r ,s = U r ,s x t s + W r ,s h t -1 s + b r ,s +[U r ,m ] 2 x t s r t s = σ s (o t r ,m , o t r ,s ) +U r ,s [x t m ] 2 +[W r ,m ] 2 h t -1 s + W r ,s [h t -1 m ] 2 Update Gate o t z ,m = U z ,m x t m + W z ,m h t -1 m + b z ,m z t m = σ m (o t z ,m , o t z ,s ) o t z ,s = U z ,s x t s + W z ,s h t -1 s + b z ,s +[U z ,m ] 2 x t s z t s = σ s (o t z ,m , o t z ,s ) +U z ,s [x t m ] 2 +[W z ,m ] 2 h t -1 s + W z ,s [h t -1 m ] 2 State Candidate o t ˆ h,m = U ˆ h,m x t m + W ˆ h,m h t -1 m + b ˆ h,m ˆ h t m = tanh m (o t ˆ h,m , o t ˆ h,s ) o t ˆ h,s = U ˆ h,s x t s + W ˆ h,s h t -1 s + b ˆ h,s +[U ˆ h,m ] 2 x t s ˆ h t s = tanh s (o t ˆ h,m , o t ˆ h,s ) +U ˆ h,s [x t m ] 2 +[W ˆ h,m ] 2 h t -1 s + W ˆ h,s [h t -1 m ] 2 Cell State h t m =(1 - z t m ) ˆ h t m + z t m h t -1 m Not Needed h t s = [(1 - z t s )] 2 ˆ h t s +[z t s ] 2 h t -1 s Table: SP-GRU operations in mean and variance. and [A] 2 denotes the Hadamard product and A A of a matrix/vector A respectively. Note the Cell State does not involve nonlinear operations. E XPERIMENT 1: M OVING MNIST Figure: (a) Angle deviation trajectories. (b) Speed deviation trajectories. θ Ground Truth Prediction Uncertainty 20 ◦ 25 ◦ 30 ◦ 35 ◦ v Ground Truth Prediction Uncertainty 5.0% 5.5% 6.0% 6.5% Figure: Predictions and uncertainties (frames 11, 15, and 20) from testing varying deviations from trained trajectories (first of four rows, blue). Top: angle. Middle: speed. Bottom: pixel-level noise. Right: the average sum of per frame pixel-level variance using SP-GRU and MC-LSTM. 2 Moving Digits Prediction z }| { 3 Moving Digits (Out of Domain) Prediction z }| { Figure: SP-GRU predictor results. Left 3 rows: 2 moving digits (top: ground truth, middle: mean prediction, bottom: uncertainty estimate). Right 3 rows: 3 moving digits which are out of domain (i.e., not seen in training). E XPERIMENT 2: N ORMATIVE M ODELING IN N EUROIMAGING 1. Brain Connectivity Sequence Sample Generation Original Subject Ordered and Binned by RAVLT Progression N Samples i = 1 i = N PiB Negative PiB Positive Original Subjects (Unordered, PiB Positive and Negative) PiB Negative Ordered and Binned by RAVLT Progression PiB Positive Ordered and Binned by RAVLT Progression N Samples of PiB Negative i = 1 i = N N Samples of PiB Positive i = 1 i = N Figure: The preprocessing procedure used to generate sample data for SP-GRU. Left: Global connectivity sequence samples. Right: PiB+/- connectivity sequence samples. 2. Normative Probability Map (NPM) (Marquand et al., 2016) For each subject i , our true response at time j for connectivity k is given by y ijk , with a bin-level variance of σ njk . The SP-GRU predicts a mean response ¯ y ijk and variation σ ijk . Normative Probability Map (NPM): z ijk =(y ijk - ¯ y ijk )/ q σ 2 ijk + σ 2 njk . 3. Normative Modeling in Neuroimaging: Pipeline (1) Test Sample Inputs (each i and t: 1761 connectivities) i = 1 i = N (2) Predictions with Uncertainties (each i and t: 1761 means and variances) t = 1 t = 2 t = 3 t = 4 (3) Sequential Normative Probability Maps (each i and t: 1761 NPMs) SP-GRU (Trained Predictor) Given t = 1,2,3,4 Predict t = 5,6,7,8 t = 5 t = 6 t = 7 t = 8 i = 1 i = N 1761 NPMs 1761 NPMs 1761 NPMs 1761 NPMs 1761 NPMs 1761 NPMs 1761 NPMs 1761 NPMs t = 5 t = 6 t = 7 t = 8 i = 1 i = N (4) Sequential Extreme Value Statistics (each i and t: 1 EVS) t = 5 t = 6 t = 7 t = 8 i = 1 i = N 1 EVS 1 EVS 1 EVS 1 EVS 1 EVS 1 EVS 1 EVS 1 EVS (5) Extreme Value Distributions 1. Construct histogram of EVS for each t 2. Fit generalized extreme value distributions (GED) 3. Derive confidence intervals t = 5 t = 6 t = 7 t = 8 Extreme Value Statistics Robust summary of NPMs (Mean of top 5%) 1761 NPMs 1 EVS t = 5 t = 6 t = 7 t = 8 1 EVS 1 EVS 1 EVS 1 EVS (6) Sequential EVS of a New Subject N’ (following the above pipeline on a new subject) (7) Outlier Detection Outlier EVS in at least one t ⇒ Outlier subject t = 5 t = 6 t = 7 t = 8 i = N’ 1 EVS 1 EVS 1 EVS 1 EVS Figure: Normative modeling pipeline for preclinical AD. (1) Given a set of test inputs (t = 1, 2, 3, 4), (2) use the pretrained SP-GRU to make mean and variance predictions for each connectivity and t = 5, 6, 7, 8. (3) Compute NPM for each prediction, and (4) derive EVS for each sample i and t . (5) Fit GED and construct confidence intervals based on N EVS for each t . (6) Given a new sample, derive EVS following (1)-(4), and (7) check the confidence intervals from (5) to determine heterogeneity. 4. Outlier Detection: Cognitively Healthy (PiB-) vs. At-Risk (PiB+) Detected outliers: 9 of 100 samples in PiB- and 19 of 100 samples in PiB+. Implication: Larger absolute fluctuations in DWI connectivity may be a good indicator for disease risk as measured by amyloid burden Research supported in part by NIH (R01AG040396, R01AG021155, R01AG027161, P50AG033514, R01AG059312, R01EB022883, R01AG062336), the Center for Predictive and Computational Phenotyping (U54AI117924), NSF CAREER Award (1252725), and a predoctoral fellowship to RRM via T32LM012413.