From 53a756cca5560a1424a5ee72b3ece673293e5367 Mon Sep 17 00:00:00 2001 From: andyh Date: Mon, 18 Nov 2024 16:11:18 +0100 Subject: [PATCH] fix data loader and preprocessing visualisations and main --- config.yaml | 71 ++++----- data/data_loader.py | 253 ++++++++++++++++--------------- main.py | 225 ++++++--------------------- models/__init__.py | 4 +- models/architectures/__init__.py | 13 +- models/architectures/cnn3d.py | 160 ++++++++++++++----- models/architectures/vit2d.py | 93 ++++++++++-- models/factory.py | 55 +++++++ 8 files changed, 465 insertions(+), 409 deletions(-) create mode 100644 models/factory.py diff --git a/config.yaml b/config.yaml index 4110c17..089bcbf 100644 --- a/config.yaml +++ b/config.yaml @@ -1,46 +1,45 @@ -# Model Configurations +# Model configuration model: - # Shared settings + type: 'vit2d' num_labels: 3 + freeze_layers: true input_size: 224 + patch_size: 16 dropout_rate: 0.1 - freeze_layers: true # Model specific settings - vit3d: - type: 'vit3d' - patch_sizes: [8, 16, 32] # Different patch sizes to try + vit2d: pretrained: 'google/vit-base-patch16-224-in21k' + hidden_size: 768 + intermediate_size: 3072 + num_attention_heads: 12 - vit2d: - type: 'vit2d' - patch_sizes: [8, 16, 32] + vit3d: pretrained: 'google/vit-base-patch16-224-in21k' - slice_mode: 'center' # or 'average' + hidden_size: 768 + use_memory_efficient: true cnn3d: - type: 'cnn3d' - patch_sizes: [8, 16, 32] - pretrained: 'resnet50' # Will use inflated ResNet - channels: [64, 128, 256, 512] + pretrained: true + base_channels: 64 + num_blocks: [3, 4, 6, 3] # Dataset configuration dataset: path: './adni' - batch_size: 8 + batch_size: 16 val_ratio: 0.15 test_ratio: 0.15 + input_size: 224 preprocessing: voxel_spacing: [1.5, 1.5, 1.5] orientation: 'RAS' intensity_norm: true foreground_crop: true crop_margin: 10 - # 2D specific slice_selection: method: 'center' # or 'average' num_slices: 5 - # Data augmentation augmentation: enable: true rotation_range: [-10, 10] @@ -52,42 +51,38 @@ training: epochs: 50 device: 'cuda' # will fall back to CPU seed: 42 - base_learning_rate: 0.0001 - layer_specific_lrs: - pretrained: 0.00001 - new: 0.0001 + learning_rate: 0.0001 optimizer: type: 'adamw' weight_decay: 0.01 + layer_specific_lrs: + pretrained: 0.00001 + new: 0.0001 scheduler: type: 'cosine' + T_0: 10 # Added for cosine annealing + T_mult: 2 + eta_min: 1e-6 warmup_epochs: 2 - min_lr: 1.0e-6 early_stopping: patience: 10 min_delta: 0.001 + gradient_clipping: + enable: true + max_norm: 1.0 + mixed_precision: false -# Experiment tracking -experiment: - name: 'alzheimer_detection_comparison' - models_to_run: ['vit3d', 'vit2d', 'cnn3d'] - metrics: - - accuracy - - precision - - recall - - f1 - - confusion_matrix - save_predictions: true - save_model: true - plot_results: true - -# Paths and logging +# Paths configuration paths: output_dir: './output' log_dir: './logs' checkpoint_dir: './checkpoints' - results_dir: './results' + data: + raw: './adni/raw' + processed: './adni/processed' + metadata: './metadata/adni.csv' +# Logging configuration logging: level: 'INFO' save_to_file: true diff --git a/data/data_loader.py b/data/data_loader.py index a8d803e..14bf792 100644 --- a/data/data_loader.py +++ b/data/data_loader.py @@ -26,86 +26,32 @@ RandAffined, RandGaussianNoised, EnsureTyped, - ToTensord + ToTensord, + Lambda ) import logging from tqdm import tqdm logger = logging.getLogger(__name__) -def collate_fn(batch): - """ - Custom collate function to ensure proper batch dimension and tensor types. - Must be defined at module level for multiprocessing to work. - - Args: - batch: List of samples from the dataset - - Returns: - Dict containing batched tensors - """ - batch_data = {} - for key in batch[0].keys(): - if key == 'image': - # Stack images and ensure 5D shape [B, C, D, H, W] - images = torch.stack([item[key] for item in batch]) - if len(images.shape) == 4: # [B, D, H, W] - images = images.unsqueeze(1) # Add channel dimension - batch_data[key] = images - else: - batch_data[key] = torch.tensor([item[key] for item in batch]) - return batch_data - -def load_config(config_path: str = "config.yaml") -> Dict[str, Any]: - """Load configuration from YAML file.""" - with open(config_path, 'r') as f: - config = yaml.safe_load(f) - return config - -class ShapeCheckd(object): - """Custom transform to verify tensor shapes.""" - - def __init__(self, keys: List[str]): - self.keys = keys - - def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: - d = dict(data) - for key in self.keys: - if key in d: - img = d[key] - if isinstance(img, torch.Tensor): - if len(img.shape) != 4: # [C, D, H, W] - logger.warning(f"Unexpected shape in {key}: {img.shape}") - if len(img.shape) == 3: # [D, H, W] - img = img.unsqueeze(0) # Add channel dim - d[key] = img - logger.debug(f"Shape after checking {key}: {img.shape}") - return d - class ADNIDataset(Dataset): - """Custom Dataset for loading ADNI data.""" + """Dataset for loading ADNI data in both 2D and 3D modes.""" def __init__( self, config: Dict[str, Any], transform: Optional[Compose] = None, - split: str = 'train' + split: str = 'train', + mode: str = '3d' ): - """ - Initialize the dataset. - - Args: - config: Configuration dictionary - transform: MONAI transforms to apply - split: Dataset split ('train', 'val', or 'test') - """ self.data_root = Path(config['dataset']['path']) self.transform = transform self.split = split + self.mode = mode self.file_list = self._create_file_list() self.label_to_idx = {'AD': 0, 'CN': 1, 'MCI': 2} - logger.info(f"Initialized {split} dataset with {len(self.file_list)} samples") + logger.info(f"Initialized {split} dataset with {len(self.file_list)} samples in {mode} mode") def _create_file_list(self) -> List[Tuple[Path, str]]: """Create list of file paths and labels.""" @@ -147,15 +93,13 @@ def __getitem__(self, idx: int) -> Dict[str, Any]: logger.error(f"Error loading {file_path}: {str(e)}") raise -def get_transforms(config: Dict[str, Any], split: str) -> Compose: - """ - Get preprocessing transforms based on configuration. - - Args: - config: Configuration dictionary - split: Dataset split ('train', 'val', or 'test') - """ - spatial_size = (config['dataset']['input_size'],) * 3 +def get_transforms(config: Dict[str, Any], split: str, mode: str = '3d') -> Compose: + """Get transforms based on configuration and mode.""" + # Set spatial size based on mode + if mode == '2d': + spatial_size = (config['dataset']['input_size'],) * 2 + else: + spatial_size = (config['dataset']['input_size'],) * 3 common_transforms = [ LoadImaged(keys=["image"]), @@ -170,68 +114,109 @@ def get_transforms(config: Dict[str, Any], split: str) -> Compose: keys=["image"], source_key="image", margin=config['dataset']['preprocessing']['crop_margin'] - ), - ResizeWithPadOrCropd( - keys=["image"], - spatial_size=spatial_size - ), + ) + ] + + # Add mode-specific resize + if mode == '2d': + resize_transform = [ + # Extract center slice for 2D + Lambda(lambda x: { + 'image': x['image'][:, x['image'].shape[1]//2, :, :] + if len(x['image'].shape) == 4 + else x['image'], + 'label': x['label'] + }), + ResizeWithPadOrCropd( + keys=["image"], + spatial_size=spatial_size + ) + ] + else: + resize_transform = [ + ResizeWithPadOrCropd( + keys=["image"], + spatial_size=spatial_size + ) + ] + + # Add remaining transforms + post_transforms = [ ScaleIntensityd(keys=["image"]), NormalizeIntensityd(keys=["image"], nonzero=True), EnsureTyped(keys=["image", "label"]), - ToTensord(keys=["image", "label"]), - ShapeCheckd(keys=["image"]) + ToTensord(keys=["image", "label"]) ] + # Combine transforms + transforms = common_transforms + resize_transform + post_transforms + # Add augmentation for training if split == 'train': - augmentation_transforms = [ - RandRotate90d( - keys=["image"], - prob=0.5, - spatial_axes=(0, 1) - ), - RandFlipd( - keys=["image"], - prob=0.5, - spatial_axis=0 - ), - RandAffined( - keys=["image"], - prob=0.5, - rotate_range=(np.pi/12, np.pi/12, np.pi/12), - scale_range=(0.1, 0.1, 0.1), - mode="bilinear" - ), - RandGaussianNoised( - keys=["image"], - prob=0.2, - mean=0.0, - std=0.1 - ) - ] - transforms = common_transforms + augmentation_transforms - else: - transforms = common_transforms + if mode == '2d': + augmentation = [ + RandRotate90d( + keys=["image"], + prob=0.5, + spatial_axes=(0, 1) + ), + RandFlipd( + keys=["image"], + prob=0.5, + spatial_axis=0 + ), + RandAffined( + keys=["image"], + prob=0.5, + rotate_range=[np.pi/12] * 2, + scale_range=[0.1] * 2, + mode="bilinear" + ) + ] + else: + augmentation = [ + RandRotate90d( + keys=["image"], + prob=0.5, + spatial_axes=(0, 1) + ), + RandFlipd( + keys=["image"], + prob=0.5, + spatial_axis=0 + ), + RandAffined( + keys=["image"], + prob=0.5, + rotate_range=[np.pi/12] * 3, + scale_range=[0.1] * 3, + mode="bilinear" + ) + ] + transforms = transforms + augmentation return Compose(transforms) -def create_data_loaders(config: Dict[str, Any]) -> Tuple[DataLoader, DataLoader, DataLoader]: - """Create train, validation, and test data loaders.""" +def create_data_loaders( + config: Dict[str, Any], + mode: str = '3d' +) -> Tuple[DataLoader, DataLoader, DataLoader]: + """Create data loaders with specified mode.""" + # Create transforms - train_transforms = get_transforms(config, 'train') - val_transforms = get_transforms(config, 'val') - test_transforms = get_transforms(config, 'test') + train_transforms = get_transforms(config, 'train', mode) + val_transforms = get_transforms(config, 'val', mode) + test_transforms = get_transforms(config, 'test', mode) # Create datasets - train_dataset = ADNIDataset(config, transform=train_transforms, split='train') + train_dataset = ADNIDataset(config, transform=train_transforms, split='train', mode=mode) + val_dataset = ADNIDataset(config, transform=val_transforms, split='val', mode=mode) + test_dataset = ADNIDataset(config, transform=test_transforms, split='test', mode=mode) # Split data train_size = 1 - config['dataset']['val_ratio'] - config['dataset']['test_ratio'] - - # Get all labels for stratification labels = [label for _, label in train_dataset.file_list] - # Create splits train_idx, temp_idx = train_test_split( range(len(train_dataset)), train_size=train_size, @@ -246,33 +231,48 @@ def create_data_loaders(config: Dict[str, Any]) -> Tuple[DataLoader, DataLoader, random_state=config['training']['seed'] ) - # Determine number of workers based on OS - num_workers = 0 if os.name == 'nt' else 4 # Use 0 workers on Windows for debugging - - # Create data loaders + def collate_fn(batch): + batch_data = {} + for key in batch[0].keys(): + if key == 'image': + # Stack images + images = torch.stack([item[key] for item in batch]) + + # Handle channel dimension based on mode + if mode == '2d' and len(images.shape) == 3: # [B, H, W] + images = images.unsqueeze(1) # Add channel dim [B, C, H, W] + elif mode == '3d' and len(images.shape) == 4: # [B, D, H, W] + images = images.unsqueeze(1) # Add channel dim [B, C, D, H, W] + + batch_data[key] = images + else: + batch_data[key] = torch.tensor([item[key] for item in batch]) + return batch_data + + # Create loaders train_loader = DataLoader( Subset(train_dataset, train_idx), batch_size=config['dataset']['batch_size'], shuffle=True, - num_workers=num_workers, + num_workers=0 if os.name == 'nt' else 4, pin_memory=True, collate_fn=collate_fn ) val_loader = DataLoader( - Subset(ADNIDataset(config, transform=val_transforms), val_idx), + Subset(val_dataset, val_idx), batch_size=config['dataset']['batch_size'], shuffle=False, - num_workers=num_workers, + num_workers=0 if os.name == 'nt' else 4, pin_memory=True, collate_fn=collate_fn ) test_loader = DataLoader( - Subset(ADNIDataset(config, transform=test_transforms), test_idx), + Subset(test_dataset, test_idx), batch_size=config['dataset']['batch_size'], shuffle=False, - num_workers=num_workers, + num_workers=0 if os.name == 'nt' else 4, pin_memory=True, collate_fn=collate_fn ) @@ -281,12 +281,13 @@ def create_data_loaders(config: Dict[str, Any]) -> Tuple[DataLoader, DataLoader, logger.info(f"Dataset splits - Train: {len(train_idx)}, " f"Val: {len(val_idx)}, Test: {len(test_idx)}") - # Verify shapes of the first batch + # Verify shapes try: train_batch = next(iter(train_loader)) - logger.info(f"Train batch image shape: {train_batch['image'].shape}") + expected_shape = "B, C, H, W" if mode == '2d' else "B, C, D, H, W" + logger.info(f"Train batch image shape ({expected_shape}): {train_batch['image'].shape}") logger.info(f"Train batch label shape: {train_batch['label'].shape}") except Exception as e: logger.warning(f"Could not verify train batch shapes: {str(e)}") - + return train_loader, val_loader, test_loader \ No newline at end of file diff --git a/main.py b/main.py index 58a0572..5973763 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,5 @@ """ -Main script for Alzheimer's detection using Vision Transformers. +Main script for Alzheimer's detection with model selection. """ import sys @@ -11,8 +11,9 @@ import random import numpy as np from datetime import datetime -from models.architectures import create_model +from models import create_model from data.data_loader import create_data_loaders +from models.train import train_model # Configure logging logging.basicConfig( @@ -25,25 +26,26 @@ ) logger = logging.getLogger(__name__) +def parse_args(): + parser = argparse.ArgumentParser(description="Alzheimer's Detection Training") + parser.add_argument('--config', type=str, default='config.yaml', + help='Path to config file') + parser.add_argument('--model', type=str, choices=['vit2d', 'vit3d', 'cnn3d'], + default='vit2d', help='Model architecture to use') + parser.add_argument('--device', type=str, choices=['cuda', 'cpu'], + help='Device to use (overrides config file)') + parser.add_argument('--patch_size', type=int, default=16, + help='Patch size for ViT models') + parser.add_argument('--debug', action='store_true', + help='Enable debug mode') + return parser.parse_args() def load_config(config_path: str = "config.yaml") -> dict: """Load configuration from YAML file.""" with open(config_path, 'r') as f: return yaml.safe_load(f) - -def set_seed(seed: int) -> None: - """Set random seeds for reproducibility.""" - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - - -def setup_experiment(config: dict) -> Path: +def setup_experiment(config: dict, model_type: str) -> Path: """Setup experiment directories and logging.""" try: if 'paths' not in config: @@ -67,12 +69,12 @@ def setup_experiment(config: dict) -> Path: # Create experiment directory with timestamp timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - exp_name = f"{config['model']['type']}_{timestamp}" + exp_name = f"{model_type}_{timestamp}" exp_dir = Path(config['paths']['output_dir']) / exp_name exp_dir.mkdir(parents=True, exist_ok=True) # Create experiment subdirectories - for subdir in ['checkpoints', 'logs', 'results', 'visualizations']: + for subdir in ['checkpoints', 'logs', 'results']: (exp_dir / subdir).mkdir(exist_ok=True) # Save config to experiment directory @@ -86,185 +88,51 @@ def setup_experiment(config: dict) -> Path: logger.error(f"Error in setup_experiment: {str(e)}") raise - -def validate(model, val_loader, criterion, device): - """Validate the model.""" - model.eval() - val_loss = 0.0 - correct = 0 - total = 0 - - with torch.no_grad(): - for batch in val_loader: - images = batch['image'].to(device) - labels = batch['label'].to(device) - outputs = model(images) - loss = criterion(outputs, labels) - - val_loss += loss.item() - _, predicted = outputs.max(1) - total += labels.size(0) - correct += predicted.eq(labels).sum().item() - - return val_loss / len(val_loader), 100. * correct / total - - -def train_model(model, train_loader, val_loader, config, device, exp_dir): - """Training loop.""" - try: - # Initialize optimizer and criterion - optimizer = torch.optim.AdamW( - model.parameters(), - lr=config['training']['learning_rate'], - weight_decay=config['training']['optimizer']['weight_decay'] - ) - criterion = torch.nn.CrossEntropyLoss() - - # Initialize tracking variables - best_val_acc = 0.0 - best_val_loss = float('inf') - val_losses = [] - checkpoint_dir = exp_dir / 'checkpoints' - - # Training loop - for epoch in range(config['training']['epochs']): - # Training phase - model.train() - running_loss = 0.0 - correct = 0 - total = 0 - - for batch_idx, batch in enumerate(train_loader): - # Get data - images = batch['image'].to(device) - labels = batch['label'].to(device) - - # Forward pass - optimizer.zero_grad() - outputs = model(images) - loss = criterion(outputs, labels) - - # Backward pass - loss.backward() - optimizer.step() - - # Update statistics - running_loss += loss.item() - _, predicted = outputs.max(1) - total += labels.size(0) - correct += predicted.eq(labels).sum().item() - - # Log progress - if batch_idx % 5 == 0: - logger.info( - f'Epoch: {epoch+1}/{config["training"]["epochs"]}, ' - f'Batch: {batch_idx}/{len(train_loader)}, ' - f'Loss: {loss.item():.4f}, ' - f'Acc: {100.*correct/total:.2f}%' - ) - - # Compute epoch statistics - train_loss = running_loss / len(train_loader) - train_acc = 100. * correct / total - - # Validation phase - val_loss, val_acc = validate(model, val_loader, criterion, device) - val_losses.append(val_loss) - - # Log epoch results - logger.info( - f'\nEpoch {epoch+1}/{config["training"]["epochs"]}:\n' - f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%\n' - f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%\n' - ) - - # Save checkpoint - checkpoint = { - 'epoch': epoch + 1, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'train_loss': train_loss, - 'val_loss': val_loss, - 'train_acc': train_acc, - 'val_acc': val_acc, - 'config': config - } - - # Save latest checkpoint - torch.save( - checkpoint, - checkpoint_dir / f'checkpoint_epoch_{epoch+1}.pt' - ) - - # Save best model - if val_acc > best_val_acc: - best_val_acc = val_acc - best_val_loss = val_loss - torch.save( - checkpoint, - checkpoint_dir / 'best_model.pt' - ) - logger.info(f'New best model saved with validation accuracy: {val_acc:.2f}%') - - # Early stopping - if config['training'].get('early_stopping', {}).get('enable', False): - patience = config['training']['early_stopping']['patience'] - min_delta = config['training']['early_stopping']['min_delta'] - if (epoch > patience and - val_loss > min(val_losses[-patience:]) - min_delta): - logger.info(f'Early stopping triggered at epoch {epoch+1}') - break - - except Exception as e: - logger.error(f"Error during training: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - raise - - def main(): """Main function to run the training pipeline.""" - parser = argparse.ArgumentParser(description="Alzheimer's Detection Model Training") - parser.add_argument('--config', type=str, default='config.yaml', - help='Path to config file') - parser.add_argument('--device', type=str, choices=['cuda', 'cpu'], - help='Device to use (overrides config file)') - parser.add_argument('--debug', action='store_true', - help='Enable debug mode') - args = parser.parse_args() + args = parse_args() try: # Load configuration logger.info("Loading configuration...") config = load_config(args.config) - - # Device handling + + # Override config with command line arguments if args.device: config['training']['device'] = args.device - elif not torch.cuda.is_available() and config['training']['device'] == 'cuda': + config['model']['type'] = args.model + config['model']['patch_size'] = args.patch_size + + # Device handling + if not torch.cuda.is_available() and config['training']['device'] == 'cuda': logger.warning("CUDA not available, falling back to CPU") config['training']['device'] = 'cpu' - device = torch.device(config['training']['device']) logger.info(f"Using device: {device}") # Set random seed - set_seed(config['training']['seed']) - logger.info(f"Set random seed: {config['training']['seed']}") + random.seed(config['training']['seed']) + np.random.seed(config['training']['seed']) + torch.manual_seed(config['training']['seed']) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(config['training']['seed']) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False # Setup experiment directory - exp_dir = setup_experiment(config) + exp_dir = setup_experiment(config, args.model) # Create data loaders logger.info("Creating data loaders...") - train_loader, val_loader, test_loader = create_data_loaders(config) + mode = '2d' if args.model == 'vit2d' else '3d' + train_loader, val_loader, test_loader = create_data_loaders(config, mode=mode) logger.info("Data loaders created successfully") # Create model - logger.info("Creating model...") - model = create_model(config) + logger.info(f"Creating {args.model} model...") + model = create_model(config, model_type=args.model) model = model.to(device) - logger.info(f"Model moved to {device} successfully") + logger.info(f"Model created and moved to {device} successfully") # Train model logger.info("Starting training...") @@ -278,6 +146,8 @@ def main(): ) logger.info("Training completed successfully") + except KeyboardInterrupt: + logger.info("Training interrupted by user") except Exception as e: logger.error(f"Error in main execution: {str(e)}") if args.debug: @@ -285,14 +155,5 @@ def main(): traceback.print_exc() raise - if __name__ == "__main__": - try: - main() - except KeyboardInterrupt: - logger.info("Training interrupted by user") - except Exception as e: - logger.error(f"An error occurred: {str(e)}") - if '--debug' in sys.argv: - import traceback - traceback.print_exc() \ No newline at end of file + main() \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py index c09de5a..dfd3430 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1 +1,3 @@ -# models/__init__.py +from .factory import create_model + +__all__ = ['create_model'] diff --git a/models/architectures/__init__.py b/models/architectures/__init__.py index 3a39539..084b2d3 100644 --- a/models/architectures/__init__.py +++ b/models/architectures/__init__.py @@ -1,10 +1,5 @@ -""" -Initialize the architectures module and provide a unified interface for model creation. -""" +from .vit2d import ViT2D +from .vit3d import ViT3D +from .cnn3d import CNN3D -from .vit3d import create_model -import logging - -logger = logging.getLogger(__name__) - -__all__ = ['create_model'] \ No newline at end of file +__all__ = ['ViT2D', 'ViT3D', 'CNN3D'] \ No newline at end of file diff --git a/models/architectures/cnn3d.py b/models/architectures/cnn3d.py index ffbf642..ce05ea9 100644 --- a/models/architectures/cnn3d.py +++ b/models/architectures/cnn3d.py @@ -1,59 +1,94 @@ """ -3D CNN model for Alzheimer's detection. +models/cnn_3d.py - 3D CNN for Alzheimer's detection with transfer learning. """ +import torch import torch.nn as nn import torchvision.models as models +import logging +logger = logging.getLogger(__name__) class CNN3D(nn.Module): - def __init__(self, num_labels: int, freeze_layers: bool = True): + def __init__( + self, + num_labels: int, + freeze_layers: bool = True, + input_size: int = 224, + patch_size: int = 16, + dropout_rate: float = 0.1 + ): super().__init__() + + # Load pretrained 2D ResNet resnet = models.resnet50(pretrained=True) - # Modify the first convolutional layer for 3D input - self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) + # Convert first conv layer to 3D + self.conv1 = nn.Conv3d( + 1, 64, + kernel_size=(7, 7, 7), + stride=(2, 2, 2), + padding=(3, 3, 3), + bias=False + ) + + # Initialize from 2D weights + with torch.no_grad(): + self.conv1.weight.copy_( + resnet.conv1.weight.unsqueeze(2).repeat(1, 1, 7, 1, 1) / 7 + ) - # Use other ResNet layers - self.bn1 = resnet.bn1 - self.relu = resnet.relu + # Convert other layers to 3D + self.bn1 = nn.BatchNorm3d(64) + self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer_3d(resnet.layer1) - self.layer2 = self._make_layer_3d(resnet.layer2) - self.layer3 = self._make_layer_3d(resnet.layer3) - self.layer4 = self._make_layer_3d(resnet.layer4) + + # Convert ResNet blocks to 3D + def convert_layer(layer2d): + blocks = [] + for block in layer2d: + blocks.append( + Block3D( + block.conv1.in_channels, + block.conv1.out_channels, + stride=block.stride[0] + ) + ) + return nn.Sequential(*blocks) + + self.layer1 = convert_layer(resnet.layer1) + self.layer2 = convert_layer(resnet.layer2) + self.layer3 = convert_layer(resnet.layer3) + self.layer4 = convert_layer(resnet.layer4) + self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1)) - self.fc = nn.Linear(2048, num_labels) + + # Classification head + self.fc = nn.Sequential( + nn.Linear(2048, 512), + nn.ReLU(inplace=True), + nn.Dropout(dropout_rate), + nn.Linear(512, num_labels) + ) if freeze_layers: - for param in self.parameters(): + self._freeze_layers() + + # Log model statistics + total_params = sum(p.numel() for p in self.parameters()) + trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + logger.info(f"Total parameters: {total_params:,}") + logger.info(f"Trainable parameters: {trainable_params:,}") + + def _freeze_layers(self): + """Freeze early layers.""" + frozen_layers = [self.conv1, self.bn1, self.layer1, self.layer2] + for layer in frozen_layers: + for param in layer.parameters(): param.requires_grad = False - # Unfreeze the final fully connected layer - for param in self.fc.parameters(): - param.requires_grad = True - - def _make_layer_3d(self, layer): - new_layer = nn.Sequential() - for i, bottleneck in enumerate(layer): - new_bottleneck = nn.Sequential() - for name, module in bottleneck.named_children(): - if isinstance(module, nn.Conv2d): - # Replace 2D convolutions with 3D - new_conv = nn.Conv3d( - in_channels=module.in_channels, - out_channels=module.out_channels, - kernel_size=module.kernel_size[0], - stride=module.stride[0], - padding=module.padding[0], - bias=module.bias - ) - new_bottleneck.add_module(name, new_conv) - else: - new_bottleneck.add_module(name, module) - new_layer.add_module(str(i), new_bottleneck) - return new_layer - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" x = self.conv1(x) x = self.bn1(x) x = self.relu(x) @@ -65,11 +100,52 @@ def forward(self, x): x = self.layer4(x) x = self.avgpool(x) - x = x.view(x.size(0), -1) + x = torch.flatten(x, 1) x = self.fc(x) + return x +class Block3D(nn.Module): + """3D version of ResNet block.""" + + def __init__(self, in_planes, planes, stride=1): + super().__init__() + self.conv1 = nn.Conv3d( + in_planes, planes, kernel_size=3, + stride=stride, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm3d(planes) + self.conv2 = nn.Conv3d( + planes, planes, kernel_size=3, + stride=1, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm3d(planes) + + if stride != 1 or in_planes != planes: + self.downsample = nn.Sequential( + nn.Conv3d( + in_planes, planes, + kernel_size=1, stride=stride, bias=False + ), + nn.BatchNorm3d(planes) + ) + else: + self.downsample = None + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = F.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = F.relu(out) -def create_cnn_3d(num_labels: int, freeze_layers: bool = True) -> nn.Module: - """Create a 3D CNN model with transfer learning.""" - return CNN3D(num_labels, freeze_layers) \ No newline at end of file + return out \ No newline at end of file diff --git a/models/architectures/vit2d.py b/models/architectures/vit2d.py index 50ddd45..fc42a72 100644 --- a/models/architectures/vit2d.py +++ b/models/architectures/vit2d.py @@ -1,20 +1,91 @@ """ -2D Vision Transformer model for Alzheimer's detection. +models/vit_2d.py - 2D Vision Transformer for Alzheimer's detection. """ -from torch import nn -from transformers import ViTForImageClassification +import torch +import torch.nn as nn +from transformers import ViTModel +import logging +logger = logging.getLogger(__name__) -def create_vit_2d(num_labels: int, freeze_layers: bool = True) -> nn.Module: - """Create a 2D Vision Transformer model with transfer learning.""" - model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=num_labels) +class ViT2D(nn.Module): + def __init__( + self, + num_labels: int, + freeze_layers: bool = True, + input_size: int = 224, + patch_size: int = 16, + dropout_rate: float = 0.1 + ): + super().__init__() - if freeze_layers: - for param in model.vit.parameters(): + # Load pretrained ViT + self.vit = ViTModel.from_pretrained( + 'google/vit-base-patch16-224-in21k', + add_pooling_layer=False + ) + hidden_size = self.vit.config.hidden_size + + # Medical image preprocessing + self.preprocess = nn.Sequential( + nn.Conv2d(1, 16, kernel_size=7, padding=3, bias=False), + nn.BatchNorm2d(16), + nn.ReLU(inplace=True), + nn.Conv2d(16, 3, kernel_size=1, bias=False), + nn.BatchNorm2d(3) + ) + + # Classification head + self.classifier = nn.Sequential( + nn.LayerNorm(hidden_size), + nn.Linear(hidden_size, hidden_size // 2), + nn.GELU(), + nn.Dropout(dropout_rate), + nn.Linear(hidden_size // 2, num_labels) + ) + + if freeze_layers: + self._freeze_layers() + + # Initialize weights + self._init_weights() + + # Log model statistics + total_params = sum(p.numel() for p in self.parameters()) + trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) + logger.info(f"Total parameters: {total_params:,}") + logger.info(f"Trainable parameters: {trainable_params:,}") + + def _init_weights(self): + """Initialize new weights.""" + for m in self.modules(): + if isinstance(m, (nn.Linear, nn.Conv2d)): + torch.nn.init.kaiming_normal_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.LayerNorm, nn.BatchNorm2d)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + + def _freeze_layers(self): + """Freeze pretrained layers.""" + for param in self.vit.parameters(): param.requires_grad = False - # Only train the classification head - model.classifier = nn.Linear(model.config.hidden_size, num_labels) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + # Preprocess grayscale to RGB-like + x = self.preprocess(x) + + # Normalize + x = (x - x.mean(dim=[2, 3], keepdim=True)) / (x.std(dim=[2, 3], keepdim=True) + 1e-6) + + # Pass through ViT + outputs = self.vit(pixel_values=x, return_dict=True) + + # Get CLS token and classify + x = outputs.last_hidden_state[:, 0] + x = self.classifier(x) - return model \ No newline at end of file + return x \ No newline at end of file diff --git a/models/factory.py b/models/factory.py new file mode 100644 index 0000000..fd3caa6 --- /dev/null +++ b/models/factory.py @@ -0,0 +1,55 @@ +""" +Model factory for creating different model architectures. +""" + +import logging +from typing import Dict, Any +from .architectures import ViT2D, ViT3D, CNN3D + +logger = logging.getLogger(__name__) + + +def create_model(config: Dict[str, Any], model_type: str = None) -> Any: + """ + Create a model based on configuration and type. + + Args: + config: Configuration dictionary + model_type: Type of model to create ('vit2d', 'vit3d', 'cnn3d') + + Returns: + Instantiated model + """ + # Use model type from config if not specified + model_type = model_type or config['model']['type'] + + # Map model types to classes + model_map = { + 'vit2d': ViT2D, + 'vit3d': ViT3D, + 'cnn3d': CNN3D + } + + if model_type not in model_map: + raise ValueError(f"Unknown model type: {model_type}. " + f"Available types: {list(model_map.keys())}") + + # Get the model class + model_class = model_map[model_type] + + # Create model instance + logger.info(f"Creating model of type: {model_type}") + try: + model = model_class( + num_labels=config['model']['num_labels'], + freeze_layers=config['model'].get('freeze_layers', True), + input_size=config['model'].get('input_size', 224), + patch_size=config['model'].get('patch_size', 16), + dropout_rate=config['model'].get('dropout_rate', 0.1) + ) + logger.info(f"Successfully created {model_type} model") + return model + + except Exception as e: + logger.error(f"Error creating model: {str(e)}") + raise \ No newline at end of file