-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix data loader and preprocessing visualisations and main
- Loading branch information
Showing
2 changed files
with
110 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,88 +1,94 @@ | ||
# Model configuration | ||
# Model Configurations | ||
model: | ||
type: 'vit3d' # Keep this as vit3d for compatibility | ||
# Shared settings | ||
num_labels: 3 | ||
freeze_layers: true | ||
input_size: 224 | ||
patch_size: 16 # Added this for the new architecture | ||
dropout_rate: 0.1 | ||
architecture: | ||
hidden_size: 768 | ||
num_attention_heads: 12 | ||
intermediate_size: 3072 | ||
pretrained_model: 'google/vit-base-patch16-224-in21k' | ||
freeze_last_n_layers: 4 | ||
medical_preprocessing: | ||
enable: true | ||
kernel_sizes: [7, 5, 1] | ||
channels: [16, 32, 3] | ||
freeze_layers: true | ||
|
||
# Model specific settings | ||
vit3d: | ||
type: 'vit3d' | ||
patch_sizes: [8, 16, 32] # Different patch sizes to try | ||
pretrained: 'google/vit-base-patch16-224-in21k' | ||
|
||
vit2d: | ||
type: 'vit2d' | ||
patch_sizes: [8, 16, 32] | ||
pretrained: 'google/vit-base-patch16-224-in21k' | ||
slice_mode: 'center' # or 'average' | ||
|
||
cnn3d: | ||
type: 'cnn3d' | ||
patch_sizes: [8, 16, 32] | ||
pretrained: 'resnet50' # Will use inflated ResNet | ||
channels: [64, 128, 256, 512] | ||
|
||
# Dataset configuration | ||
dataset: | ||
path: './adni' | ||
batch_size: 16 # Increased batch size | ||
batch_size: 8 | ||
val_ratio: 0.15 | ||
test_ratio: 0.15 | ||
input_size: 224 | ||
preprocessing: | ||
voxel_spacing: [1.5, 1.5, 1.5] | ||
orientation: 'RAS' | ||
intensity_norm: true | ||
foreground_crop: true | ||
crop_margin: 10 | ||
medical_augmentation: | ||
# 2D specific | ||
slice_selection: | ||
method: 'center' # or 'average' | ||
num_slices: 5 | ||
# Data augmentation | ||
augmentation: | ||
enable: true | ||
rotation_range: [-10, 10] | ||
scale_range: [0.95, 1.05] | ||
flip_probability: 0.5 | ||
intensity_shift: [-0.1, 0.1] | ||
gamma_range: [0.9, 1.1] | ||
|
||
# Training configuration | ||
training: | ||
epochs: 50 | ||
device: 'cuda' # will fall back to CPU if not available | ||
device: 'cuda' # will fall back to CPU | ||
seed: 42 | ||
learning_rate: 0.0001 | ||
base_learning_rate: 0.0001 | ||
layer_specific_lrs: | ||
pretrained: 0.00001 | ||
new: 0.0001 | ||
optimizer: | ||
type: 'adamw' | ||
weight_decay: 0.01 | ||
layer_specific_lrs: # Different learning rates for different components | ||
vit_backbone: 0.00001 | ||
medical_preprocessing: 0.0001 | ||
feature_enhance: 0.0001 | ||
classifier: 0.0001 | ||
scheduler: | ||
type: 'cosine' | ||
T_0: 5 | ||
warmup_epochs: 2 | ||
min_lr: 1.0e-6 | ||
early_stopping: | ||
enable: true | ||
patience: 10 | ||
min_delta: 0.001 | ||
loss: | ||
type: 'cross_entropy' | ||
label_smoothing: 0.1 | ||
|
||
# Paths configuration | ||
# Experiment tracking | ||
experiment: | ||
name: 'alzheimer_detection_comparison' | ||
models_to_run: ['vit3d', 'vit2d', 'cnn3d'] | ||
metrics: | ||
- accuracy | ||
- precision | ||
- recall | ||
- f1 | ||
- confusion_matrix | ||
save_predictions: true | ||
save_model: true | ||
plot_results: true | ||
|
||
# Paths and logging | ||
paths: | ||
output_dir: './output' | ||
log_dir: './logs' | ||
checkpoint_dir: './checkpoints' | ||
data: | ||
raw: './adni/raw' | ||
processed: './adni/processed' | ||
metadata: './metadata/adni.csv' | ||
results_dir: './results' | ||
|
||
# Logging configuration | ||
logging: | ||
level: 'INFO' | ||
save_to_file: true | ||
log_frequency: 10 | ||
|
||
# Memory management for large medical images | ||
memory: | ||
clear_cache_frequency: 1 | ||
pin_memory: false # For CPU training | ||
num_workers: 0 # For CPU training | ||
prefetch_factor: 2 | ||
log_frequency: 10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters