Detection of Knee Pathologies from MRI Using Deep Learning Models (CNNs) - Spring 2026
Slides: slides.html ( Go to slides.qmd to edit)
1 Introduction
Musculoskeletal disorders affecting knee and related tissues like osteoarthritis, ACL injury and meniscal injury are widespread and greatly affect quality of life and mobility of patients globally. Current methods for diagnosis include X-ray radiography but these are ineffective for soft tissue diagnosis and, hence, diagnosis of early changes are often delayed until knee is significantly degenerated (Panwar et al. 2025). On the other hand, magnetic resonance imaging has emerged as a leading diagnostic method for musculoskeletal and sports medicine imaging due to several features, like high contrast resolution and multiplanar imaging, which can aid in thorough evaluation of large joints, like knee along with ligaments, cartilage, bone, tendons, and muscles, and provides distinct images of different soft tissues to get a complete overview (Qiu et al. 2021). However, current methods of MRI scans analysis manually pose huge challenge for radiologists due to time consumption, high rate of errors, low reproducibility, heavy cognitive load, and high inter-observer variability (Bien et al. 2018).
Manual segmentation of knee structures is a time-consuming step in MRI-based knee assessment. Deep learning methods can potentially improve this process by enabling fast and reproducible automated segmentation of knee tissues. In their study, (Zhou et al. 2018) developed a deep convolutional neural network for automatic cartilage and meniscus segmentation of knee joints, achieving highly accurate results for multiple types of knee tissues, to facilitate efficient assessment of knee anatomy and pathology.
Artificial intelligence techniques based on deep learning (DL), particularly deep learning Convolutional Neural Networks (CNNs), allow for automated, objective, and scaleable image analysis, which constitutes a major paradigm shift for overcoming challenges in medical diagnostics. CNNs are especially suited for the analysis of medical images since they have the ability to learn automatically hierarchical features directly from raw pixel data. Including pathological changes that the human eye cannot even detect. 2D and 3D CNN models have recently shown promising results in the segmentation and even grading of musculoskeletal tissues and degree of damage of their structures. Although crucial advances have been achieved, however, there are more important challenges to be addressed, including the acquisition and labeling of a large, high-quality dataset and the lack of generalization capability in different MRI platforms and hardware.
Fast diagnosis and monitoring of patients is a current challenge. To address this problem, efficient automated systems for the multiplex detection in time-consuming workflows are needed to help doctors. The problem of discrimination of multiple co-occurring knee injuries in the knee is still an open challenging problem. Based on recent CNN models, in this paper we present an automated system for the detection of various knee common injuries and degeneration from knee MRI images. The system can enable early intervention and help develop strategies for better patient care.
Knee osteoarthritis assessment based on visual inspection of clinical images is traditionally time-consuming and operator-dependent. Recent studies indicate that knee osteoarthritis can be automatically detected from imaging data using a deep learning-based convolutional neural network (CNN) approach, resulting in robust and accurate classification.
Can a 3D convolutional neural network learn spatial and anatomical information from a multi-plane knee MRI volume to accurately identify multiple knee pathologies? Yes, we can. In this study, we trained a 3D CNN to identify 3 different knee pathologies simultaneously, namely ACL tears, meniscus tears and knee abnormalities.
To achieve this goal, a custom deep learning architecture, using a 3D Convolutional Neural Network (CNN), was implemented and fine-tuned on the large image dataset MRNet provided by the Stanford University. The aim of the novel architecture is to distinguish between three different knee pathologies: normal knee, knee with an ACL tear and knee with a meniscus tear. In summary, three individual binary classification problems were formulated and performed in parallel. The network therefore classified images into three categories: presence or absence of an ACL tear, presence or absence of a meniscus tear, and presence or absence of any other abnormalities.
The performance of our approach was individually evaluated for each type of classification in terms of Area Under the ROC Curve (AUC), accuracy, sensitivity and specificity. However, to provide an idea of clinical applicability, we have also tested the approach for all knee pathologies simultaneously.
2 Methodology
Convolutional Neural Network (CNN)
A Convolutional Neural Network or (CNN) is a deep learning model which processes structured data. CNNs have been adopted due to their remarkable efficiency in image processing and medical imaging applications. In such systems, the CNN can automatically extract vital features directly from images without the requirement for manually extracting features.
In the early layers of a deep CNN, it is possible to observe the learning of basic spatial features, such as gradients and structural contours. In the subsequent layers, the network continues to learn more complex and abstract features related to normal and pathological tissues, such as severely degenerated cartilage or ACL fibers disruption.
Commonly, 1D, 2D or even 3D CNNs are employed for processing sequential data. For image data, however, 2D CNNs are standard practice since individual 2D slices are processed. For volumetric data, however, a 3D convolution operation can be performed that considers not only the neighbour pixels in height and width but also in depth. This is particularly relevant for MRI data, where slices are highly correlated.
Deep learning techniques such as 3D CNNs have achieved state-of-the-art results in knee MRI for the detection and grading of ACL tears as well as for the assessment of meniscal tears and osteoarthritis. Inter-slice spatial relationships are crucial for the adequate understanding of 3D knee anatomy, and 3D CNNs are currently the most advanced artificial intelligence approach in orthopedic imaging, in pursuit of computerized diagnosis.
Basic Architecture of a CNN
A typical CNN consists of several key components that enable hierarchical feature extraction and classification (Yeoh et al. 2021; Awan et al. 2021).
1. Convolution Layer
Applies filters (kernels) to extract features from input images.
Mathematical representation:
\[ Y(i,j) = \sum_{m}\sum_{n} \left[ X(i+m, j+n)\cdot W(m,n) \right] + b \]
Where: - \(X\) = Input image - \(W\) = Filter/kernel - \(b\) = Bias - \(Y\) = Output feature map
This operation allows CNNs to learn spatially localized features and is fundamental to many medical image analysis frameworks (Yeoh et al. 2021; Qiu et al. 2021).
2. Activation Function
In order to enable a deep neural network to learn also complex features, it is common to use non-linear activation functions after the convolutional layers. The most common used activation function in that context is the Rectified Linear Unit (ReLU).
\[ f(x) = \max(0, x) \]
ReLU facilitates faster training and is known to overcome vanishing gradients when compared to traditional sigmoid and tanh functions in deep learning architectures.
3. Pooling Layer
The pooling layers reduce the spatial dimensions of the feature maps extracted from input images keeping only the most relevant information. Among the different types of pooling operations, max pooling is the most popular one. In the max pooling approach, the model selects the maximum value within a defined window during a sliding operation. It helps in reducing the number of parameters and contributes to the model’s robustness for small spatial variations across images.
4. Fully Connected Layer
These classification layers learn features from previous layers and use these features to determine a target classification (e.g. the presence of a knee pathology, severity of knee pathology, etc.), such as in (Awan et al. 2021) improved knee-OD diagnostic model.
5. Output Layer
Uses Softmax (multi-class) or Sigmoid (binary/multi-label).
Deep learning techniques, particularly convolutional neural networks (CNNs), have achieved outstanding performance in several medical imaging applications, including MRI-based disease diagnosis and tissue segmentation.
Types of CNN:
Deep Neural Networks known as CNNs are categorized based on the dimensionality of the input data and the sliding direction of the kernel.
| Feature | Type of input data used to construct the feature | Direction of movement allowed for the kernel in the feature |
|---|---|---|
| 1D CNN | Time-series, audio, ECG signals (Ige & Sibiya, 2024). | These filters move across the 1D data (time/sequence) instead of the regular 2D images. |
| 2D CNN | 2D input features (e.g. 2D images such as: Grayscale/RGB images, medical X-rays). | 2D regions (slides over height dimension, then width dimension). |
| 3D CNN | Videos, MRI/CT scans, 3D point clouds (Guida, Zhang, and Shan 2021). | The network slides along three dimensions (height, width, and depth). |
3D Convolutional Neural Networks:
A 3D Convolutional Neural Network (3D CNN) is a natural extension to the standard 2D Convolutional Neural Network (2D CNN) architecture for handling volumetric data by employing three-dimensional convolutional kernels that slide over all three spatial dimensions (i.e., height, width, and depth). Unlike 2D CNNs, which treat each 2D image slice independently, 3D CNNs maintain the spatial relationships between slices and leverage these relationships to learn more effective anatomical features from the entire volume. This architecture is especially beneficial for many medical imaging modalities, such as MRI, where relevant information extends several slices in depth.
Current approaches for 3D medical image analysis using deep learning methods are mostly based on 3D CNNs. Typically, a 3D CNN consists of repeated 3D convolutional layers (convolutions) followed by nonlinear activation functions, 3D pooling layers (down-sampling) and fully connected layers (classification). The early layers learn low-level volumetric features (edges, textures etc.) and the deeper layers learn high-level features representing organs or diseases. In terms of computational costs 3D CNNs are significantly more expensive than 2D CNNs but yield state of the art results in volumetric classification, detection and segmentation tasks.
Three-dimensional convolutional neural networks (3D CNNs) have been successfully employed in multiple tasks in knee MRI analysis. Building on these results, in this work, we develop a deep learning network, MRNet, capable of accurately diagnosing anterior cruciate ligament tears and abnormalities of the meniscal structures. In another work,(Pedoia et al. 2019) 3D CNNs were employed for the detection. In osteoarthritis, knee MRI scans can classify the severity of knee osteoarthritis. In this work,(Guida, Zhang, and Shan 2021) a 3D CNN was trained to classify knee MRI scans according to the degree of knee osteoarthritis.
Besides classification, 3D CNNs have also been explored for various anatomical segmentation tasks. For example, (Zhou et al. 2018), (Liu et al. 2018) developed deep convolutional networks that, combined with deformable modeling, enable accurate tissue segmentation of the knee joint, an essential component for quantitative musculoskeletal assessment.
Researches have made further progress in enhancing capabilities of CNN-based systems for improving knee injuries, especially ACL tear detection and knee osteoarthritis grading. Self-supervised learning with methods like BYOL can be used to reduce the need of large datasets with corresponding labels. Instead, federated learning approach can be used to develop and share a system for knee injuries across institutions without even sharing patient data. Emerging frameworks for multimodal and continuous monitoring suggest that developed AI systems will be integrated into future orthopedic care systems. Additionally, alternative models like Vision Transformers are being explored for osteoarthritis grading. 3D CNNs are still one of the preferred solutions for volumetric knee MRI analysis.
3D CNNs have significant potential in musculoskeletal imaging for the detection, classification, grading and segmentation of knee pathologies in an automated fashion. This architecture preserves the 3D spatial relationships within the images enabling an anatomical understanding of the disease.
The 3D convolution is defined as:
\[ Y(i,j,k) = \sum_{m}\sum_{n}\sum_{p} X(i+m, j+n, k+p)\cdot W(m,n,p) + b \]
Where:
- \(X\) = Input 3D MRI volume
- \(W\) = 3D convolution kernel
- \(b\) = Bias
- \(Y\) = Output feature map
- \(i,j,k\) = Spatial voxel indices
This operation extracts volumetric features across depth, height, and width.(Guida, Zhang, and Shan 2021).
Applications in Knee MRI Analysis
3D Convolutional Neural Networks (3D CNN) have shown great potential in processing musculoskeletal images, especially knee MRI.
Tissue Segmentation: The techniques in this subcategory, including several deep CNN-based architectures developed in 2018, such as DDN, DCAN, and AC-Net, have been able to achieve results for knee joint anatomy segmentation and improved tissue segmentation accuracy by (Liu et al. 2018) and (Zhou et al. 2018).
Meniscus and Cartilage Degeneration Detection: In(Pedoia et al. 2019), 3D CNNs were used to identify and staging of degenerative changes of meniscus and patellofemoral cartilage. Special attention was paid to extract volumetric features from knee MRI scans.
Osteoarthritis Classification: Classification of knee osteoarthritis in knee by (Guida, Zhang, and Shan 2021) using MRI scan using deep learning technique 3D CNN. It achieves higher accuracy compared to the 2D technique. Classification and severity assessment of osteoarthritis using CNN technique is also shown by (Rani et al. 2024).
ACL Tear Detection: ACL tears can be detected from the knee MRI scan using the deep learning technique MRNet proposed in(Bien et al. 2018) and then improved using more efficient learning techniques such as self-supervised learning.
Multimodal and Federated Learning Approaches New frameworks have been derived to combine 3D CNNs with federated and few-shot learning techniques to improve generalization across institutions. The potential of AI-based multimodal systems for orthopedic diagnostics has been demonstrated.
Performance and Advantages:
3D CNNs provide significant benefits in medical imaging:
Volumetric Context: Features from adjacent slices are taken into account, bio-markers for certain conditions, such as cartilage degradation, that are not present in single 2D images can be detected.
Higher Accuracy: Evidence from brain and knee imaging studies shows that 3D model has higher accuracy compared to 2D and 2.5D approaches.
Efficiency in Convergence: For volumetric data, 3D models can often converge 20% to 40% faster than their 2D equivalents as the model learns.
Limitations and Assumptions
Despite their power, 3D CNNs face specific challenges:
Computational Cost: 3D models need huge amount of memory to run which is around 20 times bigger than 2D models to view and interact (Avesta et al. 2023).
Large Data Requirements: Large annotated datasets are the primary requirement to achieve the highest accuracy for CNNs. However, for medical image datasets, these datasets are small and expensive to obtain and manually label, which in turn hampers the generalisation performance of the learnt models. In recent years, self-supervised learning has emerged as a promising technique for expanding the current repertoire of CNNs to tackle these challenges.
Risk of Overfitting: One of the major challenges faced by the researchers during CT image segmentation is the overfitting of the CNN due to limited number of medical images. Usually 3D CNNs are employed to tackle this problem which also increase the overfitting because of depth.
3 Analysis & Results:
3.1 Dataset Description:
The MRNet dataset, a publicly available set of knee MRI scans gathered by the School of Medicine at Stanford University, is used in this research. The dataset contains images gathered from clinical studies performed at the Stanford University Medical Center over an eleven-year period from 2001 to 2012. The structure of the dataset is divided into a training and validation set, with each MRI study labeled for three different diagnoses:
- The presence of any abnormality,
- Tears of the anterior cruciate ligament (ACL),
- Meniscal injuries.
| exam_id | acl_label | meniscus_label | abnormal_label | acl_diagnosis | meniscus_diagnosis | abnormal_diagnosis |
|---|---|---|---|---|---|---|
| 0 | 0 | 0 | 1 | No ACL Tear | No Meniscus Tear | Abnormal |
| 1 | 1 | 1 | 1 | ACL Tear | Meniscus Tear | Abnormal |
| 2 | 0 | 0 | 1 | No ACL Tear | No Meniscus Tear | Abnormal |
| 3 | 0 | 1 | 1 | No ACL Tear | Meniscus Tear | Abnormal |
| 4 | 0 | 0 | 1 | No ACL Tear | No Meniscus Tear | Abnormal |
3.1.2 Data Organization The dataset has a standardized partition structure, with a train/ directory for model development and a valid/ directory for model evaluation. The diagnostic ground truth is provided in the form of two CSV files, namely train-acl.csv and valid-acl.csv, each containing the examination identifier and a binary diagnostic label, where 0 represents the absence of an ACL tear and 1 represents a confirmed ACL tear.
3.1.3 Problem Formulation:
This study aims to investigate the effectiveness of a 3D Convolutional Neural Network (3D CNN) in automatically detecting knee abnormalities-including ACL tears, meniscus tears, and general abnormalities-from volumetric MRI scans in the MRNet dataset.
The central research question is: Can a 3D CNN effectively learn spatial and anatomical features from multi-plane MRI volumes to accurately detect multiple types of knee injuries?
To address this, the proposed model leverages:
- 3D convolutional layers for volumetric feature extraction
- Multi-plane fusion to integrate complementary anatomical views
- Supervised learning to identify diagnostic patterns
This formulation enables the development of an automated system capable of assisting in clinical diagnosis by providing reliable predictions across multiple knee conditions.
3.2 Dataset Visualization & Exploratory Analysis:
Exploratory Data Analysis (EDA) is a fundamental component of any machine learning model’s pipeline. It is critical to understand the structure and composition of the data before moving forward to train any model. This section of the report outlines a detailed exploratory data analysis of the MRNet dataset, proposed by (Bien et al. 2018). in 2018, which consists of knee MRI images from three different imaging planes, i.e., sagittal, coronal, and axial. Each of these imaging sessions is associated with a binary variable representing the ACL tear condition of the knee. The class balance, volumetric structure, imaging characteristics, and cross-plane imaging of the MRNet dataset have all been addressed in the following visualizations, which were critical to the decisions made in the subsequent sections of this report.
3.2.1 Class Distribution:
Code
class_counts = labels_df["diagnosis"].value_counts()
plt.figure(figsize=(6, 4))
ax = class_counts.plot(kind="bar",
color=["#4C9BE8", "#E8654C"],
edgecolor="white",
width=0.5)
# Add count labels on top of each bar
for bar in ax.patches:
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 1,
str(int(bar.get_height())),
ha="center", va="bottom", fontweight="bold"
)
plt.title("Class Distribution – ACL Tear Labels", fontsize=13, fontweight="bold")
plt.xlabel("")
plt.ylabel("Number of Exams")
plt.xticks(rotation=0)
plt.tight_layout()
plt.show()
plt.savefig("figures/class_distribution.png", dpi=300)The bar chart describes the class distribution of the ACL tear labels in the data set. The horizontal axis indicates the two classes in the diagnostic results, namely “No ACL Tear” and “ACL Tear,” while the vertical axis indicates the number of examinations in the data set. The data set contains a total of 1,130 examinations, with 922 (81.6%) labeled as “No ACL Tear” and 208 (18.4%) labeled as “ACL Tear.” This is a significant imbalance between the two classes, with an approximate ratio of 4.4:1 in favor of the negative class. This imbalance is clinically realistic, since ACL tears are not as common in the general population as the converse; nonetheless, it is a significant methodological limitation in the application of supervised machine learning, since a model trained on such an unbalanced data set will likely develop a bias toward the majority class, thereby achieving high accuracy but at the expense of poor sensitivity for the clinically important minority class. This figure is important because it establishes the underlying data set characteristics that directly influence the model design, loss function, and evaluation metric used in the study.
3.2.2 Sample Middle Slices from the Sagittal Plane
Code
# Visualise Sample MRI Slices (Sagittal Plane)
# Grab exam IDs for each class
positive_ids = labels_df[labels_df["label"] == 1]["exam_id"].values[:2]
negative_ids = labels_df[labels_df["label"] == 0]["exam_id"].values[:2]
# Combine them with their labels for the plot title
sample_exams = [(eid, "ACL Tear") for eid in positive_ids] + \
[(eid, "No ACL Tear") for eid in negative_ids]
fig, axes = plt.subplots(1, 4, figsize=(14, 4))
for ax, (exam_id, diagnosis) in zip(axes, sample_exams):
npy_path = os.path.join(DATA_DIR, "train", "sagittal", f"{exam_id:0>4}.npy")
volume = np.load(npy_path) # shape: (S, H, W)
# Pick the middle slice as the most representative view
middle_idx = volume.shape[0] // 2
slice_img = volume[middle_idx] # shape: (H, W)
ax.imshow(slice_img, cmap="gray")
ax.set_title(f"ID: {exam_id}\n{diagnosis}", fontsize=10, fontweight="bold",
color="#E8654C" if diagnosis == "ACL Tear" else "#4C9BE8")
ax.axis("off")
fig.suptitle("Sample Middle Slices – Sagittal Plane", fontsize=13, fontweight="bold")
plt.tight_layout()
plt.show()The following figure depicts four sample middle sagittal plane MRI slice images from the dataset, showing instances of both positive and negative ACL tear classes. As shown, each image has been labeled with the patient ID and the associated label, with “ACL Tear” instances depicted in red and “No ACL Tear” instances depicted in blue, based on their associated IDs 1 and 18, and 0 and 2, respectively. As depicted, the images have been shown in grayscale, consistent with standard MRI acquisition procedures. As shown, there are discernible structural differences in the ACL-tear and no ACL-tear classes, with the ACL-tear classes showing a lack of the characteristic linear hypointense feature representing the ligament, while the no ACL-tear classes show a well-defined ligamentous feature. However, discernible variability in knee orientation, contrast, and soft tissue intensity may be observed among the four sample images, showing the inherent variability in clinical data. As shown, the variability among the samples demonstrates the complexity of the classification task, warranting the use of deep learning architectures that are capable of learning complex hierarchical representations without relying on feature engineering.
3.2.3 Distribution of Slice Counts per Examination
Code
# Slice Count Histogram
slice_counts = []
for exam_id in labels_df["exam_id"]:
npy_path = os.path.join(DATA_DIR, "train", "sagittal", f"{exam_id:0>4}.npy")
volume = np.load(npy_path)
slice_counts.append(volume.shape[0]) # number of slices
labels_df["num_slices"] = slice_counts
plt.figure(figsize=(7, 4))
plt.hist(labels_df["num_slices"], bins=20, color="#4C9BE8", edgecolor="white")
plt.title("Distribution of Slice Counts (Sagittal Plane)", fontsize=13, fontweight="bold")
plt.xlabel("Number of Slices per Exam")
plt.ylabel("Number of Exams")
plt.tight_layout()
plt.show()
# Quick stats
print(f"Min slices : {labels_df['num_slices'].min()}")
print(f"Max slices : {labels_df['num_slices'].max()}")
print(f"Mean slices: {labels_df['num_slices'].mean():.1f}")The below histogram shows the distribution of the number of slices in the sagittal plane for each MRI exam. The horizontal axis shows the number of slices in each exam. This varies from around 15 slices to around 50 slices. The vertical axis shows the number of exams. The distribution of slices in the sagittal plane appears to be approximately unimodal. The main peak occurs at around 25 slices per exam, with approximately 150 exams. There is a smaller peak around 30-35 slices per exam, with around 100-110 exams. The distribution of slices in the sagittal plane appears skewed to the right. This means that a long tail of the distribution goes towards a maximum of 50 slices per exam. Here, the number of exams is less than 5. This variation in the number of slices in the sagittal plane arises due to the inherent variation in clinical images. This variation arises because different imaging protocols are followed in different hospitals and institutions. This figure is of great importance as it guides the way in which the volumetric inputs are standardized during the pre-processing stage.
3.2.4 Pixel Intensity
Code
# Pixel Intensity Histogram
# Only load a small sample to avoid slow runtimes
n_sample = 10
sample_ids = labels_df["exam_id"].sample(n=n_sample, random_state=42).values
all_pixels = []
for exam_id in sample_ids:
npy_path = os.path.join(DATA_DIR, "train", "sagittal", f"{exam_id:0>4}.npy")
volume = np.load(npy_path).astype(np.float32)
# Flatten volume and take a random subset of pixels (for speed)
flat = volume.ravel()
sample_px = flat[np.random.choice(len(flat), size=3000, replace=False)]
all_pixels.append(sample_px)
all_pixels = np.concatenate(all_pixels)
plt.figure(figsize=(7, 4))
plt.hist(all_pixels, bins=60, color="#E8A84C", edgecolor="white")
plt.title(f"Pixel Intensity Distribution (sample of {n_sample} exams)",
fontsize=13, fontweight="bold")
plt.xlabel("Pixel Intensity Value")
plt.ylabel("Frequency")
plt.tight_layout()
plt.show()
print(f"Intensity range: {all_pixels.min():.1f} – {all_pixels.max():.1f}")
print(f"Mean intensity : {all_pixels.mean():.1f}")The histogram in the below image demonstrates the pixel intensity values for a random selection of ten images. The horizontal axis of the histogram ranges from values lower than 0 up to 255. The vertical axis of the histogram corresponds to the values’ frequencies. The values are highly skewed towards the right. The highest frequency values are concentrated in the lower intensity values and reach a peak of around 20-25. The frequency value in this region approaches 2000. After this peak, the values decrease gradually as the intensity increases. This demonstrates that the pixels in an MRI image are not uniformly distributed. The presence of a secondary peak of around 300 occurs when the intensity reaches its maximum value of 255. This could be due to the presence of pixels in images of joint effusion. This figure is useful as it justifies the normalization of all the images before they are fed into the model for training.
3.3 Modeling and Results
3.3.1 Data Preprocessing:
Our data preprocessing focused on preparing 3D MRI volumes for input into a deep learning model. The dataset consisted of knee MRI scans stored as NumPy arrays, along with three associated labels for each exam: ACL tear, meniscus tear, and abnormality.
Each MRI scan is a 3D volume composed of multiple 2D slices. Since the number of slices varies across exams, we standardized all inputs to a fixed depth of 32 slices. Volumes with more than 32 slices were truncated, while those with fewer slices were padded with zeros. This ensured consistent input dimensions for the model.
We implemented a 3D convolutional neural network using a pretrained 3D ResNet-18 architecture. The pretrained model was loaded and the final fully connected layer of the network was modified to output three values corresponding to the three prediction tasks: ACL tear, meniscus tear, and abnormality. Since each condition is independent, the model performs multi-label classification.
Code
# Code used for modeling
train_dataset = MRNetMultiPlaneDataset(DATA_DIR, "train", augment=True)
valid_dataset = MRNetMultiPlaneDataset(DATA_DIR, "valid", augment=False)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False)
3.3.2 3D CNN Model Architecture:
We implemented a custom 3D convolutional neural network with residual connections, inspired by 3D ResNet architectures. The model consists of four residual blocks with increasing channels, followed by global average pooling. A learnable plane-attention mechanism was applied to fuse features from sagittal, coronal, and axial planes. The final fully connected layer outputs three values corresponding to the three prediction tasks: ACL tear, meniscus tear, and abnormality.
Code
model = Improved3DCNN().to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([3,2,1]).to(device))
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
3.3.3 Model Training
The model was trained using batches of multi-plane MRI volumes. Binary Cross-Entropy loss with logits was used to accommodate multi-label classification, and class weights were applied to mitigate label imbalance. The Adam optimizer updated the network weights. Mixed-precision training was optionally used to reduce memory consumption.
Code
epochs = 20
train_losses = []
for epoch in range(epochs):
model.train()
total_loss = 0
for sag, cor, axi, labels in train_loader:
sag, cor, axi, labels = sag.to(device), cor.to(device), axi.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(sag, cor, axi)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
epoch_loss = total_loss / len(train_loader)
train_losses.append(epoch_loss)
print(f"Epoch {epoch+1}: {epoch_loss:.4f}")
3.4 Results
3.4.1 Training Performance
The training loss consistently decreased over 20 epochs, indicating the model successfully learned meaningful patterns from the MRI data. Early epochs showed rapid improvement, followed by slower convergence in later epochs. Minor fluctuations may indicate slight overfitting, but overall convergence was stable.
Code
plt.figure()
plt.plot(train_losses, marker='o')
plt.title("Training Loss Curve")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid()
plt.show()3.4.2 ROC Curves
The Receiver Operating Characteristic (ROC) curves were plotted to evaluate the model’s ability to distinguish between positive and negative cases across the three classification tasks: ACL tear, meniscus tear, and abnormality detection.
Code
plt.figure()
for i, name in enumerate(labels_names):
fpr, tpr, _ = roc_curve(all_labels[:, i], all_probs[:, i])
plt.plot(fpr, tpr, label=f"{name} (AUC={auc(fpr,tpr):.3f})")
plt.plot([0,1],[0,1],'k--')
plt.title("ROC Curves")
plt.xlabel("FPR")
plt.ylabel("TPR")
plt.legend()
plt.grid()
plt.show()The Area Under the Curve (AUC) values obtained are as follows:
ACL: 0.793 Meniscus: 0.767 Abnormal: 0.806
These results indicate that the model achieves good discriminative performance across all classes, with AUC values close to 0.8.
Among the three tasks, abnormality detection shows the highest performance (AUC = 0.806), suggesting that the model is most effective at identifying general abnormalities in knee MRI scans. The ACL classification also performs strongly (AUC = 0.793), while meniscus tear detection shows slightly lower performance (AUC = 0.767), likely due to more subtle structural variations and higher inter-class similarity.
3.4.3 Confusion Matrix Heatmaps
Confusion matrices were used to evaluate the classification performance of the model for each task: ACL tear, meniscus tear, and abnormality detection. Each matrix presents the number of true negatives (TN), false positives (FP), false negatives (FN), and true positives (TP).
Code
fig, axes = plt.subplots(1,3,figsize=(15,4))
for i, name in enumerate(labels_names):
cm = confusion_matrix(all_labels[:,i], all_preds[:,i])
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", ax=axes[i])
axes[i].set_title(name)
plt.tight_layout()
plt.show()The model demonstrates strong detection capability, particularly for abnormalities and meniscus tears.
ACL detection is moderate, with room for improvement in reducing both false positives and false negatives.
The model tends to favor higher sensitivity (detecting positives) at the cost of lower specificity (more false positives), especially for the meniscus class.
This behavior is consistent with the applied class weighting and thresholding strategy, which prioritizes capturing positive cases.
3.4.4 Probability Distribution Plots
- ACL:
The model shows good discriminative power, with clear separation between the peaks for positive and negative cases and only a moderate amount of overlap in the middle probability range.
- Meniscus:
This is the most challenging category, as seen by the significant overlap between the two distributions; the model is frequently assigning high probabilities to negative cases, which will likely lead to a higher rate of False Positives.
- Abnormal: The model is highly confident and accurate here, with the vast majority of positive cases clustered extremely close to a probability of \(1.0\), indicating it is very effective at identifying general knee pathologies.
Code
3.3.4 Summary of the evaluation metrics:
for i, name in enumerate(labels_names):
plt.figure()
plt.hist(all_probs[:,i][all_labels[:,i]==1], bins=20, alpha=0.6, label="Positive")
plt.hist(all_probs[:,i][all_labels[:,i]==0], bins=20, alpha=0.6, label="Negative")
plt.title(f"{name} Probability Distribution")
plt.legend()
plt.show()3.4.5 Plane Attention Weights
The learned plane-attention weights indicate the relative importance of sagittal, coronal, and axial planes for the predictions.
Code
weights = torch.softmax(model.plane_weights, dim=0).cpu().detach().numpy()
plt.figure()
plt.bar(["Sagittal","Coronal","Axial"], weights)
plt.title("Plane Attention Weights")
plt.ylabel("Importance")
plt.show()Axial has the Highest Importance (~0.36): The model relies most heavily on the Axial plane. This suggests that, for the specific condition being predicted, the Axial view contains the most distinct features or “evidence” for the model’s decision-making process.
Sagittal is a Close Second (~0.35): The importance of the Sagittal plane is nearly equal to the Axial plane. This indicates that the model finds significant diagnostic value in both of these views.
Coronal has the Lowest Importance (~0.29): While still contributing significantly, the Coronal plane is weighted the least. This implies that the features found in the Coronal view are either more redundant or slightly less informative for the final classification than the other two.
3.4.6 Validation Metrics
After training, the model was evaluated on the validation dataset. Predictions were converted into probabilities using a sigmoid function and then thresholded to produce binary outputs. Sensitivity, specificity, accuracy, and AUC-ROC were calculated for each condition.
Code
model.eval()
all_labels, all_probs, all_preds = [], [], []
with torch.no_grad():
for sag, cor, axi, labels in valid_loader:
sag, cor, axi = sag.to(device), cor.to(device), axi.to(device)
probs = torch.sigmoid(model(sag, cor, axi)).cpu().numpy()
preds = (probs > [0.5,0.5,0.7]).astype(int) # adjustable thresholds
all_labels.extend(labels.numpy())
all_probs.extend(probs)
all_preds.extend(preds)
all_labels = np.array(all_labels)
all_probs = np.array(all_probs)
all_preds = np.array(all_preds)
labels_names = ["ACL","Meniscus","Abnormal"]3.4.7 Summary of the evaluation metrics:
- ACL:
Moderate performance with balanced sensitivity (0.630) and specificity (0.697), and good discriminative ability (AUC 0.793).
- Meniscus:
Very high sensitivity (0.923) but low specificity (0.382), indicating overprediction of meniscus tears. AUC is 0.767
- Abnormal:
Very high sensitivity (0.926) but low specificity (0.400), suggesting the model frequently flags normal cases as abnormal. AUC is 0.806.
These results suggest that the model detects positive cases effectively but tends to overpredict meniscus tears and abnormalities. Threshold tuning, class balancing, or loss function adjustments may improve specificity and overall reliability for clinical application.