Skip to content

Commit

Permalink
fix data loader and preprocessing visualisations and main
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyMDH committed Nov 12, 2024
1 parent 218f865 commit 0dfad06
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 111 deletions.
110 changes: 64 additions & 46 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
# Model configuration
model:
type: 'multiview_vit' # Changed to multi-view approach
type: 'multiview_vit'
num_labels: 3
freeze_layers: true
input_size: 224
model_name: 'google/vit-large-patch16-224-in21k' # Using larger model
dropout_rate: 0.2 # Increased for better regularization
use_attention: true # Enable attention-based slice selection
use_middle_slices: true # For faster training
use_base_model: true # Use base instead of large model
dropout_rate: 0.1
architecture:
fusion_hidden_size: 768 # Matches base model size
use_simplified_fusion: true # Use simpler fusion for speed

# Dataset configuration
dataset:
path: './adni'
batch_size: 8
batch_size: 16 # Increased for better utilization
val_ratio: 0.15
test_ratio: 0.15
input_size: 224
Expand All @@ -21,45 +24,55 @@ dataset:
intensity_norm: true
foreground_crop: true
crop_margin: 10
# Added preprocessing specific to multi-view
slice_selection:
method: 'attention' # Can be 'center', 'attention', or 'multi'
num_slices: 5 # Number of slices to consider for attention
augmentation:
enable: true
rotation_range: 10
contrast_range: [0.9, 1.1]
brightness_range: [0.9, 1.1]
noise_std: 0.02
# Speed optimizations
cache_processed: true # Cache preprocessed data
use_memory_cache: true # Keep frequently used data in memory
num_workers: 0 # For CPU training
pin_memory: false # For CPU training
prefetch_factor: 2
persistent_workers: false

# Training configuration
training:
epochs: 100 # Increased epochs since training is more stable
epochs: 50
device: 'cuda' # will fall back to CPU if not available
seed: 42
learning_rate: 0.0001
batch_size: 16
learning_rate: 0.001 # Increased for faster convergence

# Optimization settings
optimization:
compile_model: true # Use torch.compile for speedup
gradient_accumulation_steps: 4 # Effective batch size of 64
use_amp: false # Automatic Mixed Precision (set true if using GPU)
use_gradient_clipping: true
max_gradient_norm: 1.0

optimizer:
type: 'adamw'
weight_decay: 0.01
layer_specific_lrs: # Different learning rates for different parts
vit: 0.00001 # Lower LR for pretrained parts
attention: 0.0001
fusion: 0.0001
classifier: 0.0001
layer_specific_lrs: # Different learning rates for different components
vit: 0.0001
fusion: 0.001
classifier: 0.001

scheduler:
type: 'cosine'
T_0: 10 # Increased period
warmup_epochs: 5 # More warmup epochs
T_0: 5
warmup_epochs: 2
min_lr: 1.0e-6

# Early stopping
early_stopping:
enable: true
patience: 10
patience: 5
min_delta: 0.001
gradient_clipping:
enable: true
max_norm: 1.0
label_smoothing: 0.1 # Added label smoothing
mixed_precision: true # Enable mixed precision training

# Loss settings
loss:
type: 'cross_entropy'
label_smoothing: 0.1
class_weights: null # Can be set if classes are imbalanced

# Paths configuration
paths:
Expand All @@ -69,39 +82,44 @@ paths:
data:
raw: './adni/raw'
processed: './adni/processed'
cache: './adni/cache' # For preprocessed data
metadata: './metadata/adni.csv'

# Logging configuration
logging:
level: 'INFO'
save_to_file: true
log_frequency: 10
wandb: # Added W&B integration
log_frequency: 10 # Log every N batches
tensorboard:
enable: true
project: 'alzheimer_detection'
tags: ['multiview', 'vit-large']
log_images: true # Log attention maps and predictions
log_dir: './runs'
save_model_frequency: 5 # Save model every N epochs

# Performance monitoring
monitoring:
track_memory: true
profile_execution: false # Set to true for debugging performance
log_gpu_stats: false
batch_timing: true

# Validation configuration
validation:
frequency: 1 # Validate every epoch
frequency: 1 # Validate every N epochs
metrics:
- accuracy
- precision
- recall
- f1
- confusion_matrix
save_predictions: true
save_attention_maps: true # Save attention maps for visualization

# Testing configuration
testing:
save_predictions: true
ensemble:
enable: true
num_models: 3 # Use ensemble of last 3 checkpoints
visualization:
enable: true
attention_maps: true
confusion_matrix: true
misclassified_samples: true
save_confusion_matrix: true
ensemble_predictions: false

# Memory management
memory:
clear_cache_frequency: 1 # Clear CUDA cache every N batches
optimize_memory: true
pin_memory: false # Since using CPU
123 changes: 58 additions & 65 deletions models/architectures/vit3d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Multi-view Vision Transformer for Alzheimer's detection using transfer learning.
Optimized Multi-view Vision Transformer for faster training and better accuracy.
"""

import torch
Expand All @@ -16,63 +16,49 @@ def __init__(
num_labels: int,
freeze_layers: bool = True,
input_size: int = 224,
dropout_rate: float = 0.1
dropout_rate: float = 0.1,
use_middle_slices: bool = True # Added option for faster training
):
super().__init__()

self.input_size = input_size
self.use_middle_slices = use_middle_slices

logger.info(f"Initializing MultiViewViT with:")
logger.info(f"Initializing Optimized MultiViewViT with:")
logger.info(f"- Input size: {input_size}x{input_size}")
logger.info(f"- Views: Axial, Sagittal, Coronal")
logger.info(f"- Using middle slices: {use_middle_slices}")
logger.info(f"- Number of classes: {num_labels}")

# Load pre-trained ViT (now using a larger variant)
# Load pre-trained ViT (using base model for faster training)
self.vit = ViTModel.from_pretrained(
'google/vit-large-patch16-224-in21k', # Using larger model for better features
'google/vit-base-patch16-224-in21k', # Using base model for speed
add_pooling_layer=False
)
hidden_size = self.vit.config.hidden_size # 1024 for large model
hidden_size = self.vit.config.hidden_size # 768 for base model

# Channel projection to convert 1-channel MRI to 3-channel input
# Efficient channel projection
self.channel_proj = nn.Sequential(
nn.Conv2d(1, 3, 1, 1),
nn.Conv2d(1, 3, 1, 1, bias=False), # Removed bias for speed
nn.BatchNorm2d(3),
nn.GELU()
nn.ReLU() # Using ReLU instead of GELU for speed
)

# Adaptive slice selection
self.slice_attention = nn.Sequential(
nn.Conv3d(1, 16, kernel_size=3, padding=1),
nn.GELU(),
nn.Conv3d(16, 3, kernel_size=1), # Output 3 attention maps for 3 views
nn.Softmax(dim=2) # Softmax along depth dimension
)

# Feature fusion module
# Simplified feature fusion
self.fusion = nn.Sequential(
nn.Linear(hidden_size * 3, hidden_size * 2),
nn.LayerNorm(hidden_size * 2),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_size * 2, hidden_size),
nn.Linear(hidden_size * 3, hidden_size),
nn.LayerNorm(hidden_size),
nn.GELU(),
nn.ReLU(),
nn.Dropout(dropout_rate)
)

# Classification head
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_size // 2, num_labels)
)
self.classifier = nn.Linear(hidden_size, num_labels)

# Initialize weights
self._init_weights()

# Freeze pre-trained layers if specified
# Freeze and optimize pre-trained layers
if freeze_layers:
self._freeze_pretrained_layers()

Expand All @@ -83,57 +69,63 @@ def __init__(
logger.info(f"Trainable parameters: {trainable_params:,}")

def _init_weights(self):
"""Initialize custom layers."""
for m in [self.channel_proj, self.fusion, self.classifier]:
"""Initialize custom layers with simple initialization."""
for m in self.modules():
if isinstance(m, (nn.Linear, nn.Conv2d)):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.xavier_uniform_(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)

def _freeze_pretrained_layers(self):
"""Freeze pretrained ViT layers."""
"""Freeze pretrained ViT layers and optimize memory."""
self.vit.eval() # Set to eval mode for inference
for param in self.vit.parameters():
param.requires_grad = False

def _get_weighted_slices(self, x):
"""Get attention-weighted slices from each view."""
def _get_middle_slices(self, x):
"""Extract middle slices from each view efficiently."""
B, C, D, H, W = x.shape

# Generate attention weights
attention = self.slice_attention(x) # [B, 3, D, H, W]
# Get middle indices
d_mid = D // 2
h_mid = H // 2
w_mid = W // 2

# Extract weighted slices for each view
axial = (x[:, :, :, :, :] * attention[:, 0:1, :, :, :]).sum(dim=2)
sagittal = (x[:, :, :, :, :] * attention[:, 1:2, :, :, :]).sum(dim=3)
coronal = (x[:, :, :, :, :] * attention[:, 2:3, :, :, :]).sum(dim=4)
# Extract slices efficiently
axial = x[:, :, d_mid, :, :]
sagittal = x[:, :, :, h_mid, :]
coronal = x[:, :, :, :, w_mid]

return axial, sagittal, coronal

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass using attention-weighted views."""
B = x.shape[0]
@torch.no_grad() # Disable gradients for efficiency
def _normalize_view(self, view):
"""Normalize view efficiently."""
view = F.interpolate(view, size=(224, 224), mode='bilinear', align_corners=False)
view = (view - view.mean(dim=[2, 3], keepdim=True)) / (view.std(dim=[2, 3], keepdim=True) + 1e-6)
return view

# Get weighted slices from each view
axial, sagittal, coronal = self._get_weighted_slices(x)

# Process each view
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Optimized forward pass."""
if self.use_middle_slices:
axial, sagittal, coronal = self._get_middle_slices(x)
else:
# Use your original _get_weighted_slices method here
axial, sagittal, coronal = self._get_weighted_slices(x)

# Process each view efficiently
view_features = []
for view in [axial, sagittal, coronal]:
# Project to 3 channels and normalize
# Project and normalize efficiently
view = self.channel_proj(view)
view = self._normalize_view(view)

# Normalize to ImageNet range
view = F.interpolate(view, size=(224, 224))
view = (view - view.mean()) / view.std()

# Pass through ViT
outputs = self.vit(pixel_values=view, return_dict=True)
view_features.append(outputs.last_hidden_state[:, 0]) # Use CLS token
# Get ViT features
with torch.set_grad_enabled(not self.vit.training): # Only compute gradients if not frozen
outputs = self.vit(pixel_values=view, return_dict=True)
view_features.append(outputs.last_hidden_state[:, 0])

# Combine view features
# Combine features efficiently
x = torch.cat(view_features, dim=1)
x = self.fusion(x)
x = self.classifier(x)
Expand All @@ -142,11 +134,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


def create_model(config: dict) -> nn.Module:
"""Create a multi-view ViT model from config."""
"""Create an optimized multi-view ViT model from config."""
model = MultiViewViT(
num_labels=config['model']['num_labels'],
freeze_layers=config['model'].get('freeze_layers', True),
input_size=config['model'].get('input_size', 224),
dropout_rate=config['model'].get('dropout_rate', 0.1)
dropout_rate=config['model'].get('dropout_rate', 0.1),
use_middle_slices=config['model'].get('use_middle_slices', True)
)
return model

0 comments on commit 0dfad06

Please sign in to comment.