Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu 1 , Michael C. Hughes 2 , Sonali Parbhoo 3 , Maurizio Zazzi 4 , Volker Roth 3 , and Finale Doshi-Velez 2 1 Stanford University, [email protected]2 Harvard University SEAS, [email protected], fi[email protected]3 University of Basel, {sonali.parbhoo,volker.roth}@unibas.ch 4 University of Siena, [email protected]Abstract The lack of interpretability remains a key barrier to the adop- tion of deep models in many applications. In this work, we explicitly regularize deep models so human users might step through the process behind their predictions in little time. Specifically, we train deep time-series models so their class- probability predictions have high accuracy while being closely modeled by decision trees with few nodes. Using intuitive toy examples as well as medical tasks for treating sepsis and HIV, we demonstrate that this new tree regularization yields models that are easier for humans to simulate than simpler L1 or L2 penalties without sacrificing predictive power. 1 Introduction Deep models have become the de-facto approach for pre- diction in a variety of applications such as image classifi- cation (e.g. (Krizhevsky, Sutskever, and Hinton 2012)) and machine translation (e.g. (Bahdanau, Cho, and Bengio 2014; Sutskever, Vinyals, and Le 2014)). However, many practi- tioners are reluctant to adopt deep models because their pre- dictions are difficult to interpret. In this work, we seek a spe- cific form of interpretability known as human-simulability. A human-simulatable model is one in which a human user can “take in input data together with the parameters of the model and in reasonable time step through every calculation required to produce a prediction” (Lipton 2016). For exam- ple, small decision trees with only a few nodes are easy for humans to simulate and thus understand and trust. In con- trast, even simple deep models like multi-layer perceptrons with a few dozen units can have far too many parameters and connections for a human to easily step through. Deep models for sequences are even more challenging. Of course, decision trees with too many nodes are also hard to simulate. Our key research question is: can we create deep models that are well-approximated by compact, human-simulatable models? The question of creating accurate yet human-simulatable models is an important one, because in many domains sim- ulatability is paramount. For example, despite advances in deep learning for clinical decision support (e.g. (Miotto et al. A version of this work will appear in AAAI 2018 (https:// aaai.org/Conferences/AAAI-18/). This paper includes an extended appendix with supplementary material. 2016; Choi et al. 2016; Che et al. 2015)), the clinical com- munity remains skeptical of machine learning systems (Chen and Asch 2017). Simulatability allows clinicians to audit pre- dictions easily. They can manually inspect changes to outputs under slightly-perturbed inputs, check substeps against their expert knowledge, and identify when predictions are made due to systemic bias in the data rather than real causes. Sim- ilar needs for simulatability exist in many decision-critical domains such as disaster response or recidivism prediction. To address this need for interpretability, a number of works have been developed to assist in the interpretation of already- trained models. Craven and Shavlik (1996) train decision trees that mimic the predictions of a fixed, pretrained neural network, but do not train the network itself to be simpler. Other post-hoc interpretations typically typically evaluate the sensitivity of predictions to local perturbations of inputs or the input gradient (Ribeiro, Singh, and Guestrin 2016; Sel- varaju et al. 2016; Adler et al. 2016; Lundberg and Lee 2016; Erhan et al. 2009). In parallel, research efforts have empha- sized that simple lists of (perhaps locally) important features are not sufficient: Singh, Ribeiro, and Guestrin (2016) pro- vide explanations in the form of programs; Lakkaraju, Bach, and Leskovec (2016) learn decision sets and show benefits over other rule-based methods. These techniques focus on understanding already learned models, rather than finding models that are more interpretable. However, it is well-known that deep models often have mul- tiple optima of similar predictive accuracy (Goodfellow, Ben- gio, and Courville 2016), and thus one might hope to find more interpretable models with equal predictive accuracy. However, the field of optimizing deep models for interpretabil- ity remains nascent. Ross, Hughes, and Doshi-Velez (2017) penalize input sensitivity to features marked as less relevant. Lei, Barzilay, and Jaakkola (2016) train deep models that make predictions from text and simultaneously highlight con- tiguous subsets of words, called a “rationale,” to justify each prediction. While both works optimize their deep models to expose relevant features, lists of features are not sufficient to simulate the prediction. Contributions. In this work, we take steps toward optimiz- ing deep models for human-simulatability via a new model complexity penalty function we call tree regularization. Tree regularization favors models whose decision boundaries can arXiv:1711.06178v1 [stat.ML] 16 Nov 2017
26
Embed
Beyond Sparsity: Tree Regularization of Deep Models for ...Beyond Sparsity: Tree Regularization of Deep Models for Interpretability Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3, Maurizio
This document is posted to help you gain knowledge. Please leave a comment to let me know what you think about it! Share it to your friends and learn new things together.
Transcript
Beyond Sparsity: Tree Regularization of Deep Models for Interpretability
Mike Wu1, Michael C. Hughes2, Sonali Parbhoo3,Maurizio Zazzi4, Volker Roth3, and Finale Doshi-Velez2
The lack of interpretability remains a key barrier to the adop-tion of deep models in many applications. In this work, weexplicitly regularize deep models so human users might stepthrough the process behind their predictions in little time.Specifically, we train deep time-series models so their class-probability predictions have high accuracy while being closelymodeled by decision trees with few nodes. Using intuitive toyexamples as well as medical tasks for treating sepsis and HIV,we demonstrate that this new tree regularization yields modelsthat are easier for humans to simulate than simpler L1 or L2penalties without sacrificing predictive power.
1 IntroductionDeep models have become the de-facto approach for pre-diction in a variety of applications such as image classifi-cation (e.g. (Krizhevsky, Sutskever, and Hinton 2012)) andmachine translation (e.g. (Bahdanau, Cho, and Bengio 2014;Sutskever, Vinyals, and Le 2014)). However, many practi-tioners are reluctant to adopt deep models because their pre-dictions are difficult to interpret. In this work, we seek a spe-cific form of interpretability known as human-simulability.A human-simulatable model is one in which a human usercan “take in input data together with the parameters of themodel and in reasonable time step through every calculationrequired to produce a prediction” (Lipton 2016). For exam-ple, small decision trees with only a few nodes are easy forhumans to simulate and thus understand and trust. In con-trast, even simple deep models like multi-layer perceptronswith a few dozen units can have far too many parameters andconnections for a human to easily step through. Deep modelsfor sequences are even more challenging. Of course, decisiontrees with too many nodes are also hard to simulate. Ourkey research question is: can we create deep models that arewell-approximated by compact, human-simulatable models?
The question of creating accurate yet human-simulatablemodels is an important one, because in many domains sim-ulatability is paramount. For example, despite advances indeep learning for clinical decision support (e.g. (Miotto et al.
A version of this work will appear in AAAI 2018 (https://aaai.org/Conferences/AAAI-18/). This paper includesan extended appendix with supplementary material.
2016; Choi et al. 2016; Che et al. 2015)), the clinical com-munity remains skeptical of machine learning systems (Chenand Asch 2017). Simulatability allows clinicians to audit pre-dictions easily. They can manually inspect changes to outputsunder slightly-perturbed inputs, check substeps against theirexpert knowledge, and identify when predictions are madedue to systemic bias in the data rather than real causes. Sim-ilar needs for simulatability exist in many decision-criticaldomains such as disaster response or recidivism prediction.
To address this need for interpretability, a number of workshave been developed to assist in the interpretation of already-trained models. Craven and Shavlik (1996) train decisiontrees that mimic the predictions of a fixed, pretrained neuralnetwork, but do not train the network itself to be simpler.Other post-hoc interpretations typically typically evaluate thesensitivity of predictions to local perturbations of inputs orthe input gradient (Ribeiro, Singh, and Guestrin 2016; Sel-varaju et al. 2016; Adler et al. 2016; Lundberg and Lee 2016;Erhan et al. 2009). In parallel, research efforts have empha-sized that simple lists of (perhaps locally) important featuresare not sufficient: Singh, Ribeiro, and Guestrin (2016) pro-vide explanations in the form of programs; Lakkaraju, Bach,and Leskovec (2016) learn decision sets and show benefitsover other rule-based methods.
These techniques focus on understanding already learnedmodels, rather than finding models that are more interpretable.However, it is well-known that deep models often have mul-tiple optima of similar predictive accuracy (Goodfellow, Ben-gio, and Courville 2016), and thus one might hope to findmore interpretable models with equal predictive accuracy.However, the field of optimizing deep models for interpretabil-ity remains nascent. Ross, Hughes, and Doshi-Velez (2017)penalize input sensitivity to features marked as less relevant.Lei, Barzilay, and Jaakkola (2016) train deep models thatmake predictions from text and simultaneously highlight con-tiguous subsets of words, called a “rationale,” to justify eachprediction. While both works optimize their deep models toexpose relevant features, lists of features are not sufficient tosimulate the prediction.
Contributions. In this work, we take steps toward optimiz-ing deep models for human-simulatability via a new modelcomplexity penalty function we call tree regularization. Treeregularization favors models whose decision boundaries can
arX
iv:1
711.
0617
8v1
[st
at.M
L]
16
Nov
201
7
be well-approximated by small decision-trees, thus penaliz-ing models that would require many calculations to simulatepredictions. We first demonstrate how this technique can beused to train simple multi-layer perceptrons to have tree-likedecision boundaries. We then focus on time-series appli-cations and show that gated recurrent unit (GRU) modelstrained with strong tree-regularization reach a high-accuracy-at-low-complexity sweet spot that is not possible with anystrength of L1 or L2 regularization. Prediction quality canbe further boosted by training new hybrid models – GRU-HMMs – which explain the residuals of interpretable discreteHMMs via tree-regularized GRUs. We further show thatthe approximate decision trees for our tree-regularized deepmodels are useful for human simulation and interpretability.We demonstrate our approach on a speech recognition taskand two medical treatment prediction tasks for patients withsepsis in the intensive care unit (ICU) and for patients withhuman immunodeficiency virus (HIV). Throughout, we alsoshow that standalone decision trees as a baseline are notice-ably less accurate than our tree-regularized deep models. Wehave released an open-source Python toolbox to allow othersto experiment with tree regularization 1.
Related work. While there is little work (as mentionedabove) on optimizing models for interpretability, there aresome related threads. The first is model compression, whichtrains smaller models that perform similarly to large, black-box models (e.g. (?; Hinton, Vinyals, and Dean 2015;Balan et al. 2015; Han et al. 2015)). Other efforts specifi-cally train very sparse networks via L1 penalties (Zhang, Lee,and Jordan 2016) or even binary neural networks (Tang, Hua,and Wang 2017; Rastegari et al. 2016) with the goal of fastercomputation. Edge and node regularization is commonly usedto improve prediction accuracy (Drucker and Le Cun 1992;Ochiai et al. 2017), and recently Hu et al. (2016) improveprediction accuracy by training neural networks so that pre-dictions match a small list of known domain-specific first-order logic rules. Sometimes, these regularizations—whichall smooth or simplify decision boundaries—can have theeffect of also improving interpretability. However, there isno guarantee that these regularizations will improve inter-pretability; we emphasize that specifically training deep mod-els to have easily-simulatable decision boundaries is (to ourknowledge) novel.
2 Background and NotationWe consider supervised learning tasks given datasets of Nlabeled examples, where each example (indexed by n) hasan input feature vectors xn and a target output vector yn. Weshall assume the targets yn are binary, though it is simple toextend to other types. When modeling time series, each ex-ample sequence n contains Tn timesteps indexed by t whicheach have a feature vector xnt and an output ynt. Formally,we write: xn = [xn1 . . . xnTn
] and yn = [yn1 . . . ynTn]. Each
value ynt could be prediction about the next timestep (e.g. thecharacter at time t+ 1) or some other task-related annotation(e.g. if the patient became septic at time t).
Simple neural networks. A multi-layer perceptron (MLP)makes predictions yn of the target yn via a functionyn(xn,W ), where the vector W represents all parameters ofthe network. Given a data set (xn, yn), our goal is to learnthe parameters W to minimize the objective
minW
λΨ(W ) +
N∑n=1
loss(yn, yn(xn,W )) (1)
For binary targets yn, the logistic loss (binary cross entropy)is an effective choice. The regularization term Ψ(W ) canrepresent L1 or L2 penalties (e.g. (Drucker and Le Cun 1992;Goodfellow, Bengio, and Courville 2016; Ochiai et al. 2017))or our new regularization.
Recurrent Neural Networks with Gated Recurrent Units.A recurrent neural network (RNN) takes as input an arbitrarylength sequence xn = [xn1 . . . xnTn
] and produces a “hiddenstate” sequence hn = [hn1 . . . hnTn
] of the same length asthe input. Each hidden state vector at timestep t represents alocation in a (possibly low-dimensional) “state space” withKdimensions: hnt ∈ RK . RNNs perform sequential nonlinearembedding of the form hnt = f(xnt, hnt−1) in hope thatthe state space location hnt is a useful summary statistic formaking predictions of the target ynt at timestep t.
Many different variants of the transition function architec-ture f have been proposed to solve the challenge of capturinglong-term dependencies. In this paper, we use gated recurrentunits (GRUs) (Cho et al. 2014), which are simpler than otheralternatives such as long short-term memory units (LSTMs)(Hochreiter and Schmidhuber 1997). While GRUs are con-venient, any differentiable RNN architecture is compatiblewith our new tree-regularization approach.
Below we describe the evolution of a single GRU sequence,dropping the sequence index n for readability. The GRU tran-sition function f produces the state vector ht = [ht1 . . . htK ]from a previous state ht−1 and an input vector xt, via thefollowing feed-forward architecture:
output state : htk = (1− ztk)ht−1,k + zt,khtk (2)
candidate state : htk = tanh(V hk xt + Uhk (rt ht−1))
update gate : ztk = σ(V zk xt + Uzkht−1)
reset gate : rtk = σ(V rk xt + Urkht−1)
The internal network nodes include candidate state gatesh, update gates z and reset gates r which have the samecardinalty as the state vector h. Reset gates allow the networkto forget past state vectors when set near zero via the logisticsigmoid nonlinearity σ(·). Update gates allow the network toeither pass along the previous state vector unchanged or usethe new candidate state vector instead. This architecture isdiagrammed in Figure 1.
The predicted probability of the binary label yt for time tis a sigmoid transformation of the state at time t:
yt = σ(wTht) (3)
Here, weight vector w ∈ RK represents the parameters ofthis output layer. We denote the parameters for the entire
2
1-
sigm sigm tanh
sigmht-1
ht
yt
rt zth~t
x t-1
λ
Figure 1: Diagram of gated recurrent unit (GRU) used foreach timestep our neural time-series model. The orange trian-gle indicates the predicted output yt at time t.
GRU-RNN model as W = (w,U, V ), concatenating all com-ponent parameters. We can train GRU-RNN time-series mod-els (hereafter often just called GRUs) via the following lossminimization objective:
minW
λΨ(W ) +
N∑n=1
Tn∑n=1
loss(ynt, ynt(xn,W )) (4)
where again Ψ(W ) defines a regularization cost.
3 Tree Regularization for Deep ModelsWe now propose a novel tree regularization function Ω(W )for the parameters of a differentiable model which attemptsto penalize models whose predictions are not easily simu-latable. Of course, it is difficult to measure “simulatability”directly for an arbitrary network, so we take inspiration fromdecision trees. Our chosen method has two stages: first, finda single binary decision tree which accurately reproducesthe network’s thresholded binary predictions yn given inputxn. Second, measure the complexity of this decision tree asthe output of Ω(W ). We measure complexity as the aver-age decision path length—the average number of decisionnodes that must be touched to make a prediction for an inputexample xn. We compute the average with respect to somedesignated reference dataset of example inputs D = xnfrom the training set. While many ways to measure complex-ity exist, we find average path length is most relevant to ournotion of simulatability. Remember that for us, human simu-lation requires stepping through every calculation requiredto make a prediction. Average path length exactly counts thenumber of true-or-false boolean calculations needed to makean average prediction, assuming the model is a decision tree.Total number of nodes could be used as a metric, but mightpenalize more accurate trees that have short paths for mostexamples but need more involved logic for few outliers.
Our true-average-path-length cost function Ω(W ) is de-tailed in Alg. 1. It requires two subroutines, TRAINTREE andPATHLENGTH. TRAINTREE trains a binary decision tree toaccurately reproduce the provided labeled examples xn, yn.We use the DecisionTree module distributed in Python’sscikit-learn (Pedregosa et al. 2011) with post-pruning to sim-plify the tree. These trees can give probabilistic predictions ateach leaf. (Complete decision-tree training details are in the
Algorithm 1 Average-Path-Length Cost Function
Require:y(·,W ) : binary prediction function, with parameters WD = xnNn=1 : reference dataset with N examples
1: function Ω(W )2: tree← TRAINTREE(xn, y(xn,W ))3: return 1
N
∑n PATHLENGTH(tree, xn)
supplement.) Next, PATHLENGTH counts how many nodesare needed to make a specific input to an output node in theprovided decision tree. In our evaluations, we will apply ouraverage-decision-tree-path-length regularization, or simply“tree regularization,” to several neural models.Alg. 1 defines our average-path-length cost function Ω(W ),which can be plugged into the abstract regularization termΨ(W ) in the objectives in equations 1 and 4.
Making the Decision-Tree Loss Differentiable Trainingdecision trees is not differentiable, and thus Ω(W ) as definedin Alg. 1 is not differentiable with respect to the networkparameters W (unlike standard regularizers such as the L1or L2 norm). While one could resort to derivative-free opti-mization techniques (Audet and Kokkolaras 2016), gradientdescent has been an extremely fast and robust way of trainingnetworks (Goodfellow, Bengio, and Courville 2016).
A key technical contribution of our work is introducingand training a surrogate regularization function Ω(W ) :supp(W ) → R+ to map each candidate neural model pa-rameter vector W to an estimate of the average-path-length.Our approximate function Ω is implemented as a standalonemulti-layer perceptron network and is thus differentiable. Letvector ξ of size k denote the parameters of this chosen MLPapproximator. We can train Ω to be a good estimator byminimizing a squared error loss function:
minξ
∑Jj=1(Ω(Wj)− Ω(Wj , ξ))
2 + ε||ξ||22 (5)
where Wj are the entire set of parameters for our model,ε > 0 is a regularization strength, and we assume we havea dataset of J known parameter vectors and their associ-ated true path-lengths: Wj ,Ω(Wj)Jj=1. This dataset canbe assembled using the candidate W vectors obtained whiletraining our target neural model y(·,W ), as well as by evalu-ating Ω(W ) for randomly generatedW . Importantly, one cantrain the surrogate function Ω in parallel with our network.In the supplement, we show evidence that our surrogate pre-dictor Ω(·) tracks the true average path length as we train thetarget predictor y(·,W ).
Training the Surrogate Loss Even moderately-sizedGRUs can have parameter vectors W with thousands ofdimensions. Our labeled dataset for surrogate training –Wj ,Ω(Wj)Jj=1—will only have one Wj example fromeach target network training iteration. Thus, in early itera-tions, we will have only few examples from which to learna good surrogate function Ω(W ). We resolve this challenge
3
via augmenting our training set with additional examples: Werandomly sample weight vectors W and calculate the trueaverage path length Ω(W ), and we also perform several ran-dom restarts on the unregularized GRU and use those weightsin our training set.
A second challenge occurs later in training: as the modelparameters W shift away from their initial values, those earlyparameters may not be as relevant in characterizing the cur-rent decision function of the GRU. To address this, for eachepoch, we use examples only from the past E epochs (in ad-dition to augmentation), where in practice, E is empiricallychosen. Using examples from a fixed window of epochs alsospeeds up training. The supplement shows a comparison ofthe importance of these heuristics for efficient and accuratetraining—empirically, data augmentation for stabilizing sur-rogate training allows us to scale to GRUs with 100s of nodes.GRUs of this size are sufficient for many real problems, suchas those we encounter in healthcare domains.
Typically, we use J = 50 labeled pairs for surrogate train-ing for toy datasets and J = 100 for real world datasets.Optimization of our surrogate objective is done via gradientdescent. We use Autograd to compute gradients of the loss inEq. (5) with respect to ξ, then use Adam to compute descentdirections with step sizes set to 0.01 for toy datasets and0.001 for real world datasets.
4 Tree-Regularized MLPs: A DemonstrationWhile time-series models are the main focus of this work,we first demonstrate tree regularization on a simple binaryclassification task to build intuition. We call this task the 2DParabola problem, because as Fig. 2(a) shows, the trainingdata consists of 2D input points whose two-class decisionboundary is roughly shaped like a parabola. The true decisionfunction is defined by y = 5 ∗ (x − 0.5)2 + 0.4. We sam-pled 500 input points xn uniformly within the unit square[0, 1]× [0, 1] and labeled those above the decision functionas positive. To make it easy for models to overfit, we flipped10% of the points in a region near the boundary. A random30% were held out for testing.
For the classifier y, we train a 3-layer MLP with 100 firstlayer nodes, 100 second layer nodes, and 10 third layer nodes.This MLP is intentionally overly expressive to encourageoverfitting and expose the impact of different forms of regular-ization: our proposed tree regularization Ψ(W ) = Ω(W ) andtwo baselines: an L2 penalty on the weights Ψ(W ) = ||W ||2,and an L1 penalty on the weights Ψ(W ) = ||W ||1. For eachregularization function, we train models at many differentregularization strengths λ chosen to explore the full rangeof decision boundary complexities possible under each tech-nique.
For our tree regularization, we model our surrogate Ω(W )with a 1-hidden layer MLP with 25 units. We find this simplearchitecture works well, but certainly more complex MLPscould could be used on more complex problems. The objec-tive in equation 1 was optimized via Adam gradient descent(Kingma and Ba 2014) using a batch size of 100 and a learn-ing rate of 1e-3 for 250 epochs, and hyperparameters wereset via cross validation using grid search (see supplement for
full experimental details).Fig. 2 (b) shows the each trained model as a single point
in a 2D fitness space: the x-axis measures model complexityvia our average-path-length metric, and the y-axis measuresAUC prediction performance. These results show that sim-ple L1 or L2 regularization does not produce models withboth small node count and good predictions at any value ofthe regularization strength λ. As expected, large λ valuesfor L1 and L2 only produce far-too-simple linear decisionboundaries with poor accuracies. In contrast, our proposedtree regularization directly optimizes the MLP to have simpletree-like boundaries at high λ values which can still yieldgood predictions.
The lower panes of Fig. 2 shows these boundaries. Ourtree regularization is uniquely able to create axis-alignedfunctions, because decision trees prefer functions that areaxis-aligned splits. These axis-aligned functions require veryfew nodes but are more effective than L1 and L2 counterparts.The L1 boundary is more sharp, whereas the L2 is moreround.
5 Tree-Regularized Time-Series ModelsWe now evaluate our tree-regularization approach on time-series models. We focus on GRU-RNN models, with somelater experiments on new hybrid GRU-HMM models. Aswith the MLP, each regularization technique (tree, L2, L1)can be applied to the output node of the GRU across a rangeof strength parameters λ. Importantly, Algorithm 1 can com-pute the average-decision-tree-path-length for any fixed deepmodel given its parameters, and can hence be used to mea-sure decision boundary complexity under any regularization,including L1 or L2. This means that when training any model,we can track both the predictive performance (as measuredby area-under-the-ROC-curve (AUC); higher values meanbetter predictions), as well as the complexity of the deci-sion tree required to explain each model (as measured byour average path length metric; lower values mean moreinterpretable models). We also show results for a baselinestandalone decision tree classifier without any associateddeep model, sweeping a range of parameters controlling leafsize to explore how this baseline trades off path length andprediction quality. Further details of our experimental proto-col are in the supplement, as well as more extensive resultswith additional baselines.
5.1 TasksSynthetic Task: Signal-and-noise HMM We generateda toy dataset of N = 100 sequences, each with T = 50timesteps. Each timestep has a data vector xnt of 14 binaryfeatures and a single binary output label ynt. The data comesfrom two separate HMM processes. First, a “signal” HMMgenerates the first 7 data dimensions from 5 well-separatedstates. Second, an independent “noise” HMM generates theremaining 7 data dimensions from a different set of 5 states.Each timestep’s output label ynt is produced by a rule involv-ing both the signal data and the signal hidden state: the targetis 1 at timestep t only if both the first signal state is active andthe first observation is turned on. We deliberately designed
4
0.0 0.2 0.4 0.6 0.8 1.00.0
0.2
0.4
0.6
0.8
1.0
(a) Training Data and Binary Class Labels for 2D Parabola
(d) Decision Boundaries with L2 regularizationTree 0.01 Tree 100.0 Tree 700.0 Tree 9500.0 Tree 12000.0 Tree 15000.0
(e) Decision Boundaries Tree regularization
Figure 2: 2D Parabola task: (a) Each training data pointin 2D space, overlaid with true parabolic class boundary.(b): Each method’s prediction quality (AUC) and complexity(path length) metrics, across range of regularization strengthλ. In the small path length regime between 0 and 5, tree reg-ularization produces models with higher AUC than L1 or L2.(c-e): Decision boundaries (black lines) have qualitativelydifferent shapes for different regularization schemes, as regu-larization strength λ increases. We color predictions as truepositive (red), true negative (yellow), false negative (green),and false positive (blue).
the generation process so that neither logistic regression withx as features nor an RNN model that makes predictions fromhidden states alone can perfectly separate this data.
Real-World Tasks: We tested our approach on several realtasks: predicting medical outcomes of hospitalized septic pa-tients, predicting HIV therapy outcomes, and identifying stopphonemes in English speech recordings. To normalize scales,we independently standardized features x via z-scoring.
• Sepsis Critical Care: We study time-series data for 11 786
septic ICU patients from the public MIMIC III dataset(Johnson et al. 2016). We observe at each hour t a datavector xnt of 35 vital signs and lab results as well as a labelvector ynt of 5 binary outcomes. Hourly data xnt measurescontinuous features such as respiration rate (RR), bloodoxygen levels (paO2), fluid levels, and more. Hourly binarylabels ynt include whether the patient died in hospitaland if mechanical ventilation was applied. Models aretrained to predict all 5 output dimensions concurrentlyfrom one shared embedding. The average sequence lengthis 15 hours. 7 070 patients are used in training, 1 769 forvalidation, and 294 for test.
• HIV Therapy Outcome (HIV): We use the EuResist In-tegrated Database (Zazzi et al. 2012) for 53 236 patientsdiagnosed with HIV. We consider 4-6 month intervals (cor-responding to hospital visits) as time steps. Each datavector xnt has 40 features, including blood counts, viralload measurements and lab results. Each output vector ynthas 15 binary labels, including whether a therapy was suc-cessful in reducing viral load to below detection limits, iftherapy caused CD4 blood cell counts to drop to dangerouslevels (indicating AIDS), or if the patient suffered adher-ence issues to medication. The average sequence lengthis 14 steps. 37 618 patients are used for training; 7 986 fortesting, and 7 632 for validation.
• Phonetic Speech (TIMIT): We have recordings of 630speakers of eight major dialects of American English read-ing ten phonetically rich sentences (Garofolo et al. 1993).Each sentence contains time-aligned transcriptions of 60phonemes. We focus on distinguishing stop phonemes(those that stop the flow of air, such as “b” or “g”) fromnon-stops. Each timestep has one binary label ynt indicat-ing if a stop phoneme occurs or not. Each input xnt has 26continuous features: the acoustic signal’s Mel-frequencycepstral coefficients and derivatives. There are 6 303 se-quences, split into 3 697 for training, 925 for validation,and 1 681 for testing. The average length is 614.
5.2 ResultsThe major conclusions of our experiments comparing GRUswith various regularizations are outlined below.
Tree-regularized models have fewer nodes than otherforms of regularization. Across tasks, we see that in thetarget regime of small decision trees (low average-pathlengths), our proposed tree-regularization achieves higher pre-diction quality (higher AUCs). In the signal-and-noise HMMtask, tree regularization (green line in Fig. 3(d)) achievesAUC values near 0.9 when its trees have an average pathlength of 10. Similar models with L1 or L2 regularizationreach this AUC only with trees that are nearly double in com-plexity (path length over 25). On the Sepsis task (Fig. 4) wesee AUC gains of 0.05-0.1 at path lengths of 2-10. On theTIMIT task (Fig. 5a), we see AUC gains of 0.05-0.1 at pathlengths of 20-30. Finally, on the HIV CD4 blood cell counttask in Fig. 5b, we see AUC differences of between 0.03 and0.15 for path lengths of 10-15. The HIV adherence task in
5
X[0] <= 0.5value = [3897, 1103]
class = off
value = [2546, 0]class = off
True
X[4] <= 0.5value = [1351, 1103]
class = off
False
X[3] <= 0.5value = [902, 1103]
class = onvalue = [449, 0]
class = off
X[6] <= 0.5value = [902, 516]
class = offvalue = [0, 587]
class = on
X[5] <= 0.5value = [741, 516]
class = offvalue = [161, 0]
class = off
X[11] <= 0.5value = [588, 516]
class = offvalue = [153, 0]
class = off
X[1] <= 0.5value = [254, 516]
class = onvalue = [334, 0]
class = off
X[8] <= 0.5value = [112, 280]
class = on
X[9] <= 0.5value = [142, 236]
class = on
value = [0, 280]class = on
value = [112, 0]class = off
value = [0, 236]class = on
value = [142, 0]class = off
(a) GRU λ = 1
X[0] <= 0.5value = [4439, 561]
class = off
value = [2546, 0]class = off
True
X[3] <= 0.5value = [1893, 561]
class = off
False
value = [1744, 0]class = off
X[4] <= 0.5value = [149, 561]
class = on
X[7] <= 0.5value = [26, 561]
class = onvalue = [123, 0]
class = off
value = [0, 537]class = on
value = [26, 24]class = off
(b) GRU λ = 800
X[0] <= 0.5value = [4413, 587]
class = off
value = [2546, 0]class = off
True
X[3] <= 0.5value = [1867, 587]
class = off
False
value = [1744, 0]class = off
X[4] <= 0.5value = [123, 587]
class = on
value = [0, 587]class = on
value = [123, 0]class = off
(c) GRU λ = 1 000
0 10 20 30 40Average Path Length
0.5
0.6
0.7
0.8
0.9
1.0
AU
C (T
est)
GRU (L1)GRU (L2)GRU (Tree)Decision Tree
(d) GRU
Figure 3: Toy Signal-and-Noise HMM Task: (a)-(c) Decision trees trained to mimic predictions of GRU models with 25 hiddenstates at different regularization strengths λ; as expected, increasing λ decreases the size of the learned trees (see supplement formore trees). Decision tree (c) suggests the model learns to predict positive output (blue) if and only if “x[0] == 1 and x[3] == 1and x[4] == 0”, which is consistent with the true rule we used to generate labels: assign positive label only if first dimension ison (x[0] == 1) and first state is active (emission probabilities for this state: [.5 .5 .5 .5 0 . . .]). (d) Tree-regularized GRU modelsreach a sweet spot of small path lengths yet high AUC predictions that alternatives cannot reach at any tested value of λ.
Figure 4: Sepsis task: Study of different regularization techniques for GRU model with 100 states, trained to jointly predict 5binary outcomes. Panels (a) and (c) show AUC vs. average path length for 2 of the 5 outcomes (remainder in the supplement); inboth cases, tree-regularization provides higher accuracy in the target regime of low-complexity decision trees. Panels (b) and (d)show the associated decision trees for λ = 2 000; these were found by clinically interpretable by an ICU clinician (see main text).
20 40 60Average path Length
0.5
0.6
0.7
0.8
0.9
1.0
AU
C (T
est)
GRU (L1)GRU (L2)GRU (Tree)Decision Tree
(a) TIMIT Stop Phonemes
10.0 12.5 15.0 17.5 20.0Average Path Length
0.55
0.60
0.65
0.70
AU
C (T
est)
GRU (L1)GRU (L2)GRU (Tree)Decision Tree
(b) HIV: CD4+ ≤ 200 cells/ml
15.0 17.5 20.0 22.5 25.0 27.5Average Path Length
0.6
0.7
0.8
AU
C (T
est)
GRU (L1)GRU (L2)GRU (Tree)Decision Tree
(c) HIV Therapy Adherence
Baseline VL <= 45.68value = [33957, 39928]
class = Poor Adherence: OFF
value = [146, 3524]class = Poor Adherence: OFF
Baseline VL <= 900486.29value = [33811, 36404]
class = Poor Adherence: ON
No. of prior treatment lines < 4.0value = [29541, 31271]
class = Poor Adherence: OFF
Baseline CD4 <216.94value = [4270, 5133]
class = Poor Adherence: ON
value = [27262, 21472]class = Poor Adherence: OFF
value = [2279, 9799]class = Poor Adherence: ON
Age <38.0value = [2623, 4666]
class = Poor Adherence: ONvalue = [1647, 467]
class = Poor Adherence: ON
value = [1306, 2219]class = Poor Adherence: ON
Sex <= 0.5value = [1317, 2447]
class = Poor Adherence: ON
IDU <= 0.5value = [71, 412]
class = Poor Adherence: OFFvalue = [1246, 2035]
class = Poor Adherence: ON
value = [44, 13]class = Poor Adherence: ON
value = [27, 399]class = Poor Adherence: OFF
(d) HIV Therapy Adherence
Figure 5: TIMIT and HIV tasks: Study of different regularization techniques for GRU model with 75 states. Panels (a)-(c) aretradeoff curves showing how AUC predictive power and decision-tree complexity evolve with increasing regularization strengthunder L1, L2 or tree regularization on both TIMIT and HIV tasks. The GRU is trained to jointly predict 15 binary outcomes forHIV, of which 2 are shown here in Panels (b) - (c). The GRU’s decision tree proxy for HIV Adherence is shown in (d).
6
Fig. 5d has AUC gains of between 0.03 and 0.05 in the pathlength range of 19 to 25 while at smaller paths all methodsare quite poor, indicating the problem’s difficulty. Overall,these AUC gains are particularly useful in determining howto administer subsequent HIV therapies.
We emphasize that our tree-regularization usually achievesa sweet spot of high AUCs at short path lengths not possiblewith standalone decision trees (orange lines), L1-regularizeddeep models (red lines) or L2-regularized deep models (bluelines). In unshown experiments, we also tested elastic netregularization (Zou and Hastie 2005), a linear combinationof L1 and L2 penalities. We found elastic nets to follow thesame trend lines as L1 and L2, with no visible differences. Indomains where human-simulatability is required, increases inprediction accuracy in the small-complexity regime can meanthe difference between models that provide value on a taskand models that are unusable, either because performance istoo poor or predictions are uninterpretable.
Our learned decision tree proxies are interpretable.Across all tasks, the decision trees which mimic the pre-dictions of tree-regularized deep models are small enoughto simulate by hand (path length ≤ 25) and help users graspthe model’s nonlinear prediction logic. Intuitively, the treesfor our synthetic task in Fig. 3(a)-(c) decrease in size as thestrength λ increases. The logic of these trees also matchesthe true labeling process: even the simplest tree (c) checks arelevant subset of input dimensions necessary to verify thatboth the first state and the first output dimension are active.
In Fig. 4, we show decision tree proxies for our deep mod-els on two sepsis prediction tasks: mortality and need forventilation. We consulted a clinical expert on sepsis treat-ment, who noted that the trees helped him understand whatthe models might be doing and thus determine if he wouldtrust the deep model. For example, he said that using FiO2,RR, CO2 and paO2 to predict need for mechanical ventilation(Fig. 4d) was sensible, as these all measure breathing quality.In contrast, the in-hospital mortality tree (Fig. 4b) predictsthat some young patients with no organ failure have highmortality rates while other young patients with organ failurehave low mortality. These counter-intuitive results led to hy-potheses about how uncaptured variables impact the trainingprocess. Such reasoning would not be possible from simplesensitivity analyses of the deep model.
Finally, we have verified that the decision tree proxies ofour tree-regularized deep models of the HIV task in Fig. 5dare interpretable for understanding why a patient has troubleadhering to a prescription; that is, taking drugs regularly asdirected. Our clinical collaborators confirm that the baselineviral load and number of prior treatment lines, which areprominent attributes for the decisions in Fig. 5d, are usefulpredictors of a patient with adherence issues. Several med-ical studies (Langford, Ananworanich, and Cooper 2007;Socas et al. 2011) suggest that patients with higher base-line viral loads tend to have faster disease progression, andhence have to take several drug cocktails to combat resistance.Juggling many drugs typically makes it difficult for these pa-tients to adhere as directed. We hope interpretable predictive
models for adherence could help assess a patient’s overallprognosis (Paterson et al. 2000) and offer opportunities forintervention (e.g. with alternative single-tablet regimens).
Decision trees trained to mimic deep models make faith-ful predictions. Across datasets, we find that each tree-regularized deep time-series model has predictions that agreewith its corresponding decision tree proxy in about 85-90%of test examples. Table 1 shows exact fidelty scores for eachdataset. Thus, the simulatable paths of the decision tree willbe trustworthy in a majority of cases.
Practical runtimes for tree regularization are less thantwice that of simpler L2. While our tree-regularized GRUwith 10 states takes 3977 seconds per epoch on TIMIT, asimilar L2-regularized GRU takes 2116 seconds per epoch.Thus, our new method has cost less than twice the baselineeven when the surrogate is serially computed. Because thesurrogate Ω(W ) will in general be a much smaller modelthan the predictor y(x,W ), we expect one could get fasterper-epoch times by parallelizing the creation of (W,Ω(W ))
training pairs and the training of the surrogate Ω(W ). Ad-ditionally, 3977 seconds includes the time needed to trainthe surrogate. In practice, we do this sparingly, only onceevery 25 epochs, yielding an amortized per-epoch cost of2191 seconds (more runtime results are in the supplement).
Decision trees are stable over multiple optimization runs.When tree regularization is strong (high λ), the decision treestrained to match the predictions of deep models are stable. Forboth signal-and-noise and sepsis tasks, multiple runs fromdifferent random restarts have nearly identical tree shapeand size, perhaps differing by a few nodes. This stability iscrucial to building trust in our method. On the signal-and-noise task (λ = 7000), 7 of 10 independent runs with randominitializations resulted in trees of exactly the same structure,and the others closely resembled those sharing the samesubtrees and features (more details in supplement).
The deep residual GRU-HMM achieves high AUC withless complexity. So far, we have focused on regularizingstandard deep models, such as MLPs or GRUs. Another op-tion is to use a deep model as a residual on another modelthat is already interpretable: for example, discrete HMMs par-tition timesteps into clusters, each of which can be inspected,but its predictions might have limited accuracy. In Fig. 6, weshow the performance of jointly training a GRU-HMM, anew model which combines an HMM with a tree-regularizedGRU to improve its predictions (details and further resultsin the supplement). Here, the ideal path length is zero, indi-cating only the HMM makes predictions. For small average-path-lengths, the GRU-HMM improves the original HMM’spredictions and has simulatability gains over earlier GRUs.On the mechanical ventilation task, the GRU-HMM requiresan average path length of only 28 to reach AUC of 0.88, whilethe GRU alone with the same number of states requires apath length of 60 to reach the same AUC. This suggests that
Table 1: Fidelity of predictions from our trained deep GRU-RNN and its corresponding decision tree. Fidelity is definedas the percentage of test examples on which the predictionmade by a tree agrees with the deep model (Craven andShavlik 1996). We used 20 hidden GRU states for signal-and-noise task, 50 states for all others.
jointly-trained deep residual models may provide even betterinterpretability.
0 5 10 15 20 25 30Average Path Length
0.935
0.940
0.945
0.950
0.955
AU
C (T
est)
GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)
(a) Signal-and-noise 20+5
0 10 20 30Average Path Length
0.72
0.73
0.74
0.75
AU
C (T
est)
GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)
(b) In-Hosp. Mort. 50+50
0 2 4 6 8 10Average Path Length
0.82
0.84
0.86
0.88
0.90
AU
C (T
est)
GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)
(c) Mech. Vent. 50+50
0 20 40 60 80AvHUagH 3ath LHngth
0.92
0.93
0.94
0.95
A8
C (T
Hst)
G58-H00 (L1)G58-H00 (L2)G58-H00 (TUHH)
(d) Stop Phonemes 50+25
Figure 6: Fitness curves for the GRU-HMM, showing predic-tion quality (AUC) vs. complexity (path length) across rangeof regularization strengths λ. Captions show the number ofHMM states plus the number of GRU states. See earlier fig-ures to compare these GRU-HMM numbers to simpler GRUand decision tree baselines.
6 Discussion and ConclusionWe have introduced a novel tree-regularization technique thatencourages the complex decision boundaries of any differen-tiable model to be well-approximated by human-simulatablefunctions, allowing domain experts to quickly understandand approximately compute what the more complex modelis doing. Overall, our training procedure is robust and effi-cient; future work could continue to explore and increase the
stability of the learned models as well as identify ways toapply our approach to situations in which the inputs are notinherently interpretable (e.g. pixels in an image).
Across three complex, real-world domains – HIV treat-ment, sepsis treatment, and human speech processing – ourtree-regularized models provide gains in prediction accuracyin the regime of simpler, approximately human-simulatablemodels. Future work could apply tree regularization to lo-cal, example-specific approximations of a loss (Ribeiro,Singh, and Guestrin 2016) or to representation learning tasks(encouraging embeddings with simple boundaries). Morebroadly, our general training procedure could apply tree-regularization or other procedure-regularization to a wideclass of popular models, helping us move beyond sparsitytoward models humans can easily simulate and thus trust.
AcknowledgementsMW is supported by the U.S. National Science Foundation.MCH is supported by Oracle Labs. SP is supported by theSwiss National Science Foundation project 51MRP0 158328.The authors thank the EuResist Network for providing HIVdata for this study, and thank Matthieu Komorowski for thepreprocessed sepsis data (Raghu et al. 2017). Computationswere supported by the FAS Research Computing Group atHarvard and sciCORE (http://scicore.unibas.ch/) scientificcomputing core facility at University of Basel.
ReferencesAdler, P.; Falk, C.; Friedler, S. A.; Rybeck, G.; Scheidegger,C.; Smith, B.; and Venkatasubramanian, S. 2016. Auditingblack-box models for indirect influence. In ICDM.Audet, C., and Kokkolaras, M. 2016. Blackbox andderivative-free optimization: theory, algorithms and applica-tions. Springer.Bahdanau, D.; Cho, K.; and Bengio, Y. 2014. Neural machinetranslation by jointly learning to align and translate. arXivpreprint arXiv:1409.0473.Balan, A. K.; Rathod, V.; Murphy, K. P.; and Welling, M.2015. Bayesian dark knowledge. In NIPS.Che, Z.; Kale, D.; Li, W.; Bahadori, M. T.; and Liu, Y. 2015.Deep computational phenotyping. In KDD.Chen, J. H., and Asch, S. M. 2017. Machine learning and pre-diction in medicinebeyond the peak of inflated expectations.N Engl J Med 376(26):2507–2509.Cho, K.; Gulcehre, B. v. M. C.; Bahdanau, D.; Schwenk,F. B. H.; and Bengio, Y. 2014. Learning phrase represen-tations using RNN encoder–decoder for statistical machinetranslation. In EMLNP.Choi, E.; Bahadori, M. T.; Schuetz, A.; Stewart, W. F.; andSun, J. 2016. Doctor AI: Predicting clinical events via recur-rent neural networks. In Machine Learning for HealthcareConference.Craven, M., and Shavlik, J. W. 1996. Extracting tree-structured representations of trained networks. In NIPS.Drucker, H., and Le Cun, Y. 1992. Improving generalizationperformance using double backpropagation. IEEE Transac-tions on Neural Networks 3(6):991–997.
8
Erhan, D.; Bengio, Y.; Courville, A.; and Vincent, P. 2009.Visualizing higher-layer features of a deep network. Tech-nical Report 1341, Department of Computer Science andOperations Research, University of Montreal.Garofolo, J. S.; Lamel, L. F.; Fisher, W. M.; Fiscus, J. G.;Pallett, D. S.; Dahlgren, N. L.; and Zue, V. 1993. Timitacoustic-phonetic continuous speech corpus. Linguistic dataconsortium 10(5).Goodfellow, I.; Bengio, Y.; and Courville, A.2016. Deep Learning. MIT Press. http://www.deeplearningbook.org.Han, S.; Pool, J.; Tran, J.; and Dally, W. 2015. Learningboth weights and connections for efficient neural network. InNIPS.Hinton, G.; Vinyals, O.; and Dean, J. 2015. Distill-ing the knowledge in a neural network. arXiv preprintarXiv:1503.02531.Hochreiter, S., and Schmidhuber, J. 1997. Long short-termmemory. Neural computation 9(8):1735–1780.Hu, Z.; Ma, X.; Liu, Z.; Hovy, E.; and Xing, E. 2016. Har-nessing deep neural networks with logic rules. In ACL.Johnson, A. E.; Pollard, T. J.; Shen, L.; Lehman, L. H.; Feng,M.; Ghassemi, M.; Moody, B.; Szolovits, P.; Celi, L. A.; andMark, R. G. 2016. MIMIC-III, a freely accessible criticalcare database. Scientific Data 3.Kingma, D., and Ba, J. 2014. Adam: A method for stochasticoptimization. arXiv preprint arXiv:1412.6980.Krizhevsky, A.; Sutskever, I.; and Hinton, G. E. 2012. Ima-geNet classification with deep convolutional neural networks.In NIPS.Lakkaraju, H.; Bach, S. H.; and Leskovec, J. 2016. Inter-pretable decision sets: A joint framework for description andprediction. In KDD.Langford, S. E.; Ananworanich, J.; and Cooper, D. A. 2007.Predictors of disease progression in hiv infection: a review.AIDS Research and Therapy 4(1):11.Lei, T.; Barzilay, R.; and Jaakkola, T. 2016. Rationalizingneural predictions. arXiv preprint arXiv:1606.04155.Lipton, Z. C. 2016. The mythos of model interpretability.In ICML Workshop on Human Interpretability in MachineLearning.Lundberg, S., and Lee, S.-I. 2016. An unexpected unityamong methods for interpreting model predictions. arXivpreprint arXiv:1611.07478.Miotto, R.; Li, L.; Kidd, B. A.; and Dudley, J. T. 2016. Deeppatient: An unsupervised representation to predict the futureof patients from the electronic health records. ScientificReports 6:26094.Ochiai, T.; Matsuda, S.; Watanabe, H.; and Katagiri, S. 2017.Automatic node selection for deep neural networks usinggroup lasso regularization. In ICASSP.Paterson, D. L.; Swindells, S.; Mohr, J.; Brester, M.; Vergis,E. N.; Squier, C.; Wagener, M. M.; and Singh, N. 2000. Ad-herence to protease inhibitor therapy and outcomes in patientswith hiv infection. Annals of internal medicine 133(1):21–30.
Pedregosa, F.; Varoquaux, G.; Gramfort, A.; Michel, V.; et al.2011. Scikit-learn: Machine learning in Python. Journal ofMachine Learning Research 12:2825–2830.Raghu, A.; Komorowski, M.; Celi, L. A.; Szolovits, P.; andGhassemi, M. 2017. Continuous state-space models for opti-mal sepsis treatment-a deep reinforcement learning approach.In Machine Learning for Healthcare Conference.Rastegari, M.; Ordonez, V.; Redmon, J.; and Farhadi, A. 2016.XNOR-Net: ImageNet classification using binary convolu-tional neural networks. In ECCV.Ribeiro, M. T.; Singh, S.; and Guestrin, C. 2016. Why shouldI trust you?: Explaining the predictions of any classifier. InKDD.Ross, A.; Hughes, M. C.; and Doshi-Velez, F. 2017. Rightfor the right reasons: Training differentiable models by con-straining their explanations. In IJCAI.Selvaraju, R. R.; Das, A.; Vedantam, R.; Cogswell, M.;Parikh, D.; and Batra, D. 2016. Grad-cam: Why did you saythat? visual explanations from deep networks via gradient-based localization. arXiv preprint arXiv:1610.02391.Singh, S.; Ribeiro, M. T.; and Guestrin, C. 2016. Programsas black-box explanations. arXiv preprint arXiv:1611.07579.Socas, M. E.; Sued, O.; Laufer, N.; Lzaro, M. E.; Mingrone,H.; Pryluka, D.; Remondegui, C.; Figueroa, M. I.; Cesar, C.;Gun, A.; Turk, G.; Bouzas, M. B.; Kavasery, R.; Krolewiecki,A.; Prez, H.; Salomn, H.; Cahn, P.; and de SeroconversinStudy Group, G. A. 2011. Acute retroviral syndrome andhigh baseline viral load are predictors of rapid hiv progressionamong untreated argentinean seroconverters. Journal of theInternational AIDS Society 14(1):40–40.Sutskever, I.; Vinyals, O.; and Le, Q. V. 2014. Sequence tosequence learning with neural networks. In NIPS.Tang, W.; Hua, G.; and Wang, L. 2017. How to train acompact binary neural network with high accuracy? In AAAI.Zazzi, M.; Incardona, F.; Rosen-Zvi, M.; Prosperi, M.;Lengauer, T.; Altmann, A.; Sonnerborg, A.; Lavee, T.;Schulter, E.; and Kaiser, R. 2012. Predicting response toantiretroviral treatment by machine learning: the euresistproject. Intervirology 55(2):123–127.Zhang, Y.; Lee, J. D.; and Jordan, M. I. 2016. l1-regularizedneural networks are improperly learnable in polynomial time.In ICML.Zou, H., and Hastie, T. 2005. Regularization and variableselection via the elastic net. Journal of the Royal StatisticalSociety: Series B (Statistical Methodology) 67(2):301–320.
9
Supplementary Material
A Details for Decision-Tree TrainingTraining decision trees with post-pruning. Our average path length function Ω(W ) for determining the complexity of a deepmodel with parameters W – defined in the main paper in Alg. 1 – assumes that we have a robust, black-box way to train binarydecision-trees called TRAINTREE given a labeled dataset xn, yn. For this we use the DecisionTree module distributed inPython’s sci-kit learn, which optimizes information gain with Gini impurity. The specific syntax we use (for reproducibility) is:tree = DecisionTree(min_sample_count=5)tree.fit(x_train, y_train)tree = prune_tree(tree, x_valid, y_valid)
The provided keyword options force the tree to have at least 5 examples from the training set in every leaf. We found that tuninghyperparameters of the TRAINTREE subprocedure, such as the minimum size of a leaf node, to be important for making usefultrees.
Generally, the runtime cost of sklearn’s fitting procedure scales superlinearly with the number of examples N and linearlywith the number of features F – a total complexity of O(FN log(N)). In practice, we found that with N = 1000 examples,F = 10 features, tree construction takes 15.3 microseconds.
The pruning procedure is a heuristic to create simpler trees, summarized in algorithm 2. After TRAINTREE delivers a workingdecision tree, we iterative propose removing each remaining leaf node, accepting the proposal if the squared prediction error on avalidation set improves. This pruning removes sub-trees that don’t generalize to unseen data.
Algorithm 2 Post-pruning for training decision trees.
Require:T : initial decision treeERRONVAL(·) : squared error on validation data
ERRONVAL(T ) ,∑Nn=1(T (xn)− yn)2
1: procedure PRUNETREE( T , err )2: e← ERRONVAL(T ).3: for node n ∈ SORTLEAFTOROOT(T.nodes) do4: T ′ ← REMOVENODE(T, n)5: enew ← ERRONVAL(T ′)6: if enew < e then T ← T ′
7: Return T
Sanity check: Surrogate path length closely follow true path length. Fig. A.1 shows that our surrogate predictor Ω(·) tracksthe true average path length as we train the target predictor y(·,W ) on several different datasets.
Sensitivity to different choices for surrogate training. In Fig. A.2, we show sample learning curves for variations of methodsfor approximating the average path length (also called “node count”) in a decision tree. In blue is the true value. Each of the other3 lines use the same surrogate model: an MLP with 25 hidden nodes. Increasing its capacity too much, i.e. 100 hidden nodes,leads to overfitting where the surrogate is able to predict the average path length extremely well for a small number of iterations,while the performance quickly decays. With an MLP of the right capacity, four additional tricks: (1) weight augmentation, (2)random restarts with an unregularized model, (3) fixed window of data, and (4) surrogate retraining greatly improve the accuracyof the average path length predictions.
Normally, if our differentiable model is a GRU, we compile examples using the GRU weights at every batch and calculate thetrue average path length. This dataset is used to train the surrogate model. If examples are very sparse, surrogate predictionsmay be unstable. Augmentation addresses this by randomly sampling weight vectors and computing the average path lengthto artificially create a larger dataset. Early epochs are especially problematic when it comes to lacking data. In addition toaugmentation, we use random restarts to separately train unregularized GRUs (each with different weight initializations) to growa dataset of weight vectors prior to training the regularized model.
As the GRU parameters take steps away from their initial values, our examples from those early epochs no longer describe thecurrent state of the model. Retraining and a fixed window of data address this by re-learning the surrogate function at a fixedfrequency using examples only from the last J epochs. In practice, both the augmentation size, the retraining frequency, and Jare functions of the learning rate and the dataset size. See table B.1 for exact numbers.
10
(a) Path length estimates Ω for 2D Parabola task
(b) Path length estimates Ω for Signal-and-noise HMM task
Figure A.1: True average path lengths (yellow) and surrogate estimates Ω (green) across many iterations of network parametertraining iterations.
0 50 100 150 200 250 300Epoch
5
10
15
Ave
rage
3at
h Le
ngth
ground truthretraLnLng
augPentatLonretraLnLng + augPentatLon
Figure A.2: This figure shows the effects of weight augmentation and retraining. The blue line is the true average path length ofthe decision tree at each epoch. All other lines show predicted path lengths using the surrogate MLP. By randomly samplingweights and intermittently retraining the surrogate, we significantly improve the ability of the surrogate model to track thechanges in the ground truth.
11
B Experimental ProtocolSee table B.1 for model hyperparameters for each dataset. For standard recurrent models such as HMM or GRU, the decision treeswere trained on the input data and the predictions of the model’s output node. For our deep residual GRU-HMM, the decisiontrees were trained on the predictions on the GRU’s output node only. For both synthetic and real-world datasets, our surrogateto the tree loss is a multilayer perceptron with 1 hidden layer of 25 nodes. For each dataset, when we investigated severalregularization strengths (λ), we initialize the model weights using the same random seed. We use the Adam algorithm (Kingmaand Ba 2014) for all optimization.
Table B.1: Dataset summaries and training parameters used in our experiments.
B.1 2D ParabolaDataset generation. The training data consists of 2D input points whose two-class decision boundary is roughly shaped likea parabola. The true decision function is defined by y = 5 ∗ (x− 0.5)2 + 0.4. We sampled all 200 input points xn uniformlywithin the unit square [0, 1]× [0, 1] and labeled those above the decision function as positive. To add randomness, we flipped10% of the points in the region near the boundary between y = 5 ∗ (x− 0.5)2 + 0.2 and y = 5 ∗ (x− 0.5)2 + 0.6.
B.2 Signal-and-noise HMMDataset generation The transition and emission matrices describing the generative process used to create the signal-and-noiseHMM are shown in Fig. B.1. The output yn at every timestep is created by concatenating a one-hot vector of an emitted state andthe 7-dimensional binary input vector. We emphasize that to output 1, the HMM must be in state 1 and the first input featuremust be 1. (
B.5 TIMITTraining Details. We explore (1, 5, 6, 10, 11, 15, 20, 25, 26, 30, 35, 50, 51, 55, 60, 75) GRU nodes, (5, 6, 10, 11, 15, 20, 25,26, 30, 35, 50, 51, 55, 60, 75) HMM states, and GRU-HMMs with (5, 10, 25) HMM states and (1, 5, 10, 25, 50) GRU nodes.Like Sepsis, the input features are z-scored prior to training.
C Extended ResultsFor signal-to-noise HMM, Sepsis, and TIMIT, we first show expanded versions of the fitness trace plots and the tree visualizations.For Sepsis and HIV, we show the additional output dimensions not in the paper.
We also include tables of the test AUC performance for our synthetic and real data sets over a vast array of parameter settings(GRU node counts, HMM state counts, regularization strengths). Consistent with the common wisdom of training deep models,we found that larger models, with regularization, tended to perform the best.
C.1 Signal-and-noise HMM: Plots
(a) GRU: Signal-and-noise HMM
0 10 20 30 40Average Path Length
0.5
0.6
0.7
0.8
0.9
1.0
AU
C (T
est)
GRU (L1)GRU (L2)GRU (Tree)Decision Tree
(b) GRUHMM: Signal-and-noise HMM
0 5 10 15 20 25 30Average Path Length
0.935
0.940
0.945
0.950
0.955
AU
C (T
est)
GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)
Figure C.1: Performance and complexity trade-offs using L1, L2, and Tree regularization on (a) GRU and (b) GRU-HMMperformance on the Signal-and-noise HMM dataset. Note the differences in scale.
13
C.2 Signal-and-noise HMM: Tree Visualization
X[0] <= 0.5value = [3915, 1085]
class = off
value = [2546, 0]class = off
True
X[3] <= 0.5value = [1369, 1085]
class = off
False
X[4] <= 0.5value = [1246, 498]
class = off
X[4] <= 0.5value = [123, 587]
class = on
X[6] <= 0.5value = [920, 498]
class = offvalue = [326, 0]
class = off
X[5] <= 0.5value = [759, 498]
class = offvalue = [161, 0]
class = off
X[11] <= 0.5value = [606, 498]
class = offvalue = [153, 0]
class = off
X[8] <= 0.5value = [430, 340]
class = off
X[1] <= 0.5value = [176, 158]
class = off
X[7] <= 0.5value = [200, 340]
class = onvalue = [230, 0]
class = off
X[12] <= 0.5value = [170, 324]
class = onvalue = [30, 16]
class = off
X[9] <= 0.5value = [66, 324]
class = onvalue = [104, 0]
class = off
X[10] <= 0.5value = [31, 198]
class = on
X[10] <= 0.5value = [35, 126]
class = on
X[1] <= 0.5value = [31, 121]
class = onvalue = [0, 77]
class = on
value = [0, 78]class = on
X[2] <= 0.5value = [31, 43]
class = on
value = [11, 23]class = on
value = [20, 20]class = off
value = [0, 99]class = on
X[1] <= 0.5value = [35, 27]
class = off
value = [17, 18]class = on
value = [18, 9]class = off
X[2] <= 0.5value = [83, 94]
class = on
X[2] <= 0.5value = [93, 64]
class = off
value = [83, 0]class = off
value = [0, 94]class = on
X[10] <= 0.5value = [43, 40]
class = offvalue = [50, 24]
class = off
value = [27, 28]class = on
value = [16, 12]class = off
value = [0, 587]class = on
value = [123, 0]class = off
(a) GRU:0.1
X[0] <= 0.5value = [3915, 1085]
class = off
value = [2546, 0]class = off
True
X[3] <= 0.5value = [1369, 1085]
class = off
False
X[4] <= 0.5value = [1246, 498]
class = off
X[4] <= 0.5value = [123, 587]
class = on
X[6] <= 0.5value = [920, 498]
class = offvalue = [326, 0]
class = off
X[5] <= 0.5value = [759, 498]
class = offvalue = [161, 0]
class = off
X[11] <= 0.5value = [606, 498]
class = offvalue = [153, 0]
class = off
X[8] <= 0.5value = [430, 340]
class = off
X[1] <= 0.5value = [176, 158]
class = off
X[7] <= 0.5value = [200, 340]
class = onvalue = [230, 0]
class = off
X[12] <= 0.5value = [170, 324]
class = onvalue = [30, 16]
class = off
X[9] <= 0.5value = [66, 324]
class = onvalue = [104, 0]
class = off
X[10] <= 0.5value = [31, 198]
class = on
X[10] <= 0.5value = [35, 126]
class = on
X[1] <= 0.5value = [31, 121]
class = onvalue = [0, 77]
class = on
value = [0, 78]class = on
X[2] <= 0.5value = [31, 43]
class = on
value = [11, 23]class = on
value = [20, 20]class = off
value = [0, 99]class = on
X[1] <= 0.5value = [35, 27]
class = off
value = [17, 18]class = on
value = [18, 9]class = off
X[2] <= 0.5value = [83, 94]
class = on
X[2] <= 0.5value = [93, 64]
class = off
value = [83, 0]class = off
value = [0, 94]class = on
X[10] <= 0.5value = [43, 40]
class = offvalue = [50, 24]
class = off
value = [27, 28]class = on
value = [16, 12]class = off
value = [0, 587]class = on
value = [123, 0]class = off
(b) GRU:0.1
X[0] <= 0.5value = [3897, 1103]
class = off
value = [2546, 0]class = off
True
X[4] <= 0.5value = [1351, 1103]
class = off
False
X[3] <= 0.5value = [902, 1103]
class = onvalue = [449, 0]
class = off
X[6] <= 0.5value = [902, 516]
class = offvalue = [0, 587]
class = on
X[5] <= 0.5value = [741, 516]
class = offvalue = [161, 0]
class = off
X[11] <= 0.5value = [588, 516]
class = offvalue = [153, 0]
class = off
X[1] <= 0.5value = [254, 516]
class = onvalue = [334, 0]
class = off
X[8] <= 0.5value = [112, 280]
class = on
X[9] <= 0.5value = [142, 236]
class = on
value = [0, 280]class = on
value = [112, 0]class = off
value = [0, 236]class = on
value = [142, 0]class = off
(c) GRU:1.0
X[3] <= 0.5value = [4644, 356]
class = off
value = [3532, 0]class = off
True
X[0] <= 0.5value = [1112, 356]
class = off
False
value = [758, 0]class = off
X[4] <= 0.5value = [354, 356]
class = on
X[11] <= 0.5value = [231, 356]
class = onvalue = [123, 0]
class = off
X[7] <= 0.5value = [82, 308]
class = on
X[12] <= 0.5value = [149, 48]
class = off
X[9] <= 0.5value = [45, 295]
class = onvalue = [37, 13]
class = off
X[1] <= 0.5value = [19, 205]
class = on
X[8] <= 0.5value = [26, 90]
class = on
value = [0, 108]class = on
X[12] <= 0.5value = [19, 97]
class = on
value = [0, 83]class = on
value = [19, 14]class = off
value = [0, 75]class = on
value = [26, 15]class = off
X[10] <= 0.5value = [82, 48]
class = offvalue = [67, 0]
class = off
X[2] <= 0.5value = [31, 48]
class = onvalue = [51, 0]
class = off
value = [14, 31]class = on
value = [17, 17]class = off
(d) GRU:10
X[3] <= 0.5value = [4747, 253]
class = off
value = [3532, 0]class = off
True
X[0] <= 0.5value = [1215, 253]
class = off
False
value = [758, 0]class = off
X[4] <= 0.5value = [457, 253]
class = off
X[8] <= 0.5value = [334, 253]
class = offvalue = [123, 0]
class = off
X[10] <= 0.5value = [222, 253]
class = onvalue = [112, 0]
class = off
X[1] <= 0.5value = [95, 217]
class = on
X[2] <= 0.5value = [127, 36]
class = off
value = [0, 137]class = on
X[2] <= 0.5value = [95, 80]
class = off
value = [0, 80]class = on
value = [95, 0]class = off
X[12] <= 0.5value = [48, 36]
class = offvalue = [79, 0]
class = off
X[1] <= 0.5value = [29, 29]
class = offvalue = [19, 7]
class = off
value = [13, 16]class = on
value = [16, 13]class = off
(e) GRU:20X[3] <= 0.5
value = [4761, 239]class = off
value = [3532, 0]class = off
True
X[0] <= 0.5value = [1229, 239]
class = off
False
value = [758, 0]class = off
X[4] <= 0.5value = [471, 239]
class = off
X[11] <= 0.5value = [348, 239]
class = offvalue = [123, 0]
class = off
X[2] <= 0.5value = [151, 239]
class = onvalue = [197, 0]
class = off
X[1] <= 0.5value = [16, 186]
class = on
X[1] <= 0.5value = [135, 53]
class = off
value = [0, 101]class = on
X[10] <= 0.5value = [16, 85]
class = on
value = [0, 69]class = on
value = [16, 16]class = off
X[9] <= 0.5value = [36, 53]
class = onvalue = [99, 0]
class = off
value = [14, 42]class = on
value = [22, 11]class = off
(f) GRU:100
X[3] <= 0.5value = [4783, 217]
class = off
value = [3532, 0]class = off
True
X[0] <= 0.5value = [1251, 217]
class = off
False
value = [758, 0]class = off
X[4] <= 0.5value = [493, 217]
class = off
X[1] <= 0.5value = [370, 217]
class = offvalue = [123, 0]
class = off
X[8] <= 0.5value = [62, 217]
class = onvalue = [308, 0]
class = off
value = [0, 217]class = on
value = [62, 0]class = off
(g) GRU::400
X[0] <= 0.5value = [4439, 561]
class = off
value = [2546, 0]class = off
True
X[3] <= 0.5value = [1893, 561]
class = off
False
value = [1744, 0]class = off
X[4] <= 0.5value = [149, 561]
class = on
X[7] <= 0.5value = [26, 561]
class = onvalue = [123, 0]
class = off
value = [0, 537]class = on
value = [26, 24]class = off
(h) GRU:800
X[0] <= 0.5value = [4413, 587]
class = off
value = [2546, 0]class = off
True
X[3] <= 0.5value = [1867, 587]
class = off
False
value = [1744, 0]class = off
X[4] <= 0.5value = [123, 587]
class = on
value = [0, 587]class = on
value = [123, 0]class = off
(i) GRU:1 000
value = [5000, 0]class = off
(j) GRU:10 000
Figure C.2: Decision trees trained under varying tree regularization strengths for GRU models on the signal-and-noise HMMdataset dataset. As the tree regularization increases, the number of nodes collapses to a single one. If we focus on (h), we see thatthe tree resembles the ground truth data-generating function quite closely.
14
C.3 Signal-and-noise HMM: AUCs
Model AUC (Test) Average Path Length Parameter Countlogreg 0.91832 17.302 6
decision tree 0.92050 29.4424 -hmm (5) 0.93591 25.5736 71
Table C.2: Performance metrics for multi-dimensional classification on a held-out portion of the Sepsis dataset. Total AveragePath Length refers to the summed average path lengths across the 5 output dimensions. Refer to Fig. C.3 for average-path-lengthssplit across dimensions.
17
C.7 HIV:Plots
(a) Therapy Success
15 20 25 30Average Path Length
0.50
0.55
0.60
0.65
0.70
0.75
AU
C (T
est)
GRU(L1)GRU(L2)GRU(Tree)Decision Tree
(b) CD4+ ≤ 200 cells/ml
10.0 12.5 15.0 17.5 20.0Average Path Length
0.55
0.60
0.65
0.70
AU
C (T
est)
GRU (L1)GRU (L2)GRU (Tree)Decision Tree
(c) Adherence
15.0 17.5 20.0 22.5 25.0 27.5Average Path Length
0.6
0.7
0.8
AU
C (T
est)
GRU (L1)GRU (L2)GRU (Tree)Decision Tree
(d) Mortality
10.0 12.5 15.0 17.5 20.0 22.5Average Path Length
0.70
0.75
0.80
0.85
AU
C (T
est)
GRUHMM (L1)GRUHMM (L2)GRUHMM (Tree)Decision Tree
(e) Onset of AIDS
4.0 4.5 5.0 5.5 6.0Average Path Length
0.54
0.56
0.58
AU
C (T
est)
GRU (L1)GRU (L2)GRU (Tree)Decision Tree
Figure C.5: Performance and complexity trade-offs using L1, L2, and Tree regularization on GRU for the HIV dataset. The 5outputs shown here were trained jointly.
Table C.3: Performance metrics for multi-dimensional classification on a held-out portion of the HIV dataset. Total Average PathLength refers to the summed average path lengths across the output dimensions.
class = Non-Stopvalue = [19761, 0]class = Non-Stop
value = [4032, 4068]class = Stop
value = [0, 24726]class = Stop
MFCC 1 derivative <= -0.644value = [24629, 59053]
class = Stopvalue = [0, 24646]
class = Stopvalue = [4078, 1936]
class = Non-Stopvalue = [80106, 0]class = Non-Stop
Energy derivative <= -0.014value = [176844, 25699]
class = Non-Stop
value = [0, 19829]class = Stop
value = [7141, 3462]class = Non-Stop
value = [364646, 0]class = Non-Stop
MFCC 6 derivative <= 0.784value = [18264, 5231]
class = Non-Stop
Energy <= -0.645value = [3557, 49725]
class = Stop
MFCC 2 <= 1.325value = [21072, 9328]
class = Non-Stop
MFCC 2 <= 1.402value = [16519, 17934]
class = Stop
MFCC 2 derivative <= -0.737value = [160325, 7765]
class = Non-Stop
value = [14320, 0]class = Non-Stop
value = [3944, 5231]class = Stop
value = [0, 46482]class = Stop
value = [3557, 3243]class = Non-Stop
value = [14043, 0]class = Non-Stop
Energy derivative = -0.023value = [7029, 9328]
class = Stop
MFCC 1 <= 0.366value = [16519, 3453]
class = Non-Stopvalue = [0, 14481]
class = StopMFCC 1 <= 0.421
value = [8879, 7765]class = Non-Stop
value = [151446, 0]class = Non-Stop
value = [1886, 4533]class = Stop
value = [5143, 4795]class = Non-Stop
value = [2628, 3453]class = Stop
value = [13891, 0]class = Non-Stop
value = [4584, 6000]class = Stop
value = [4294, 1765]class = Non-Stop
(b) GRU:500
Figure C.6: (a) Performance and complexity trade-offs using L1, L2, and Tree regularization for GRU models on TIMIT. (b)Decision tree trained using λ = 500.0 tree regularization on GRU.
20
C.10 TIMIT:AUCs
Model AUC Average Path Length Parameter Countlogreg 0.7747 23.460 27
decision tree 0.8668 59.2061 -hmm (5) 0.8900 51.911 295
Table C.4: Performance metrics across models on a held-out portion of the TIMIT dataset.
21
D GRU-HMM: Deep Residual Timeseries ModelHidden Markov Model For our purposes, Hidden Markov Models (HMMs) can be viewed as stochastic RNNs which canbe interpreted as probabilistic generative models. In this work, we consider an HMM to generate a latent variable sequencez = [z1, . . . zT ] via a Markov chain, where each latent indicates one of K possible discrete states: zt ∈ 1, ...,K. This statesequence is then used to jointly produce the “data” xt and “outcomes” yt observed at each timestep. The joint distribution overz, x, y factorizes as:
p(z, y) = π0(z0)
T∏t=1
p(zt|zt−1, A) · p(xt|zt, φ)Bern(yt|σ(∑k
wkδk(zt))), (6)
where A is a transition matrix such that Ai,j = Pr(zt = i|zt−1 = j), π0 = p(z0) is the initial state distribution, φkKk=1 are theemission parameters that generate data. We can then apply the same objective as above for training.
GRU-HMM: Modeling the residuals of an HMM. We now consider an additional model, the GRU-HMM, designed forinterpretability. The idea is to use a GRU to to model the residual errors when predicting the binary target via the HMMbelief states. We can further penalize the complexity of the GRU predictions via our tree regularization, so that higher-qualitypredictions do not come at the price of a much less interpretable model.
1-
sigm sigm tanh
sigmht-1
ht yt
rt zth~t
x t-1
λ
st
xtxt-1 xt 1+ ……
… st-1 st 1+ …
Figure D.1: Deep residual model: GRU-HMM. The orange triangle indicates the output used in surrogate training for treeregularization.
We train the deep residual model on the same suite of synthetic and real world datasets. See Tables C.1, C.2, C.4 for acomparison of GRU-HMM with vanilla GRU and HMM models under different regularization and expressiveness parameters.We can see that across the datasets, deep residual models perform around 1% better than their vanilla equivalents with roughlythe same number of model parameters.
By nature of being a residual model, decision trees were trained only on the GRU output node, leaving the HMM unconstrained.See Figure D.1 for a pictoral representation. Similar to what we did for GRU models, figures C.1b, D.2 compare modelperformance as the λ parameter for L1, L2, and Tree regularization increase. We can see a similar albeit less pronounced effectwhere Tree regularization dominates other methods in low node count regions. It is important to notice the range of the AUC axisin these figures, where the worst the residual model can performance is the HMM-only AUC. Figure D.3 show the regularizedtrees produced by the GRU-HMM. Although they share some structure with Figure C.4, there are important distinctions thatencourage us to conclude that the GRU in a residual models performs a different role than when trained alone.
22
D.1 GRU-HMM: Sepsis Plots
(a) In-Hospital Mortality
0 10 20 30Average Path Length
0.72
0.73
0.74
0.75
AU
C (T
est)
GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)
(b) 90-Day Mortality
0 5 10 15 20Average Path Length
0.73
0.74
0.75
0.76
AU
C (T
est)
GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)
(c) Mechanical Ventilation
0 2 4 6 8 10Average Path Length
0.82
0.84
0.86
0.88
0.90
AU
C (T
est)
GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)
(d) Median Vasopressor
0 5 10 15 20 25Average Path Length
0.74
0.76
0.78
0.80
AU
C (T
est)
GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)
(e) Max Vasopressor
0 5 10 15 20Average Path Length
0.74
0.76
0.78
0.80
AU
C (T
est)
GRU-HMM (L1)GRU-HMM (L2)GRU-HMM (Tree)
Figure D.2: Performance and complexity trade-offs using L1, L2, and Tree regularization on GRU-HMM performance on theSepsis dataset.
D.2 GRU-HMM: Sepsis Tree Visualization
age <= 69.718value = [67747, 33105]
class = died_in_hosp:ON
BUN <= 22.133value = [52415, 4570]
class = died_in_hosp:ON
True
BUN <= 31.111value = [15332, 28535]
class = died_in_hosp:OFF
False
value = [32484, 0]class = died_in_hosp:ON
INR <= 1.896value = [19931, 4570]
class = died_in_hosp:ON
age <= 77.443value = [15332, 9641]
class = died_in_hosp:ONvalue = [0, 18894]
class = died_in_hosp:OFF
age <= 57.721value = [18980, 734]
class = died_in_hosp:ON
age <= 52.295value = [951, 3836]
class = died_in_hosp:OFF
PTT <= 41.791value = [8645, 1213]
class = died_in_hosp:ON
PTT <= 30.663value = [6687, 8428]
class = died_in_hosp:OFF
value = [8998, 0]class = died_in_hosp:ON
WBC_count <= 22.659value = [9982, 734]
class = died_in_hosp:ONvalue = [951, 620]
class = died_in_hosp:ONvalue = [0, 3216]
class = died_in_hosp:OFFvalue = [7785, 0]
class = died_in_hosp:ONINR <= 1.446
value = [860, 1213]class = died_in_hosp:OFF
Calcium <= 8.363value = [4565, 1988]
class = died_in_hosp:ON
Sodium <= 139.973value = [2122, 6440]
class = died_in_hosp:OFF
value = [9707, 0]class = died_in_hosp:ON
value = [275, 734]class = died_in_hosp:OFF
value = [522, 494]class = died_in_hosp:ON
value = [338, 719]class = died_in_hosp:OFF
Sodium <= 139.612value = [1680, 1988]
class = died_in_hosp:OFFvalue = [2885, 0]
class = died_in_hosp:ONArterial_BE <= -0.994value = [2122, 2658]
Figure D.4: HIV task: Study of different regularization techniques for GRU-HMM model with 75 GRU nodes and 25 HMMstates, trained to predict whether CD4+ ≤ 200 cells/ml. (a) Example decision tree for λ = 1000.0. (b) Example decision tree forλ = 3000.0. The tree in (b) is slightly smaller than the tree in (a) as a result of the regularisation.
Energy derivative <= -0.015value = [226736, 57485]
class = Non-Stop
MFCC 4 derivative <= -0.805value = [70185, 21866]
class = Non-Stopvalue = [1708238, 0]
class = Non-Stopvalue = [9159, 23170]
class = Stopvalue = [20076, 3108]
class = Non-Stopvalue = [18549, 4502]
class = Non-Stopvalue = [0, 107584]
class = StopEnergy <= -0.605
value = [16835, 57483]class = Stop
value = [209901, 0]class = Non-Stop
Energy derivative <= 0.772value = [18193, 21866]
class = Stopvalue = [51992, 0]class = Non-Stop
value = [0, 54243]class = Stop
value = [16835, 3242]class = Non-Stop
value = [4748, 15281]class = Stop
value = [13445, 6585]class = Non-Stop
(c) GRU-HMM: Stop vs Non-Stop
Figure D.5: TIMIT task: Study of different regularization techniques for GRU-HMM model with 75 GRU nodes and 25 HMMstates, trained to predict STOP phonemes. (a) Tradeoff curves showing how AUC predictive power and decision-tree complexityevolve with increasing regularization strength under L1, L2, or Tree regularization. (b) Example decision tree for λ = 3000.0.(c) Example decision tree for λ = 7000.0. When comparing with figure C.6b, this tree is significantly smaller, suggesting thatthe GRU performs a different role in the residual model.
24
E Runtime comparisonsTraining Time for Tree-Regularized Models. Table E.1 shows the wall time for training one epoch of each of the modelspresented in this paper using each of the datasets. Please note that the wall times for GRU-TREE and GRU-HMM-TREE includethe cost of surrogate training. If the retraining frequency is small, then the amortized cost should be small.
Table E.1: Training time for recurrent models measured against all datasets used in this paper. Epoch time denotes the number ofseconds it took for a single pass through all the training data. The epoch times for GRU-TREE and GRU-HMM-TREE includesurrogate training expenses. If we retrain sparsely, then the cost of surrogate training is amortized and the epoch time for GRUand GRU-TREE, GRU-HMM and GRU-HMM-TREE are approximately the same. To measure epoch time, we used 10 HMMstates, 10 GRU states, and 5 of each for GRU-HMM models. We trained the surrogate model for 5000 epochs. These tests wererun on a single Intel Core i5 CPU.
25
F Extended Stability TestsIn the paper, we noted that decision trees are stable over multiple run. Here, we show that using the signal-and-noise HMM dataset,10 independent runs with random initializations and λ = 1000.0 produce either the same or comparable trees. Additionally, weshow that with weak regularization (λ = 0.01), the variability of the learned decision trees is high. Figures F.1, F.2 includeexamples of such trees on the signal-and-noise dataset. Similar results are found for real-world datasets.
X[0] <= 0.5value = [4413, 587]
class = off
value = [2546, 0]class = off
True
X[3] <= 0.5value = [1867, 587]
class = off
False
value = [1744, 0]class = off
X[4] <= 0.5value = [123, 587]
class = on
value = [0, 587]class = on
value = [123, 0]class = off
(a) 7/10 Runs
X[0] <= 0.5value = [4467, 533]
class = off
value = [2546, 0]class = off
True
X[3] <= 0.5value = [1921, 533]
class = off
False
value = [1744, 0]class = off
X[4] <= 0.5value = [177, 533]
class = on
X[8] <= 0.5value = [54, 533]
class = onvalue = [123, 0]
class = off
value = [0, 475]class = on
X[9] <= 0.5value = [54, 58]
class = on
value = [0, 58]class = on
value = [54, 0]class = off
(b) 2/10 Runs
X[13] <= 0.5value = [4901, 99]
class = off
value = [4514, 0]class = off
True
X[1] <= 0.5value = [387, 99]
class = off
False
X[3] <= 0.5value = [161, 99]
class = offvalue = [226, 0]
class = off
X[12] <= 0.5value = [89, 99]
class = onvalue = [72, 0]
class = off
value = [89, 0]class = off
value = [0, 99]class = on
(c) 1/10 Runs
Figure F.1: Decision trees from 10 independent runs on the signal-and-noise HMM dataset with λ = 1000.0. Seven of the tenruns resulted in a tree of the same structure. The other three trees are similar, often having additional subtrees but sharing thesame splits and features.
X[2] <= 0.5value = [3149, 1851]
class = off
X[1] <= 0.5value = [2267, 291]
class = off
True
X[1] <= 0.5value = [882, 1560]
class = on
False
X[13] <= 0.5value = [1006, 291]
class = offvalue = [1261, 0]
class = off
X[4] <= 0.5value = [1006, 154]
class = offvalue = [0, 137]
class = on
X[5] <= 0.5value = [791, 154]
class = offvalue = [215, 0]
class = off
X[0] <= 0.5value = [791, 88]
class = offvalue = [0, 66]
class = on
value = [453, 0]class = off
X[12] <= 0.5value = [338, 88]
class = off
X[11] <= 0.5value = [279, 88]
class = offvalue = [59, 0]
class = off
value = [279, 0]class = off
value = [0, 88]class = on
value = [0, 1244]class = on
X[4] <= 0.5value = [882, 316]
class = off
X[5] <= 0.5value = [673, 316]
class = offvalue = [209, 0]
class = off
X[13] <= 0.5value = [673, 237]
class = offvalue = [0, 79]
class = on
X[11] <= 0.5value = [673, 167]
class = offvalue = [0, 70]
class = on
value = [606, 0]class = off
X[12] <= 0.5value = [67, 167]
class = on
value = [0, 167]class = on
value = [67, 0]class = off
(a)
X[4] <= 0.5value = [988, 4012]
class = on
X[2] <= 0.5value = [332, 3755]
class = on
True
X[2] <= 0.5value = [656, 257]
class = off
False
X[9] <= 0.5value = [332, 1764]
class = onvalue = [0, 1991]
class = on
X[10] <= 0.5value = [65, 1408]
class = on
X[10] <= 0.5value = [267, 356]
class = on
X[13] <= 0.5value = [65, 975]
class = onvalue = [0, 433]
class = on
X[3] <= 0.5value = [65, 759]
class = onvalue = [0, 216]
class = on
X[6] <= 0.5value = [65, 514]
class = onvalue = [0, 245]
class = on
X[5] <= 0.5value = [65, 453]
class = onvalue = [0, 61]
class = on
value = [0, 453]class = on
value = [65, 0]class = off
X[3] <= 0.5value = [267, 152]
class = offvalue = [0, 204]
class = on
X[6] <= 0.5value = [267, 24]
class = offvalue = [0, 128]
class = on
value = [259, 0]class = off
value = [8, 24]class = on
X[13] <= 0.5value = [429, 33]
class = off
X[10] <= 0.5value = [227, 224]
class = off
value = [420, 0]class = off
value = [9, 33]class = on
X[13] <= 0.5value = [227, 90]
class = offvalue = [0, 134]
class = on
X[3] <= 0.5value = [225, 47]
class = offvalue = [2, 43]
class = on
value = [193, 0]class = off
X[9] <= 0.5value = [32, 47]
class = on
value = [10, 32]class = on
value = [22, 15]class = off
(b)
X[12] <= 0.5value = [1574, 3426]
class = on
X[2] <= 0.5value = [902, 3112]
class = on
True
X[0] <= 0.5value = [672, 314]
class = off
False
X[9] <= 0.5value = [802, 1242]
class = on
X[9] <= 0.5value = [100, 1870]
class = on
X[0] <= 0.5value = [313, 979]
class = on
X[3] <= 0.5value = [489, 263]
class = off
X[11] <= 0.5value = [313, 359]
class = onvalue = [0, 620]
class = on
X[3] <= 0.5value = [236, 226]
class = off
X[10] <= 0.5value = [77, 133]
class = on
X[8] <= 0.5value = [236, 87]
class = offvalue = [0, 139]
class = on
value = [236, 0]class = off
value = [0, 87]class = on
X[13] <= 0.5value = [77, 59]
class = offvalue = [0, 74]
class = on
X[3] <= 0.5value = [76, 27]
class = offvalue = [1, 32]
class = on
value = [71, 0]class = off
value = [5, 27]class = on
X[10] <= 0.5value = [419, 110]
class = off
X[8] <= 0.5value = [70, 153]
class = on
X[0] <= 0.5value = [322, 28]
class = off
X[4] <= 0.5value = [97, 82]
class = off
value = [178, 0]class = off
X[4] <= 0.5value = [144, 28]
class = off
X[8] <= 0.5value = [115, 27]
class = offvalue = [29, 1]
class = off
value = [91, 0]class = off
value = [24, 27]class = on
X[0] <= 0.5value = [68, 80]
class = onvalue = [29, 2]
class = off
value = [68, 0]class = off
value = [0, 80]class = on
X[10] <= 0.5value = [70, 83]
class = onvalue = [0, 70]
class = on
X[0] <= 0.5value = [70, 29]
class = offvalue = [0, 54]
class = on
value = [52, 0]class = off
value = [18, 29]class = on
value = [0, 1192]class = on
X[0] <= 0.5value = [100, 678]
class = on
X[3] <= 0.5value = [100, 285]
class = onvalue = [0, 393]
class = on
X[10] <= 0.5value = [100, 162]
class = onvalue = [0, 123]
class = on
X[8] <= 0.5value = [100, 75]
class = offvalue = [0, 87]
class = on
X[11] <= 0.5value = [100, 15]
class = offvalue = [0, 60]
class = on
value = [88, 0]class = off
value = [12, 15]class = on
value = [494, 0]class = off
X[2] <= 0.5value = [178, 314]
class = on
X[13] <= 0.5value = [146, 92]
class = off
X[4] <= 0.5value = [32, 222]
class = on
X[10] <= 0.5value = [146, 34]
class = offvalue = [0, 58]
class = on
value = [110, 0]class = off
X[11] <= 0.5value = [36, 34]
class = off
value = [26, 9]class = off
value = [10, 25]class = on
value = [0, 204]class = on
value = [32, 18]class = off
(c)
X[1] <= 0.5value = [3543, 1457]
class = off
X[10] <= 0.5value = [2280, 261]
class = off
True
X[3] <= 0.5value = [1263, 1196]
class = off
False
value = [1798, 0]class = off
X[3] <= 0.5value = [482, 261]
class = off
X[11] <= 0.5value = [261, 261]
class = offvalue = [221, 0]
class = off
X[6] <= 0.5value = [105, 261]
class = onvalue = [156, 0]
class = off
X[2] <= 0.5value = [77, 260]
class = onvalue = [28, 1]
class = off
value = [0, 182]class = on
X[0] <= 0.5value = [77, 78]
class = on
value = [0, 78]class = on
value = [77, 0]class = off
X[6] <= 0.5value = [629, 1091]
class = on
X[10] <= 0.5value = [634, 105]
class = off
X[7] <= 0.5value = [470, 1091]
class = onvalue = [159, 0]
class = off
X[11] <= 0.5value = [337, 1091]
class = onvalue = [133, 0]
class = off
value = [0, 914]class = on
X[10] <= 0.5value = [337, 177]
class = off
value = [337, 0]class = off
value = [0, 177]class = on
value = [517, 0]class = off
X[11] <= 0.5value = [117, 105]
class = off
X[0] <= 0.5value = [25, 105]
class = onvalue = [92, 0]
class = off
value = [0, 70]class = on
X[2] <= 0.5value = [25, 35]
class = on
value = [10, 23]class = on
value = [15, 12]class = off
(d)
Figure F.2: Decision trees from 10 independent runs on the signal-and-noise HMM dataset with λ = 0.01. With low regularization,the variance in tree size and shape is high.