Skip to content

Commit

Permalink
fix data loader and preprocessing visualisations and main
Browse files Browse the repository at this point in the history
  • Loading branch information
AndyMDH committed Nov 12, 2024
1 parent a9b44f9 commit 7a58bdd
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 160 deletions.
96 changes: 51 additions & 45 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,88 +1,94 @@
# Model configuration
# Model Configurations
model:
type: 'vit3d' # Keep this as vit3d for compatibility
# Shared settings
num_labels: 3
freeze_layers: true
input_size: 224
patch_size: 16 # Added this for the new architecture
dropout_rate: 0.1
architecture:
hidden_size: 768
num_attention_heads: 12
intermediate_size: 3072
pretrained_model: 'google/vit-base-patch16-224-in21k'
freeze_last_n_layers: 4
medical_preprocessing:
enable: true
kernel_sizes: [7, 5, 1]
channels: [16, 32, 3]
freeze_layers: true

# Model specific settings
vit3d:
type: 'vit3d'
patch_sizes: [8, 16, 32] # Different patch sizes to try
pretrained: 'google/vit-base-patch16-224-in21k'

vit2d:
type: 'vit2d'
patch_sizes: [8, 16, 32]
pretrained: 'google/vit-base-patch16-224-in21k'
slice_mode: 'center' # or 'average'

cnn3d:
type: 'cnn3d'
patch_sizes: [8, 16, 32]
pretrained: 'resnet50' # Will use inflated ResNet
channels: [64, 128, 256, 512]

# Dataset configuration
dataset:
path: './adni'
batch_size: 16 # Increased batch size
batch_size: 8
val_ratio: 0.15
test_ratio: 0.15
input_size: 224
preprocessing:
voxel_spacing: [1.5, 1.5, 1.5]
orientation: 'RAS'
intensity_norm: true
foreground_crop: true
crop_margin: 10
medical_augmentation:
# 2D specific
slice_selection:
method: 'center' # or 'average'
num_slices: 5
# Data augmentation
augmentation:
enable: true
rotation_range: [-10, 10]
scale_range: [0.95, 1.05]
flip_probability: 0.5
intensity_shift: [-0.1, 0.1]
gamma_range: [0.9, 1.1]

# Training configuration
training:
epochs: 50
device: 'cuda' # will fall back to CPU if not available
device: 'cuda' # will fall back to CPU
seed: 42
learning_rate: 0.0001
base_learning_rate: 0.0001
layer_specific_lrs:
pretrained: 0.00001
new: 0.0001
optimizer:
type: 'adamw'
weight_decay: 0.01
layer_specific_lrs: # Different learning rates for different components
vit_backbone: 0.00001
medical_preprocessing: 0.0001
feature_enhance: 0.0001
classifier: 0.0001
scheduler:
type: 'cosine'
T_0: 5
warmup_epochs: 2
min_lr: 1.0e-6
early_stopping:
enable: true
patience: 10
min_delta: 0.001
loss:
type: 'cross_entropy'
label_smoothing: 0.1

# Paths configuration
# Experiment tracking
experiment:
name: 'alzheimer_detection_comparison'
models_to_run: ['vit3d', 'vit2d', 'cnn3d']
metrics:
- accuracy
- precision
- recall
- f1
- confusion_matrix
save_predictions: true
save_model: true
plot_results: true

# Paths and logging
paths:
output_dir: './output'
log_dir: './logs'
checkpoint_dir: './checkpoints'
data:
raw: './adni/raw'
processed: './adni/processed'
metadata: './metadata/adni.csv'
results_dir: './results'

# Logging configuration
logging:
level: 'INFO'
save_to_file: true
log_frequency: 10

# Memory management for large medical images
memory:
clear_cache_frequency: 1
pin_memory: false # For CPU training
num_workers: 0 # For CPU training
prefetch_factor: 2
log_frequency: 10
174 changes: 59 additions & 115 deletions models/architectures/vit3d.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,17 @@
"""
3D Vision Transformer model for Alzheimer's detection with fixed initialization.
Memory-efficient 3D Vision Transformer for Alzheimer's detection.
"""

import torch
import torch.nn as nn
from transformers import ViTModel
import logging
import numpy as np
import math
from torch.nn import functional as F

logger = logging.getLogger(__name__)

class LayerNormWithFixedInit(nn.LayerNorm):
"""Custom LayerNorm with fixed initialization."""
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)

class ViT3D(nn.Module):
"""3D Vision Transformer optimized for medical image analysis."""

def __init__(
self,
num_labels: int,
Expand All @@ -32,13 +22,10 @@ def __init__(
):
super().__init__()

# Validate input parameters
assert input_size % patch_size == 0, "Input size must be divisible by patch size"
self.input_size = input_size
self.patch_size = patch_size
self.num_patches = (input_size // patch_size) ** 3

# Log initialization parameters
logger.info(f"Initializing ViT3D with:")
logger.info(f"- Input size: {input_size}x{input_size}x{input_size}")
logger.info(f"- Patch size: {patch_size}x{patch_size}x{patch_size}")
Expand All @@ -50,63 +37,36 @@ def __init__(
'google/vit-base-patch16-224-in21k',
add_pooling_layer=False
)
hidden_size = self.vit.config.hidden_size # 768
hidden_size = self.vit.config.hidden_size

# Medical image specific preprocessing
# Efficient preprocessing
self.preprocess = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=7, padding=3),
nn.InstanceNorm2d(16),
nn.GELU(),
nn.Conv2d(16, 32, kernel_size=5, padding=2),
nn.InstanceNorm2d(32),
nn.GELU(),
nn.Conv2d(32, 3, kernel_size=1),
nn.InstanceNorm2d(3)
nn.Conv2d(1, 3, 3, padding=1, bias=False), # Simpler conv
nn.BatchNorm2d(3), # BatchNorm uses less memory
nn.ReLU() # ReLU is more memory efficient than GELU
)

# Slice selection module with attention
self.slice_attention = nn.Sequential(
nn.Conv3d(1, 16, kernel_size=3, padding=1),
nn.InstanceNorm3d(16),
nn.GELU(),
nn.Conv3d(16, 8, kernel_size=3, padding=1),
nn.InstanceNorm3d(8),
nn.GELU(),
nn.Conv3d(8, 3, kernel_size=1),
# Memory-efficient slice selection
self.slice_selection = nn.Sequential(
nn.Conv3d(1, 8, 1, bias=False), # 1x1x1 conv uses less memory
nn.BatchNorm3d(8),
nn.ReLU(),
nn.Conv3d(8, 3, 1, bias=False),
nn.Softmax(dim=2)
)

# Feature enhancement module
self.feature_enhance = nn.Sequential(
nn.Linear(hidden_size, hidden_size * 2),
LayerNormWithFixedInit(hidden_size * 2),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_size * 2, hidden_size)
)

# View fusion module
# Feature fusion (simplified)
self.fusion = nn.Sequential(
nn.Linear(hidden_size * 3, hidden_size * 2),
LayerNormWithFixedInit(hidden_size * 2),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_size * 2, hidden_size),
LayerNormWithFixedInit(hidden_size)
nn.Linear(hidden_size * 3, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Dropout(dropout_rate)
)

# Classification head
self.classifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.GELU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_size // 2, num_labels)
)
self.classifier = nn.Linear(hidden_size, num_labels)

# Initialize weights explicitly
self._init_weights()

# Freeze layers selectively
if freeze_layers:
self._freeze_layers()

Expand All @@ -117,89 +77,73 @@ def __init__(
logger.info(f"Trainable parameters: {trainable_params:,}")

def _init_weights(self):
"""Initialize weights explicitly."""
def init_module(m):
"""Initialize weights."""
for m in self.modules():
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
if getattr(m, 'bias', None) is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.InstanceNorm2d, nn.InstanceNorm3d)):
if m.weight is not None:
nn.init.ones_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, LayerNormWithFixedInit):
m.reset_parameters()

self.apply(init_module)
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)

def _freeze_layers(self):
"""Selective freezing for better transfer learning."""
# Freeze early layers
for param in self.vit.embeddings.parameters():
"""Freeze pretrained layers."""
self.vit.eval() # Set to eval mode to save memory
for param in self.vit.parameters():
param.requires_grad = False

# Only train the last 4 transformer layers
for layer in self.vit.encoder.layer[:-4]:
for param in layer.parameters():
param.requires_grad = False

def _get_attention_weighted_slices(self, x: torch.Tensor) -> tuple:
"""Extract attention-weighted slices from volume."""
@torch.no_grad() # Memory optimization
def _get_central_slices(self, x: torch.Tensor) -> tuple:
"""Get central slices efficiently."""
B, C, D, H, W = x.shape

# Generate attention weights for each direction
attention_weights = self.slice_attention(x)
# Get central indices
d_mid = D // 2
h_mid = H // 2
w_mid = W // 2

# Extract weighted slices for each view
d_center, h_center, w_center = D//2, H//2, W//2
span = 3 # Consider slices around center
# Extract central slices directly
axial = x[:, :, d_mid].clone() # [B, C, H, W]
sagittal = x[:, :, :, h_mid].clone() # [B, C, D, W]
coronal = x[:, :, :, :, w_mid].clone() # [B, C, D, H]

# Compute weighted averages around central slices
axial_region = x[:, :, d_center-span:d_center+span+1]
sagittal_region = x[:, :, :, h_center-span:h_center+span+1]
coronal_region = x[:, :, :, :, w_center-span:w_center+span+1]

axial = (axial_region * attention_weights[:, 0:1, d_center-span:d_center+span+1]).sum(dim=2)
sagittal = (sagittal_region * attention_weights[:, 1:2, :, h_center-span:h_center+span+1]).sum(dim=3)
coronal = (coronal_region * attention_weights[:, 2:3, :, :, w_center-span:w_center+span+1]).sum(dim=4)

return axial, sagittal, coronal
return axial, sagittal.transpose(2, 3), coronal.transpose(2, 3)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass with enhanced medical image processing."""
# Extract attention-weighted slices
axial, sagittal, coronal = self._get_attention_weighted_slices(x)
"""Memory-efficient forward pass."""
# Get central slices (more memory efficient than attention)
axial, sagittal, coronal = self._get_central_slices(x)

# Process each view
# Process views in sequence to save memory
view_features = []
for view in [axial, sagittal, coronal]:
# Medical image specific preprocessing
# Preprocess
view = self.preprocess(view)

# Normalize to match pretrained model's distribution
view = (view - view.mean(dim=[2, 3], keepdim=True)) / (view.std(dim=[2, 3], keepdim=True) + 1e-6)

# Pass through ViT
outputs = self.vit(pixel_values=view, return_dict=True)
# Basic normalization
view = F.normalize(view.flatten(2), dim=-1).reshape_as(view)

# Enhance features
# Get features
with torch.no_grad(): # Don't store gradients for ViT if frozen
outputs = self.vit(pixel_values=view, return_dict=True)
features = outputs.last_hidden_state[:, 0]
enhanced = self.feature_enhance(features) + features
view_features.append(enhanced)
view_features.append(features)

# Combine view features
combined = torch.cat(view_features, dim=1)
fused = self.fusion(combined)
# Clear cache after each view
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()

# Final classification
output = self.classifier(fused)
# Combine features
x = torch.cat(view_features, dim=1)
x = self.fusion(x)
x = self.classifier(x)

return output
return x


def create_model(config: dict) -> nn.Module:
"""Create a ViT3D model from config."""
"""Create a memory-efficient ViT3D model."""
model = ViT3D(
num_labels=config['model']['num_labels'],
freeze_layers=config['model']['freeze_layers'],
Expand Down

0 comments on commit 7a58bdd

Please sign in to comment.