Modeling electronic health records with recurrent neural networks David C. Kale, 1,2 Zachary C. Lipton, 3 Josh Patterson 4 STRATA - San Jose - 2016 1 University of Southern California 2 Virtual PICU, Children’s Hospital Los Angeles 3 University of California San Diego 4 Patterson Consulting
49
Embed
Modeling Electronic Health Records with Recurrent Neural Networks
This document is posted to help you gain knowledge. Please leave a comment to let me know what you think about it! Share it to your friends and learn new things together.
Transcript
Modeling electronic health records with
recurrent neural networks
David C. Kale,1,2 Zachary C. Lipton,3 Josh Patterson4
STRATA - San Jose - 20161 University of Southern California2 Virtual PICU, Children’s Hospital Los Angeles3 University of California San Diego4 Patterson Consulting
Outline• Machine (and deep) learning
• Sequence learning with recurrent neural networks
• Clinical sequence classification using LSTM RNNs
• A real world case study using DL4J
• Conclusion and looking forward
We need functions, brah
Various Inputs and Outputs
“Time underlies many interesting human behaviors”
{0,1}{A,B,C…}captions,email mom,fire nukes,eject pop tart
But how do we produce functions? We need a function
for that…
One function-generator: Programmers
Which are expensive
When/why does this fail?
• Sometimes the correct function cannot be encoded a priori — (what is spam?)
• The optimal solution might change over time
• Programmers are expensive
Sometimes We Need to Learn These Functions
From Data
One Class of Learnable Functions:
Feedforward Neural Network
Artificial Neurons
Activation Functions• At internal nodes common choices for the
activation function are the sigmoid, tanh, and ReLU functions.
• At output, activation function could be linear (regression), sigmoid (multilabel classification) or softmax (multi-class classification)
Training w Backpropagation
• Goal: calculate the rate of change of the loss function with respect to each parameter (weight) in the model
• Update the weights by gradient following:
Forward Pass
Backward Pass
Deep Networks• Used to be difficult (seemed impossible) to
train nets with many layers of hidden layers
• TLDR: Turns out we just needed to do everything 1000x faster…
Outline• Machine (and deep) learning
• Sequence learning with recurrent neural networks
• Clinical sequence classification using LSTM RNNs
• A real world case study using DL4J
• Conclusion and looking forward
Feedforward Nets work for Fixed-Size Data
Less Suitable for Text
We would like to capture temporal/sequential dynamics
in the data• Standard approaches address sequential
structure:Markov modelsConditional Random FieldsLinear dynamical systems
• Problem: We desire a system to learn representations, capture nonlinear structure, and capture long term sequential relationships.
To Model Sequential Data:
Recurrent Neural Networks
Recurrent Net (Unfolded)
Vanishing / Exploding Gradients
LSTM Memory Cell(Hochreiter & Schmidhuber, 1997)
Memory Cell with Forget Gate
(Gers et al., 2000)
LSTM Forward Pass
LSTM (full network)
Large Scale Architecture
Standard supervised learning
Imagecaptioning
Sentiment analysis
Video captioning,Natural language translation
Part of speechtagging
Generative models for text
Outline• Machine (and deep) learning
• Sequence learning with recurrent neural networks
• Clinical sequence classification using LSTM RNNs
• Lab results (e.g., glucose)• Clinical assessments (e.g., cognitive function)• One treatment: mechanical ventilation• Clinical notes• Diagnoses (often buried in free text notes)• Outcomes: in-hospital mortality• Billing codes
• Sparse, irregular, unaligned sampling in time, across variables
• Sample selection bias (e.g., more likely to record abnormal)
• Entire sequences (non-random) missing
HR
RR
Admit Discharge
Challenges: sampling rates, missingness
ETCO2
Figures courtesy of Ben Marlin, UMass Amherst
HR
HR
Admit
Admit
Discharge
Discharge
Challenges: alignment, variable length• Observations begin at time of admission, not at onset of
illness• Sequences vary in length from hours to weeks (or longer)• Variable dynamics across patients, even with same disease• Longterm dependencies: future state depends on earlier
condition
Figures courtesy of Ben Marlin, UMass Amherst
PhysioNet Challenge 2012• Task: predict mortality from only first 48 hours of data
• Classic models (SAPS, Apache, PRISM): experts features + regression• Useful: quantifying illness at admission, standardized performance• Not accurate enough to be used for decision support
• Each record includes• patient descriptors (age, gender, weight, height, unit)• irregular sequences of ~40 vitals, labs from first 48 hours• One treatment variable: mechanical ventilation• Binary outcome: in-hospital survival or mortality (~13% mortality)
• Only 4000 labeled records publicly available (“set A”)• 4000 unlabeled records (“set B”) used for tuning during competition (we
didn’t use)• 4000 test examples (“set C”) not available
• Trained on full set A (4K), tuned on set B (4K), tested on set C
• All used extensively hand-engineered features
• Our best model so far: min(P,R) = 0.4907• 60/20/20 training/validation/test split of set A• LSTM with 2 x 300-cell layers on inputs
• Different test sets so not directly comparable• Disadvantage: much smaller training set• Required no feature engineering or domain knowledge
Map sequences into fixed vector representation
• Not perfectly separable in 2D but some cluster structure related to mortality
• Can repurpose “representation” for other tasks (e.g., searching for similar patients, clustering, etc.)
Final comments• We believe we could improve performance to well over 0.5
• overfitting: training min(P,R) > 0.6 (vs. test: 0.49)• smaller or simpler RNN layers, adding dropout, multitask training
• Flexible NN architectures well suited to complex clinical data• but likely will demand much larger data sets• may be better matched to “raw” signals (e.g., waveforms)
• More general challenges• missing (or unobserved) inputs and outcomes• treatment effects confound predictive models• outcomes often have temporal components
(posing as binary classification ignores that)
• You can try it out: https://github.com/jpatanooga/dl4j-rnn-timeseries-examples/
See related paper to appear at ICLR 2016: http://arxiv.org/abs/1511.03677