How To Predict ICU Mortality with Digital Health Data, DL4J, Apache Spark and Cloudera

Modeling EHR Data in Healthcare

In this case study, we take a look at modeling electronic health record (EHR) data with deep learning and Deeplearning4j (DL4J). We draw inspiration from recent research showing that carefully designed neural network architectures can learn effectively from the complex, messy data collected in EHRs. Specifically, we describe how to train an  long short-term memory recurrent neural network (LSTM RNN) to predict in-hospital mortality among patients hospitalized in the intensive care unit (ICU). Of particular note, we show that even for a dataset orders of magnitude smaller than those commonly used in deep learning, the LSTM is competitive with more traditional approaches.

Cloudera has over one hundred healthcare and life science customers, including Cerner, Children’s Healthcare of Atlanta, and Docbox, , and Skymind recently announced a partnership with the China Electronic Corporation (CEC) to analyze healthcare data with deep learning as a part of the Healthy China 2030 initiative.

We are in the midst of a digital health data “boom.” According to the Agency for Healthcare Research and Quality (AHRQ), there are 30-40 million annual hospital admissions in the United States (US) alone. Because adoption of electronic health records (EHRs) climbed from fewer than 10% of US hospitals in 2008 to nearly 85% in 2015, records from the majority of these encounters are now captured digitally. These data represent an unprecedented opportunity to generate insights about disease and to transform the way in which healthcare is delivered. Increasingly clinicians and researchers are becoming comfortable deep learning to help make sense of this trove of data and realize its promise. One need only look at the last two proceedings of the Machine Learning for Healthcare Conference (of which one of the co-authors of this article, Mr. Kale, is a co-organizer) to witness the explosion in innovative applications of deep learning to health-related problems and data.

DL4J and CDH

DL4J’s native Spark integration on CDH enables the Fortune 500 enterprise data science practitioner to build deep learning models on their own next-generation Apache Hadoop data warehouse infrastructure. A complete deep learning suite certified on CDH gives us the ability to model complex datasets while maintaining the proper security necessary to productionize deep learning models. Certification with CDH also gives DL4J a way to perform scale-out parallelization on Spark. While we can run DL4J on arbitrary frameworks and other variants of Spark, having the ability to run on a supported Spark framework the Fortune 500 enterprise has already invested in is a key capability. Let’s now move on to how we’d build a basic DL4J Spark job on CDH.

DL4J is the suite of tools that focuses on making deep learning real and actionable for today’s enterprises. DL4J also includes ND4J, our JVM port of Numpy, for high-performance computing on the JVM with CPUs and GPUs.

DL4J also includes an ETL tool, DataVec, which is specifically geared towards data vectorization. DataVec allows us to transform the  time series data to produce 3-dimensional matrices, or “tensors”, that work natively with DL4J and Spark. ND4J allows us to manipulate these tensors with fast linear algebra operations on CPUs and GPUs. However, we still need to operate where the data lives, so we often run these deep learning workflows on Hadoop clusters.

Through certification with CDH we are able to leverage security compliance, data integrity and not run into siloes.

This post is a blog post leading up to our talk with Cloudera at Strata NYC 2017 on leveraging Cloudera Data Science Workbench and DL4J to build healthcare deep learning applications.

Cloudera’s Data Science WorkBench is a great tool for interactive development of models and training, while maintaining the integrity and security of the data in the main CDH cluster. And the examples in this blog can be run, as is, in client mode from the Data Science WorkBench as well.

In this blog post, we would primarily want to focus on reduction in training time using Spark-on-YARN and CDH. In either CPU or GPU mode. Spark allows the programmer to abstract away parallel processing and focus more on the algorithm at hand.

EHR adoption at non-federal acute care hospitals from ONC Data Brief, No. 35, May 2016.

 

A recent example of such work is the ICLR 2016 paper “Learning to Diagnose with LSTM Recurrent Neural Networks” (of which Mr. Kale is a joint first author in his capacity as a PhD candidate at the USC Information Science Institute). In it, the authors trained a LSTM RNN or LSTM, to classify acute care diseases such as respiratory distress in critically ill children. The RNN (and the more complex LSTM RNN) is a neural net architecture with recurrent connections between hidden states, giving it a form of persistent state (or “memory”) across sequential inputs. These connections enable RNNs to detect relationships not only between inputs, e.g., heart rate and blood pressure, but also over time, e.g., between a patient’s state at time of admission and later in an ICU stay. This makes it especially well-suited to health problems, which often involve modeling changes over time. In the “Learning to Diagnose” paper, the LSTM is significantly more accurate than logistic regression, a traditional clinical baseline, and even outperformed a stronger baseline combining a multilayer perceptron (MLP) with hand-engineered features.

While the resulting model was not ready for bedside deployment, the results were noteworthy as they demonstrated that the modern deep learning architectures, e.g., an LSTM combined with dropout and deep supervision, can effectively model complex clinical problems with even modest-sized training data (fewer than 8,000 cases). What is more, the LSTM required far less preprocessing and feature engineering than traditional approaches and was able to capture complex temporal dependencies. In just the last year, researchers have demonstrated similar successes in applying deep learning to a variety of medical applications, including diagnosis from longitudinal laboratory test results and early detection of sepsis in hospitalized patients. RNNs and related neural network architectures appear well-suited to modeling clinical data, which pose a number of challenges such as sequential structure with irregular sampling, heterogeneity, and non-random missing values.

Mortality is the main outcome of interest for ICU patients, and mortality risk models have a variety of applications in early warning systems and quality and reporting. To enable readers to reproduce our work, we utilize the publicly available 2012 PhysioNet Challenge data set, which is representative of the EHR data recorded by modern adult ICUs. We adapt the “Learning to Diagnose” LSTM architecture to predict future in-hospital mortality and implement it using the Deeplearning4J (DL4J) open source deep learning framework.

With the growing avalanche of digital clinical data, large-scale storage and parallel processing are paramount for clinical researchers. Hadoop clusters excel at collecting and processing transactional (or time series) data, as is common in health care. Marrying a strong storage and computing engine with a deep learning architecture adept at handling sequential data makes a lot of sense for next-generation data science applications, including in medicine. Thus, we also adapt our mortality prediction LSTM to be trained in distributed fashion using Spark and parameter averaging implemented in DL4J.

The PhysioNet Dataset

The publicly available dataset used in our example originates from a 2012 PhysioNet Challenge focused on predicting in-hospital mortality. We use only the labeled training data available in Training Set A.

To obtain the preprocessed version of the Challenge data that we use in this post, simply run the commands below:

  wget https://skymindacademy.blob.core.windows.net/physionet2012/physionet2012.tar.gztar xzvf physionet2012.tar.gzmv -i physionet2012 src/main/resources/physionet2012

tar xzvf physionet2012.tar.gz

 

LSTMs can be computationally expensive to train, especially when the training data include long sequences, which is quite common with irregularly sampled clinical time series data. This is because during learning, we unroll the LSTM’s computational graph, adding one layer per time step and backpropagating the error signal through each (an algorithm referred to as backpropagation through time). In this blog post, we will demonstrate how to accelerate training by parallelizing it using Spark 2.x on CDH 5.12.x. In particular, we will utilize DL4J’s implementation of parameter averaging, which leverages the data parallelism of Spark.

This example shares many similarities with the simpler DL4J LSTM example. Again we are using DL4J’s LSTM variant of a recurrent neural network, but this time instead of generating new characters, we are making a binary classification. Following the LSTM layer, we have a single sigmoid unit that outputs the probability of patient survival until discharge from the hospital.

Refer to the pom.xml for hadoop and spark related considerations.

Let us begin by examining the challenges of dealing with non-trivial sequence data in the PhysioNet Challenge 2012 dataset.

Our code is available publicly in a github repository.

git clone -b cdh_blog_july2017 https://github.com/deeplearning4j/dl4j-examples.git

Working with the PhysioNet Challenge 2012 Data

Our goal is to predict future in-hospital mortality for ICU patients using records taken during the 48 hours after admission to the ICU. Six variables are general descriptors, such as gender and weight, while the remaining consist of measurements, e.g., heart rate, taken throughout the 48 hour admission window, each with an associated timestamp. Because they are recorded manually by clinical staff, the interval between measurements often varies over time, across patients, and across variables. This is called irregular sampling. Also, not all variables are measured at all times, creating missing values. As a point of comparison, standard video has regular sampling with images at a fixed frame rate, and all pixels are fully observed in each frame.

We work with a preprocessed version of these data, the details of which are beyond the scope of this post. Here we provide only a very brief description of the preprocessing:

  • First we flatten the records so that each variable is a separate column. The original records are normalized with only three columns (timestamp, variable name, value).

  • Next, we transform the inputs, using one-hot encodings for categorical variables, e.g., gender, and rescaling real-valued variables like blood pressure to fall between 0 and 1. To handle missing values, we use carry forward imputation and added binary “missing” flags to indicate the imputed values, as recommended in Lipton, Kale, and Wetzel, MLHC 2016.

  • Finally, we replicate the descriptor values at each step, a common strategy for combining static and temporal inputs for LSTMs. Note that unlike the “Learning to Diagnose” paper, we work with the original irregularly sampled time series, rather than resampling them to a fixed timestep.

The preprocessed data are organized into folders within the outer physionet2012 folder, with inputs stored separate from labels. We are primarily interested in the sequence and mortality folders, which store the input time series and mortality labels, respectively. Each patient record is stored in a separate CSV file (labels have one row per record).

Even after preprocessing, the data still require several more steps of ETL and vectorization and pose several interesting challenges, such as aligning labels to the variable length sequences.

Building Vectorized Tensor Input for Training with DataVec

Our input data consists of multiple columns sampled (irregularly) over time, which prevents us from representing these data as an input vector of fixed size, as required by many machine learning models, including MLPs. Using an LSTM affords us the opportunity to retain the structure of the input by representing each record as a two-dimensional data structure (matrix) with time and variables on opposing axes.

During training we typically use minibatch-based stochastic gradient descent (or some variant), introducing an additional dimension. For sequential inputs, this means our minibatch of inputs actually has three dimensions (example x variable x time), in contrast to traditional two-dimensional mini batches (example x feature). This is illustrated in the figure below.

The DL4J ecosystem includes DataVec, a suite of machine learning ETL and vectorization tools designed for working with such data. DataVec can handle multiple types of data and has excellent support for sequential data commonly captured in HDFS. DataVec also runs natively on Spark so producing RDDs of NDArrays (the vector representation in DL4J) is relatively straightforward once we understand the API.

More specifically, we need to take the data files and convert them into JavaRDD objects. The JavaRDD objects are required for training neural networks on Spark. To start, we will read in the time series input features and the binary mortality labels. We assume that the data is stored in the Hadoop File System (HDFS) with one directory containing the feature files and another directory containing the mortality labels for each patient.

Our decision to store inputs and labels in separate files presents a minor challenge: have to ensure we can match an input record to its corresponding label across different RDDs. We first instantiate a pathToKeyConverter object (PTKC), which is used to match up files based on their file names. For example, “0.csv” in the feature directory and “0.csv” in the label directory would map to the same key: “0”. We then use the combineFilesforSequenceFile method from DataVecSparkUtil. This method combines data from separate files into pairs and stores them in a JavaPairRDD object.

  PathToKeyConverter pathToKeyConverter = new PathToKeyConverterFilename();  JavaPairRDD<Text, BytesPairWritable> rdd = DataVecSparkUtil.combineFilesForSequenceFile(sc, featuresPath, labelsPath, pathToKeyConverter);

JavaPairRDD<Text, BytesPairWritable> rdd = DataVecSparkUtil.combineFilesForSequenceFile(sc, featuresPath, labelsPath, pathToKeyConverter);

The next snippet of code constructs a function to read and convert the paired files. Because of the sequential nature of the data, we use CSVSequenceRecordReaders (CSRR) to generate two-dimensional structured data. We pass two CSRRs (one each for inputs and labels) into a PairSequenceRecordReaderBytesFunction (PSRRBF) function, which is used to combine two sets of data into DataVec format data. Note that the input files contain a header row, which we want to ignore.

  SequenceRecordReader srr1 = new CSVSequenceRecordReader(1, “,”);SequenceRecordReader srr2 = new CSVSequenceRecordReader();PairSequenceRecordReaderBytesFunction fn = new PairSequenceRecordReaderBytesFunction(srr1, srr2);

SequenceRecordReader srr2 = new CSVSequenceRecordReader();

This line uses the PSRRBF to convert the paired data stored in the JavaPairRDD into a Spark tuple containing the paired inputs and label, each represented as a list of list of DataVec Writables (each label list is a singleton).

  JavaRDD<Tuple2<List<List>, List<List>>> writables = rdd.map(fn);

The next step converts each tuple into a DataSet object containing NDArrays and metadata using a DataVecSequencePairDatasetFunction. We use the ALIGN_END flag to indicate that the mortality label should be aligned to the last timestep of the input sequence and that label sequence should be padded with zeroes everywhere else.

  int nClasses = 2;boolean regression = false;JavaRDD dataSets = writables.map(new DataVecSequencePairDataSetFunction(nClasses, regression, DataVecSequencePairDataSetFunction.AlignmentMode.ALIGN_END));

boolean regression = false;

Finally, we need to split the JavaRDD object into train, validation, and test sets. This can easily be done using the randomSplit function. The train set consists of 80% of the data, while the validation and test sets consist of 10% of the data each.

  JavaRDD[] splits = dataSets.randomSplit(new double[] { 0.80, 0.10, 0.10 } ,1);JavaRDD jTrainData = splits[0];JavaRDD jValidData = splits[1];JavaRDD jTestData = splits[2];

JavaRDD jTrainData = splits[0];

JavaRDD jTestData = splits[2];

We are now ready to configure a model and train it on Spark using the RDDs we created.

See here for more information about DataVec.

Let’s now move on and take a look at the LSTM RNN architecture.

Deep Learning Architecture and Training

In contrast with traditional machine learning, where practical expertise often focused on good feature design, effective application of deep learning requires matching problem and data to neural network architecture. The high-level version of this is the following:

  • For data without special structure, use traditional feedforward networks, e.g., Multi-Layer Perceptron Networks (MLPs)

  • For image data, use convolutional network architectures

  • For sequential (time series) data, use recurrent neural network architectures

  • For data with more complex structure, e.g., video, consider combining convolutional and recurrent layers within a hybrid neural network architecture

These are, of course, rules of thumb with a variety of exceptions. For example, one-dimensional convolutional networks can also be quite effective for sequential problems, often with shorter temporal dependencies. Nonetheless, the guidelines above form a good starting place. Because we are working with sequential data, we will configure an LSTM, similar to the “Learning to Diagnose” model. We like LSTMs for ICU data because they naturally handle variable length sequences and can encode (“remember”) longer histories, enabling us to capture important changes from time of admission.

The DL4J code for describing our target LSTM architecture is shown below:

  MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()                .seed(RANDOM_SEED)            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)                .learningRate(LEARNING_RATE)                .weightInit(WeightInit.XAVIER)                .updater(Updater.ADAM)                .dropOut(0.25)                .list()                .layer(0, new GravesLSTM.Builder().nIn(NB_INPUTS).nOut(lstmLayerSize).activation(Activation.TANH).build())                .layer(1, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX)                        .nIn(lstmLayerSize).nOut(nClasses).build())                .pretrain(false).backprop(true)                .build();

                .seed(RANDOM_SEED)            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)

                .weightInit(WeightInit.XAVIER)

                .dropOut(0.25)

                .layer(0, new GravesLSTM.Builder().nIn(NB_INPUTS).nOut(lstmLayerSize).activation(Activation.TANH).build())

                        .nIn(lstmLayerSize).nOut(nClasses).build())

                .build();

In the code above, we see the LSTM expressed as layers in the Java builder pattern. Notice that we feed our input directly into an LSTM layer, created using GravesLSTM.Builder(). We then feed that layer into an output layer with a softmax with two output units, one for each class (we could have also used a single sigmoid output). In this section of the code, we can change the network architecture by adding and removing layers and alter its behavior during training by changing, e.g., our update rule (set to ADAM above) or hyperparameters like learning rate.

Once the neural network is configured, we must specify how the job performs distributed training by creating a TrainingMaster instance with parameter averaging. We can then control the batch size of each worker and how frequently parameters are averaged and redistributed during training. See here for a more in depth description of the parameters.

  TrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1)                .averagingFrequency(3)                 .workerPrefetchNumBatches(2)                .batchSizePerWorker(BATCH_SIZE)                 .build();

                .averagingFrequency(3)

                .batchSizePerWorker(BATCH_SIZE)

We then initialize a SparkDl4jMultiLayer (SDML) instance which takes in the Spark context, neural network configuration, and training master as input. Using the SDML, we can train the model on the training data and evaluate it on the train and validation sets at every epoch:

  SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, conf, tm); for (int i = 0; i < NB_EPOCHS; i++) {           sparkNet.fit(jTrainData);           log.info(“Completed Epoch {}”, i);            ROC roc = sparkNet.evaluateROC(jTrainData);           log.info(“*** Train Evaluation **”);           log.info(“{}”, roc.calculateAUC());            roc = sparkNet.evaluateROC(jValidData);           log.info(“** Valid Evaluation ***”);           log.info(“{}”, roc.calculateAUC()); }

 

           sparkNet.fit(jTrainData);

 

           log.info(“*** Train Evaluation ***”);

 

           log.info(“*** Valid Evaluation ***”);

}

Because SDML handles all of the coordination and Spark parallelization behind the scenes for us, training is extremely user-friendly.

Evaluating the Results

Now that we have configured and trained our model using Spark, the last step is to test how well our deep neural network generalizes to new data. Thus, we evaluate our neural network on the testing split of the data below.

  ROC roc = sparkNet.evaluateROC(jTestData);log.info(“*** Test Evaluation ***”);log.info(“{}”, roc.calculateAUC());

log.info(“*** Test Evaluation ***”);

We report AUC or Area Under the Receiver Operating Characteristic (ROC) Curve, the most commonly reported discriminative metric in the medical literature. Deploying predictive models in the real world often involves trade offs. In the ICU, one such trade off involves identifying as many at-risk patients as possible while minimizing the number of false alarms. The ROC curve visualizes this trade off (true positive rate vs. false positive rate) for different choices of decision threshold, while the AUC summarizes how well a given model makes the trade off using the total area under the ROC curve. Alternatively, AUC measures the probability that the model assigns a higher score to a randomly chosen positive example, for instance, an at-risk patient, than to a negative example for instance, a healthy patient.

As we see, we achieve a test set area under the receiver operator characteristic curve (AUC) of 0.852, which is quite impressive for modeling raw time series data with only limited preprocessing. This exceeds previously published work (AUC 0.8450) using an MLP and hand-engineered features and is competitive with the winning entry of the Physionet 2012 Challenge (AUC 0.8602), which used a Bayesian ensemble of decision trees and hand-engineered features. With careful tuning, we expect these results could be further improved.

We have demonstrated have to use DL4J to build and train an RNN to forecast future in-hospital mortality from the first 48 hours of clinical observations recorded from patients in an ICU. We used only limited preprocessing, for example rescaling values, and not feature engineering, as well as simple strategies to handle irregular sampling and missing values. The resulting model is competitive with the winning entry from the original Physionet 2012 and outperforms previously published results using MLPs and hand-engineered features. Our results, while far from ready for deployment at the bedside, are encouraging because that deep learning can be successfully applied to complex clinical problems and data, even of modest size.

This general framework could be applied to other clinical prediction problems, e.g., length of stay or readmission, as well as to more applications involving time series data. with a bare minimum of upfront feature engineering. Because deep learning requires less upfront feature engineering required and due to the availability of tools for automating the tuning of neural nets, deep learning may in fact be more economical for healthcare organizations where data science expertise may be in short supply. We also showed how training can be easily scaled up using DL4J’s tight integration with Spark and Cloudera’s CDH, making DL4J especially well-suited for enterprise environments, like those found at many clinical institutions.

For more information on Deep Learning in general and DL4J specifically check out our website. We also suggest the O’Reilly book “Deep Learning: A Practitioner’s Approach” by Josh Patterson and Adam Gibson. We also invite the reader to join us at Strata NYC 2017 on Tuesday Sept 27 for a tutorial session on building deep learning healthcare applications with DL4J and Cloudera Data Science Workbench. In this workshop we’ll taking the above deep learning example and learn how to model sequence data with DL4J and Cloudera’s Data Science Workbench.

  • Audrey J. Weiss and Anne Elixhauser. Healthcare Cost and Utilization Project (H-CUP) Statistical Brief #180, October 2014.

  • JaWanna Henry, Yuriy Pylypchuk, Talisha Searcy, and Vaishali Patel. “Adoption of Electronic Health Record Systems among U.S. Non-Federal Acute Care Hospitals: 2008-2015.” ONC Data Brief 35, May 2016.

  • Zachary C Lipton,* David C. Kale,* Charles Elkan, and Randall Wetzel. “Learning to Diagnose with LSTM Recurrent Neural Networks.” International Conference on Learning Representations (ICLR), 2016.

  • Narges Razavian, Jake Marcus, and David Sontag. “Multi-task Prediction of Disease Onsets from Longitudinal Laboratory Tests.” Machine Learning for Healthcare Conference (MLHC), 2016.

  • Joseph Futoma, et al. “An Improved Multi-Output Gaussian Process RNN with Real-Time Validation for Early Sepsis Detection.” Machine Learning for Healthcare Conference (MLHC), 2017.

  • The Royal College of Physicians. “National Early Warning Score (NEWS): Standardising the assessment of acute-illness severity in the NHS.” Report of a working party, July 2012.

  • Jack E. Zimmerman, et al. “Acute Physiology and Chronic Health Evaluation (APACHE) IV: hospital mortality assessment for today’s critically ill patients.” Critical Care Medicine 34.5 (2006): 1297-1310.

  • Zachary C. Lipton, David C. Kale, and Randall Wetzel. “Directly Modeling Missing Data in Sequences with RNNs: Improved Classification of Clinical Time Series.” Machine Learning for Healthcare Conference (MLHC), 2016.

  • Alistair Johnson, et al. Patient Specific Predictions in the Intensive Care Unit Using a Bayesian Ensemble. Computers in Cardiology (CinC), 2012.

  • David C. Kale, et al. Causal Phenotype Discovery via Deep Networks. American Medical Informatics Association (AMIA) 2015 Annual Symposium, 2015.

  • Harpeet Singh, et al. iNICU – Integrated Neonatal Care Unit: Capturing Neonatal Journey in an Intelligent Data Way. Journal of Medical Systems (2017) 41: 132