By Giovanni Cinà (Pacmed)
This post is especially for you if: you work in health care, you are interested in AI, you have time to read cool blog posts.
The world of health care is abuzz with the excitement for innovative AI applications, with new ideas and products seeing the light on a monthly basis. Despite the initial successes, anyone involved in this field sooner or later is faced with the same problem.
No matter how mind-blowing your model is, you have to be able to explain your results to doctors and patients. In other words, produce an argument explaining why your algorithm came up with a certain prediction or suggestion.
“I showed him the matrix of weights of the seventh hidden layer of my Neural Network, but somehow he still didn’t get it.” — Anonymous Data Scientist
At Pacmed we focus on building decision support software, namely tools to aid doctors and nurses in taking difficult decisions. Thus our aim is not to produce explanations accessible to the (wo)man in the street, but to engage with those who have medical knowledge. Nevertheless, most of what I am about to say still applies, mutatis mutandis, to AI applications directed to patients.
Explainability is key, for two reasons. First, care providers are the users of the model: they should trust it and understand it. To be clear: I am not suggesting that doctors should learn data science. I am saying that the models should provide explanations of why, for example, this patient is diagnosed with a certain disease and another one is not, so that doctors can check that it makes sense and be convinced.
Second, providing care is a highly sensitive task. We are all fine not knowing how and why Google Translate’s French translation of English often sounds kinky; but what about a software that suggests therapies, or surgery options? Would you trust a black-box model? The answer may vary from person to person, but it’s safe to say that this is a major concern.
The current state-of-the-art models in the vast majority of AI applications are Neural Networks (NNs henceforth). Unfortunately, as most AI practitioners will admit, NNs are hard to interpret. Whence the question: can we make NNs explainable while retaining high performance?
Luckily there is a great community of researchers already working on explainability of AI, so we can stand on the shoulders of giants and exploit some of the newest innovations. In this post we try out the ideas of this paper, which are already masterfully explained in this blog post, in one of our projects. All the credit for these insights goes to the authors of the paper.
Explaining Neural Networks with Decision Trees
For classification tasks there is a model that is relatively straightforward to understand: the Decision Tree (DT for short).
Intuitively a DT is a flowchart, where each examples is passed at the root node and ‘travels’ through the graph from top to bottom, where it reaches a terminal node (called a ‘leaf’) and receives a classification. Each inner node of the tree is a check on some attribute of the example, determining which branch is more appropriate. DTs are often use in decision making in many contexts and are very easy to visualize.
Note how easy it is to justify the result: for any instance you just need to report the outcomes of the inner nodes it traverses. Using the DT in the picture, for example, you could explain that Mrs Fletcher did not receive a loan because:
- she is senior (outcome of top node)
- she has a fair credit rating (outcome of the node ‘Credit_rating’)
As a result she ends up in a ‘no’ leaf. Very intuitive.
I hear you say: “Why don’t we use DTs then?” Because they have sub-optimal performance in most cases. Ideally we would want a model with the performance of a NN and the explainability of a DT.
Here is the first idea: let’s train a NN that achieves a high performance and then use a DT to approximate that NN.
That is actually a pretty good intuition. The only catch is that DTs are not always nice. When problems are very complex DTs can become massive, which means that an example must traverse many nodes to be classified, resulting in very complicate explanations.
Thus we want a small DT, or to be more precise, we want a DT that on average gives us simple explanations. Let me call ‘path length’ the length of the path that an example must travel from the root to reach its designated leaf. In a nutshell, we want a DT that has low Average Path Length (APL), where the average here means average over all the examples we want to classify.
Example: Pacmed at the ICU
Let’s look at a concrete example. Pacmed’s team is currently working together with Dr. Patrick Thoral, MD and Dr. Paul Elbers, MD, PhD at the Intensive Care Unit of Amsterdam UMC (location VUmc) to solve the following task: predict the probabilities of readmission and/or mortality at the moment of discharge, enabling doctors to choose the right time to transfer a patient from the Intensive Care Unit to another ward.
This problem is challenging due to limited amount of data, a great imbalance between classes (luckily, in proportion only very few people die or are readmitted to the ICU), the high dimensionality of the data-points and several other technical difficulties. If you are curious about this exciting project you can find all the juicy details in this blog post.
Here I will just employ one of the models that were used in the development phase (thus what you see here are not the final results), a NN that had good performance but poor explainability. When we fit a DT to approximate that model we get the following tree.
Here is how to read the tree: each node contains the information about the check that is performed and the information concerning how many patients go through that node in a statement ‘value = [a, b]’. The first number is the amount of patients that are classified as safe to discharge, the second are patients not safe to discharge. The nodes go from red to blue depending on the proportion of patients in these two classes that go through that node: red means mostly safe patients, blue means mostly patients at risk.
The tree does a decent job at identifying patients that end up in blue leafs, but if you inspect the leaf indicated by the black arrow something else becomes apparent. Out of the 2400 patient that are classified by this tree, more than a 1000 ‘safe’ patients are classified only after 8 steps.
Summarizing the content of the nodes from the top to the leaf indicated by the black arrow, these are patients that:
- do not come from gastroenterology
- had low respiratory rate on the first day of recovery
- have a stable score in the Glasgow Coma Scale (measuring alertness and being a proxy for neurological disorders)
- do not come from an emergency setting
- have respiratory measurements that are not worrisome at the moment
- come from cardiology
- have heart-rate measurements that are not worrisome at the moment
This list, which is already a summary, is rather long, which means that the APL of this tree is rather high. Concretely, doctors will have to read through long explanations for a large portion of the patients to understand the classification of the model.
Back to the main story
If on the other hand we have a DT with low APL, a user will get (on average) a simple and clear explanation of why an example has been classified in a certain way. Now comes the second idea: we want to ensure that the NN we train ‘takes into account’ the APL of the DT that will approximate it.
The most direct way to implement this idea would be to add a term APL(DT(Ω)), capturing the APL of the DT approximating a NN with parameters Ω, to the loss function of the NN itself. Unfortunately this is not feasible since APL(DT(Ω)) is not differentiable, so we need a workaround.
It’s now time for the third idea. The term APL(DT(Ω)) is in fact a function that takes as input the parameters Ω of a NN and returns the APL of the DT approximating said NN. So we need something to approximate APL(DT(Ω)) in a differentiable manner. We are going to use…another NN!
Optimizing the Neural Network for explainability
So, to recap (I see some confused faces from those webcams):
- We have a NN, let’s call it MarkBot, that solves the task at hand but lacks interpretability.
- In MarkBot’s loss function we add a term A that indicates how hard it is to understand MarkBot, so that MarkBot will try to minimize it while training (i.e. MarkBot will try to maximize its own explainability). Call this tweaked NN MarkBot_v2.
- The term A is provided by another NN, let’s call it BettyBot, which estimates the APL of the DT approximating MarkBot (when the latter has parameters Ω)
Therefore it only remains to train BettyBot. For this purpose we need a dataset of pairs (Ω, APL(DT(Ω))), so we train the original MarkBot many times and produce many pairs like that, effectively building a dataset. In my implementation I run MarkBot 1000 times.
Once BettyBot is trained, we plug its prediction in the loss function of MarkBot_v2 and train the latter. Lo and behold: we have a NN optimized for explainability!
When we try this on our ICU model, we get a new model whose associated decision tree looks like this:
The tree looks bigger at first glance, but look at the leaf indicated by the black arrow. Now 1300 or so ‘safe’ patients are classified in 3 steps, namely the patients that:
- come from cardiology
- have both heart-rate and blood pressure under control
These are typically patients that underwent some cardiac surgery and whose heart-rate and blood pressure are back to acceptable levels. Note: these are mostly the same patients that the previous tree classified with a complex argument. The model has learned that these patients can be classified with an easy explanation. Since these patients form a large part of the population, this tree has now lower APL.
It must be added that, in general, this improvement comes with a price, in the sense that the performance of MarkBot_v2 may be lower than that of vanilla MarkBot. Depending on the task you are trying to solve, explainability might be something between a nice perk and an absolute necessity, so you may want to pay this price or not. In health care we certainly cannot do without it.
Explainability is a core issue for AI application in health care. In the case of Pacmed, we need to devise methods to render our models accessible for care providers. In this post I focused on using ‘short’ Decision Trees to approximate Neural Networks. I described the solution proposed in this paper and showed how it worked when I implemented these ideas on Pacmed’s model for Intensive Care Unit.