diff --git a/config.yaml b/config.yaml index 477d27f..4110c17 100644 --- a/config.yaml +++ b/config.yaml @@ -1,88 +1,94 @@ -# Model configuration +# Model Configurations model: - type: 'vit3d' # Keep this as vit3d for compatibility + # Shared settings num_labels: 3 - freeze_layers: true input_size: 224 - patch_size: 16 # Added this for the new architecture dropout_rate: 0.1 - architecture: - hidden_size: 768 - num_attention_heads: 12 - intermediate_size: 3072 - pretrained_model: 'google/vit-base-patch16-224-in21k' - freeze_last_n_layers: 4 - medical_preprocessing: - enable: true - kernel_sizes: [7, 5, 1] - channels: [16, 32, 3] + freeze_layers: true + + # Model specific settings + vit3d: + type: 'vit3d' + patch_sizes: [8, 16, 32] # Different patch sizes to try + pretrained: 'google/vit-base-patch16-224-in21k' + + vit2d: + type: 'vit2d' + patch_sizes: [8, 16, 32] + pretrained: 'google/vit-base-patch16-224-in21k' + slice_mode: 'center' # or 'average' + + cnn3d: + type: 'cnn3d' + patch_sizes: [8, 16, 32] + pretrained: 'resnet50' # Will use inflated ResNet + channels: [64, 128, 256, 512] # Dataset configuration dataset: path: './adni' - batch_size: 16 # Increased batch size + batch_size: 8 val_ratio: 0.15 test_ratio: 0.15 - input_size: 224 preprocessing: voxel_spacing: [1.5, 1.5, 1.5] orientation: 'RAS' intensity_norm: true foreground_crop: true crop_margin: 10 - medical_augmentation: + # 2D specific + slice_selection: + method: 'center' # or 'average' + num_slices: 5 + # Data augmentation + augmentation: enable: true rotation_range: [-10, 10] - scale_range: [0.95, 1.05] + flip_probability: 0.5 intensity_shift: [-0.1, 0.1] - gamma_range: [0.9, 1.1] # Training configuration training: epochs: 50 - device: 'cuda' # will fall back to CPU if not available + device: 'cuda' # will fall back to CPU seed: 42 - learning_rate: 0.0001 + base_learning_rate: 0.0001 + layer_specific_lrs: + pretrained: 0.00001 + new: 0.0001 optimizer: type: 'adamw' weight_decay: 0.01 - layer_specific_lrs: # Different learning rates for different components - vit_backbone: 0.00001 - medical_preprocessing: 0.0001 - feature_enhance: 0.0001 - classifier: 0.0001 scheduler: type: 'cosine' - T_0: 5 warmup_epochs: 2 min_lr: 1.0e-6 early_stopping: - enable: true patience: 10 min_delta: 0.001 - loss: - type: 'cross_entropy' - label_smoothing: 0.1 -# Paths configuration +# Experiment tracking +experiment: + name: 'alzheimer_detection_comparison' + models_to_run: ['vit3d', 'vit2d', 'cnn3d'] + metrics: + - accuracy + - precision + - recall + - f1 + - confusion_matrix + save_predictions: true + save_model: true + plot_results: true + +# Paths and logging paths: output_dir: './output' log_dir: './logs' checkpoint_dir: './checkpoints' - data: - raw: './adni/raw' - processed: './adni/processed' - metadata: './metadata/adni.csv' + results_dir: './results' -# Logging configuration logging: level: 'INFO' save_to_file: true - log_frequency: 10 - -# Memory management for large medical images -memory: - clear_cache_frequency: 1 - pin_memory: false # For CPU training - num_workers: 0 # For CPU training - prefetch_factor: 2 \ No newline at end of file + log_frequency: 10 \ No newline at end of file diff --git a/models/architectures/vit3d.py b/models/architectures/vit3d.py index a4dddc3..76492ff 100644 --- a/models/architectures/vit3d.py +++ b/models/architectures/vit3d.py @@ -1,5 +1,5 @@ """ -3D Vision Transformer model for Alzheimer's detection with fixed initialization. +Memory-efficient 3D Vision Transformer for Alzheimer's detection. """ import torch @@ -7,21 +7,11 @@ from transformers import ViTModel import logging import numpy as np -import math from torch.nn import functional as F logger = logging.getLogger(__name__) -class LayerNormWithFixedInit(nn.LayerNorm): - """Custom LayerNorm with fixed initialization.""" - def reset_parameters(self) -> None: - if self.elementwise_affine: - nn.init.ones_(self.weight) - nn.init.zeros_(self.bias) - class ViT3D(nn.Module): - """3D Vision Transformer optimized for medical image analysis.""" - def __init__( self, num_labels: int, @@ -32,13 +22,10 @@ def __init__( ): super().__init__() - # Validate input parameters - assert input_size % patch_size == 0, "Input size must be divisible by patch size" self.input_size = input_size self.patch_size = patch_size self.num_patches = (input_size // patch_size) ** 3 - # Log initialization parameters logger.info(f"Initializing ViT3D with:") logger.info(f"- Input size: {input_size}x{input_size}x{input_size}") logger.info(f"- Patch size: {patch_size}x{patch_size}x{patch_size}") @@ -50,63 +37,36 @@ def __init__( 'google/vit-base-patch16-224-in21k', add_pooling_layer=False ) - hidden_size = self.vit.config.hidden_size # 768 + hidden_size = self.vit.config.hidden_size - # Medical image specific preprocessing + # Efficient preprocessing self.preprocess = nn.Sequential( - nn.Conv2d(1, 16, kernel_size=7, padding=3), - nn.InstanceNorm2d(16), - nn.GELU(), - nn.Conv2d(16, 32, kernel_size=5, padding=2), - nn.InstanceNorm2d(32), - nn.GELU(), - nn.Conv2d(32, 3, kernel_size=1), - nn.InstanceNorm2d(3) + nn.Conv2d(1, 3, 3, padding=1, bias=False), # Simpler conv + nn.BatchNorm2d(3), # BatchNorm uses less memory + nn.ReLU() # ReLU is more memory efficient than GELU ) - # Slice selection module with attention - self.slice_attention = nn.Sequential( - nn.Conv3d(1, 16, kernel_size=3, padding=1), - nn.InstanceNorm3d(16), - nn.GELU(), - nn.Conv3d(16, 8, kernel_size=3, padding=1), - nn.InstanceNorm3d(8), - nn.GELU(), - nn.Conv3d(8, 3, kernel_size=1), + # Memory-efficient slice selection + self.slice_selection = nn.Sequential( + nn.Conv3d(1, 8, 1, bias=False), # 1x1x1 conv uses less memory + nn.BatchNorm3d(8), + nn.ReLU(), + nn.Conv3d(8, 3, 1, bias=False), nn.Softmax(dim=2) ) - # Feature enhancement module - self.feature_enhance = nn.Sequential( - nn.Linear(hidden_size, hidden_size * 2), - LayerNormWithFixedInit(hidden_size * 2), - nn.GELU(), - nn.Dropout(dropout_rate), - nn.Linear(hidden_size * 2, hidden_size) - ) - - # View fusion module + # Feature fusion (simplified) self.fusion = nn.Sequential( - nn.Linear(hidden_size * 3, hidden_size * 2), - LayerNormWithFixedInit(hidden_size * 2), - nn.GELU(), - nn.Dropout(dropout_rate), - nn.Linear(hidden_size * 2, hidden_size), - LayerNormWithFixedInit(hidden_size) + nn.Linear(hidden_size * 3, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(), + nn.Dropout(dropout_rate) ) # Classification head - self.classifier = nn.Sequential( - nn.Linear(hidden_size, hidden_size // 2), - nn.GELU(), - nn.Dropout(dropout_rate), - nn.Linear(hidden_size // 2, num_labels) - ) + self.classifier = nn.Linear(hidden_size, num_labels) - # Initialize weights explicitly self._init_weights() - - # Freeze layers selectively if freeze_layers: self._freeze_layers() @@ -117,89 +77,73 @@ def __init__( logger.info(f"Trainable parameters: {trainable_params:,}") def _init_weights(self): - """Initialize weights explicitly.""" - def init_module(m): + """Initialize weights.""" + for m in self.modules(): if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)): nn.init.kaiming_normal_(m.weight, mode='fan_out') - if m.bias is not None: + if getattr(m, 'bias', None) is not None: nn.init.zeros_(m.bias) - elif isinstance(m, (nn.InstanceNorm2d, nn.InstanceNorm3d)): - if m.weight is not None: - nn.init.ones_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, LayerNormWithFixedInit): - m.reset_parameters() - - self.apply(init_module) + elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) def _freeze_layers(self): - """Selective freezing for better transfer learning.""" - # Freeze early layers - for param in self.vit.embeddings.parameters(): + """Freeze pretrained layers.""" + self.vit.eval() # Set to eval mode to save memory + for param in self.vit.parameters(): param.requires_grad = False - # Only train the last 4 transformer layers - for layer in self.vit.encoder.layer[:-4]: - for param in layer.parameters(): - param.requires_grad = False - - def _get_attention_weighted_slices(self, x: torch.Tensor) -> tuple: - """Extract attention-weighted slices from volume.""" + @torch.no_grad() # Memory optimization + def _get_central_slices(self, x: torch.Tensor) -> tuple: + """Get central slices efficiently.""" B, C, D, H, W = x.shape - # Generate attention weights for each direction - attention_weights = self.slice_attention(x) + # Get central indices + d_mid = D // 2 + h_mid = H // 2 + w_mid = W // 2 - # Extract weighted slices for each view - d_center, h_center, w_center = D//2, H//2, W//2 - span = 3 # Consider slices around center + # Extract central slices directly + axial = x[:, :, d_mid].clone() # [B, C, H, W] + sagittal = x[:, :, :, h_mid].clone() # [B, C, D, W] + coronal = x[:, :, :, :, w_mid].clone() # [B, C, D, H] - # Compute weighted averages around central slices - axial_region = x[:, :, d_center-span:d_center+span+1] - sagittal_region = x[:, :, :, h_center-span:h_center+span+1] - coronal_region = x[:, :, :, :, w_center-span:w_center+span+1] - - axial = (axial_region * attention_weights[:, 0:1, d_center-span:d_center+span+1]).sum(dim=2) - sagittal = (sagittal_region * attention_weights[:, 1:2, :, h_center-span:h_center+span+1]).sum(dim=3) - coronal = (coronal_region * attention_weights[:, 2:3, :, :, w_center-span:w_center+span+1]).sum(dim=4) - - return axial, sagittal, coronal + return axial, sagittal.transpose(2, 3), coronal.transpose(2, 3) def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass with enhanced medical image processing.""" - # Extract attention-weighted slices - axial, sagittal, coronal = self._get_attention_weighted_slices(x) + """Memory-efficient forward pass.""" + # Get central slices (more memory efficient than attention) + axial, sagittal, coronal = self._get_central_slices(x) - # Process each view + # Process views in sequence to save memory view_features = [] for view in [axial, sagittal, coronal]: - # Medical image specific preprocessing + # Preprocess view = self.preprocess(view) - # Normalize to match pretrained model's distribution - view = (view - view.mean(dim=[2, 3], keepdim=True)) / (view.std(dim=[2, 3], keepdim=True) + 1e-6) - - # Pass through ViT - outputs = self.vit(pixel_values=view, return_dict=True) + # Basic normalization + view = F.normalize(view.flatten(2), dim=-1).reshape_as(view) - # Enhance features + # Get features + with torch.no_grad(): # Don't store gradients for ViT if frozen + outputs = self.vit(pixel_values=view, return_dict=True) features = outputs.last_hidden_state[:, 0] - enhanced = self.feature_enhance(features) + features - view_features.append(enhanced) + view_features.append(features) - # Combine view features - combined = torch.cat(view_features, dim=1) - fused = self.fusion(combined) + # Clear cache after each view + if hasattr(torch.cuda, 'empty_cache'): + torch.cuda.empty_cache() - # Final classification - output = self.classifier(fused) + # Combine features + x = torch.cat(view_features, dim=1) + x = self.fusion(x) + x = self.classifier(x) - return output + return x def create_model(config: dict) -> nn.Module: - """Create a ViT3D model from config.""" + """Create a memory-efficient ViT3D model.""" model = ViT3D( num_labels=config['model']['num_labels'], freeze_layers=config['model']['freeze_layers'],