Multiple sclerosis (MS) is a chronic, often disabling disease that affects the central nervous system. It is characterized by the development of scar tissue (sclerosis) in place of the normal tissue component of the nervous system, interfering with the transmission of nerve impulses.
The Expanded Disability Scale (EDSS) is a scale designed to assess the levels of disability of people with MS; it ranges from 0, corresponding to a normal neurological exam, to 10. It includes intermediate and increasingly higher levels of disability. The score is obtained by adding the partial scores of the various functional systems related to nervous system activity (pyramidal, cerebellar, sphincter, etc.). This allows for easier assessment of the progression of the disease and allows for verification of the effectiveness of the current therapy.
In the remainder of this work, we will analyze the classification of patients' disability levels in two contexts:
-
Binary classification: Patients are grouped into two classes based on their EDSS score, namely the positive class, which includes all patients with an EDSS value less than or equal to 2.0, and the negative class, with all patients with an EDSS value greater than 2.0. A patient belonging to the positive class is one who has no significant lesions and minimal or no functional impairment, while the negative class includes patients with obvious neurological lesions and clinical symptoms.
-
Multiclass classification: EDSS scores were mapped into three categories: normal, mild, and severe. Specifically, scores from 0 to 2.0 were labeled as normal, indicating mild to no disability; Scores between 2.5 and 4.0 were labeled as mild, indicating patients with moderate impairment but preserved ambulation, while scores above 4.0 were labeled as severe, corresponding to individuals with significant motor or systemic dysfunction.
To address this task, we will develop and compare two types of models: one based on Convolutional Neural Networks (CNN) and one on Vision Transformer (ViT). Since we have a limited dataset and the intended use is clinical, we will focus on "lightweight" models that can be deployed in environments with limited computational resources. Furthermore, the use of complex models may not be suitable in an application context such as the medical one, where data is scarce due to high acquisition and annotation costs and ongoing patient privacy issues and administrative policies that impact data sharing.
The dataset employed in this work consists of T1, T2 and FLAIR cerebral MRI sequences. The dataset with pre-defined training, validation and test splits for each task is available under: edss-dataset.zip
.
For more information, refer to the documentation under the deliverables
folder.
The models' weights are available under weights.zip
. Note that only the best models were uploaded.
For more information, refer to the documentation under the deliverables
folder.
Verify you have Python installed on your machine. This project is compatible with Python 3.11
.
Run python --version
to check what python version you have installed.
If you do not have Python installed, please refer to the official Python Guide.
It's strongly recommended to create a virtual environment before proceeding with the installation.
We recommend you use pip
. Please refer to Creating a virtual environment for major information.
Clone the repository by running the following command in your terminal:
git clone https://github.com/angelonazzaro/EDSS-classifier.git
Before installating the requirements, make sure you have your virtual environment activated. If that's the case, you terminal should look something like: (name-of-your-venv) user@userpath
.
Install the requirements using pip
:
pip install -r requirements.txt
NOTE: If you have an Nvidia GPU, overwrite tensorflow==2.19.0
with 'tensorflow[and-cuda]==2.19.0'
. If you have an Apple Silicon Chip (M1, M2, M3 or M4), please also install tensorflow-metal
for MPS support.
NOTE: A Weight and Biases account is needed for experiment tracking.
To start training you need to execute the train.py
script. The script is executed from the command line:
python train.py --data_dir <path_to_data> [OPTIONS]
--data_dir
: Path to the dataset directory.--resize
: Size to resize the images to. A tuple of ints.
Argument | Type | Default | Description |
---|---|---|---|
--seed |
int | 42 | Random seed for reproducibility |
--epochs |
int | 10 | Number of training epochs |
--check_val_every_n_epochs |
int | 1 | Interval for running validation |
--batch_size |
int | 32 | Batch size |
--resize |
tuple | None | Resize dimensions (height, width) |
--include_augmented |
bool | False | Include augmented images |
--model_type |
str | CNN | Model type (CNN or ViT ) |
Argument | Type | Default | Description |
---|---|---|---|
CNN-specific | |||
--units |
int | 128 | Hidden units in the first dense layer |
--n_conv_layers |
int | 3 | Number of convolutional layers |
--n_dense_layers |
int | 2 | Number of dense layers in classifier |
ViT-specific | |||
--patch_size |
int | None | Patch size |
--d_model |
int | None | Embedding dimension |
--mlp_dim |
int | None | MLP hidden units |
--num_heads |
int | 8 | Number of attention heads |
--n_layers |
int | 6 | Number of encoder layers |
--channels |
int | 1 | Number of channels |
Argument | Type | Default | Description |
---|---|---|---|
--task |
str | binary | Task type (binary or multi-class ) |
--alpha |
float | 0.25 | Alpha for Focal Loss |
--gamma |
float | 2.0 | Gamma for Focal Loss |
Argument | Type | Default | Description |
---|---|---|---|
--lr |
float | 1e-3 | Learning rate |
--weight_decay |
float | 1e-5 | Weight decay for AdamW |
Argument | Type | Default | Description |
---|---|---|---|
--patience |
int | 2 | Early stopping patience |
--min_delta |
float | 0.0001 | Minimum improvement in monitored metric |
--monitor |
str | val_loss | Metric to monitor |
--checkpoint_dir |
str | experiments | Directory to save checkpoints |
--save_top_k |
int | 1 | Max checkpoints to keep |
--experiment_name |
str | None | Experiment name (W&B run name if None) |
Argument | Type | Default | Description |
---|---|---|---|
--project |
str | deep-learning | W&B project name |
--tune_hyperparameters |
bool | False | Enable W&B hyperparameter sweep |
--sweep_config |
str | sweep.yaml | Path to sweep configuration file |
--sweep_count |
int | 10 | Number of sweep runs |
python train.py --data_dir ./data --model_type CNN --resize 256,256 --epochs 20 --batch_size 64
python train.py --data_dir ./data --model_type ViT --patch_size 16 --d_model 128 --mlp_dim 256 --epochs 30
python train.py --data_dir ./data --task binary --include_augmented --resize 224,224
python train.py --data_dir ./data --tune_hyperparameters --sweep_config sweep.yaml --sweep_count 20
- If
--experiment_name
is not provided, the W&B run name will be used for saving checkpoints. - For multi-class classification,
CategoricalFocalCrossentropy
is used; for binary classification,BinaryCrossentropy
is used.
To start training you need to execute the test.py
script. The script is executed from the command line:
python test.py --data_dir <path_to_data> [OPTIONS]
Argument | Type | Default | Description |
---|---|---|---|
--seed |
int | 42 | Random seed for reproducibility |
--batch_size |
int | 32 | Batch size |
--resize |
tuple | None | Resize dimensions (height, width) |
--model_type |
str | CNN | Model type (CNN or ViT ) |
Argument | Type | Default | Description |
---|---|---|---|
--data_dir |
str | Required | Path to dataset directory |
--modality |
str list | T1 | MRI modality (T1 , T2 , FLAIR ) |
--task |
str | binary | Task type (binary or multi-class ) |
Argument | Type | Default | Description |
---|---|---|---|
--model_name |
str | Required | Model name |
--checkpoint_path |
str | Required | Path to model checkpoint |
Argument | Type | Default | Description |
---|---|---|---|
--results_dir |
str | results | Directory to save results and confusion matrices |
When executed, the script will:
- Load the model checkpoint.
- Run inference on the test dataset.
- Compute accuracy, precision, recall, and F1 score.
- Save results in a CSV file:
results/<task>_results.csv
- Save confusion matrix plots in:
results/<task>/<model_name>/cm_<modality>.png
python test.py --data_dir ./data --model_name cnn_model_v1 --checkpoint_path ./experiments/cnn_model_v1/best_model.h5 --task binary --modality T1 --resize 256,256
python test.py --data_dir ./data --model_name vit_model_v1 --checkpoint_path ./experiments/vit_model_v1/best_model.h5 --task multi-class --modality FLAIR --resize 224,224