Visualization approach to assess the robustness of neural networks for medical image classification
The use of neural networks for diagnosis classification is becoming more and more prevalent in the medical imaging community. However, deep learning method outputs remain hard to explain. Another difficulty is to choose among the large number of techniques developed to analyze how networks learn, as all present different limitations. In this paper, we extended the framework of Fong and Vedaldi [IEEE International Conference on Computer Vision (ICCV), 2017] to visualize the training of convolutional neural networks (CNNs) on 3D quantitative neuroimaging data. Our application focuses on the detection of Alzheimer’s disease with gray matter probability maps extracted from structural MRI. We first assessed the robustness of the visualization method by studying the coherence of the longitudinal patterns and regions identified by the network. We then studied the stability of the CNN training by computing visualization-based similarity indexes between different re-runs of the CNN. We demonstrated that the areas identified by the CNN were consistent with what is known of Alzheimer’s disease and that the visualization approach extract coherent longitudinal patterns. We also showed that the CNN training is not stable and that the areas identified mainly depend on the initialization and the training process. This issue may exist in many other medical studies using deep learning methods on datasets in which the number of samples is too small and the data dimension is high. This means that it may not be possible to rely on deep learning to detect stable regions of interest in this field yet.
💡 Research Summary
The paper investigates the robustness of convolutional neural networks (CNNs) used for Alzheimer’s disease (AD) classification from 3‑dimensional quantitative neuroimaging data, specifically gray‑matter (GM) probability maps derived from structural T1‑weighted MRI. The authors extend the mask‑based visual explanation framework originally proposed by Fong and Vedaldi (ICCV 2017) to the domain of 3‑D medical images. Their workflow consists of two stages: (1) training a CNN to discriminate AD patients from cognitively normal (CN) controls, and (2) fixing the trained network’s parameters and learning a voxel‑wise mask that, when applied to correctly classified AD images, forces the network to misclassify them as CN. The mask is constrained by two regularization terms: an ℓ1‑norm encouraging sparsity (minimal number of voxels altered) and a total‑variation term promoting spatial smoothness. After optimization, voxels with mask values above 0.95 are binarized to 1, effectively “restoring” GM in a minimal set of regions.
Data were drawn from two public cohorts: the Alzheimer’s Disease Neuroimaging Initiative (ADNI) and the Australian Imaging, Biomarkers and Lifestyle (AIBL) study. After standard preprocessing with the Clinica pipeline (bias correction, non‑linear registration to MNI space, and Unified Segmentation), GM probability maps were obtained for each subject. The ADNI dataset was split into a 5‑fold cross‑validation (CV) training/validation set and an independent test set of 100 age‑ and sex‑matched subjects per class. AIBL served as an external test set. Hyper‑parameter selection for the CNN architecture was performed via random search; the final architecture comprised seven convolutional blocks (each containing 1–3 sub‑blocks of Conv‑BatchNorm‑LeakyReLU), followed by dropout and a fully‑connected layer. Training employed Adam optimization, early stopping based on validation loss, and cross‑entropy loss.
Classification performance was strong: average balanced accuracy across the five CV folds was 0.89 on validation and 0.88 on the ADNI test set; on AIBL the average balanced accuracy was 0.90, indicating no over‑fitting and justifying the use of the model for interpretability studies.
The visualization component was evaluated at two granularity levels. Group‑level masking used all correctly classified AD subjects from the training/validation folds to learn a single mask per CV model. Session‑level masking optimized a mask for each individual image, with stronger regularization to avoid over‑fitting to a single sample. To assess robustness, the authors compared masks obtained from different CV folds (different data splits) and from repeated runs of the same fold (different random initializations). Similarity was quantified in two ways: (a) the cosine similarity of ROI‑wise mask density vectors (using the 120 AAL2 regions), and (b) the difference in the CNN’s output probability for the true class when an image is masked by a mask derived from another model. Low similarity scores indicated that the visual explanations depended heavily on the particular training trajectory.
Results showed that, despite the variability, all group‑level masks consistently highlighted brain regions known to be affected in AD, such as the hippocampus, parahippocampal gyrus, fusiform gyrus, and amygdala. However, the exact spatial extent and intensity of the masks differed markedly across runs, revealing that the CNN’s learned decision boundaries are unstable when trained on high‑dimensional, low‑sample‑size data. Session‑level masks captured subject‑specific longitudinal changes, with higher intra‑subject similarity than inter‑subject similarity, suggesting that the method can reflect individual disease progression but still suffers from over‑fitting at the single‑subject level.
The authors conclude that while the adapted mask‑based visualization provides biologically plausible explanations and can extract coherent longitudinal patterns, the underlying CNN training is not robust: the regions identified are strongly influenced by weight initialization and stochastic training dynamics. This instability likely extends to many medical imaging studies where sample sizes are limited and data dimensionality is high, implying that deep‑learning models cannot yet be relied upon to discover stable biomarkers without additional safeguards (e.g., larger cohorts, ensemble methods, uncertainty quantification).
In summary, the paper makes three main contributions: (1) a methodological extension of mask‑based visual explanations to 3‑D neuroimaging, (2) an empirical demonstration that the visualized regions align with known AD pathology, and (3) a systematic analysis showing that CNN training variability leads to divergent visual explanations, highlighting the need for caution when interpreting deep‑learning models in small‑sample medical imaging contexts. Future work should focus on improving model stability, integrating robustness metrics into the training pipeline, and validating visual explanations on larger, multi‑site datasets.
Comments & Academic Discussion
Loading comments...
Leave a Comment