Skip to content

Commit

Permalink
Add Documentation for lightly/models/modules (#1704)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayush22iitbhu authored Oct 28, 2024
1 parent c57bdf1 commit c4551db
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 18 deletions.
51 changes: 45 additions & 6 deletions lightly/models/modules/ijepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class IJEPAPredictor(vision_transformer.Encoder):
Percentage of elements set to zero after the MLP in the transformer.
attention_dropout:
Percentage of elements set to zero after the attention head.
"""

def __init__(
Expand All @@ -56,6 +55,8 @@ def __init__(
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
**kwargs,
):
"""Initializes the IJEPAPredictor with the specified dimensions."""

super().__init__(
seq_length=seq_length,
num_layers=num_layers,
Expand All @@ -81,7 +82,16 @@ def __init__(

@classmethod
def from_vit_encoder(cls, vit_encoder, num_patches):
"""Creates a I-JEPA predictor backbone (mhas and layernorm) from a torchvision ViT encoder."""
"""Creates an I-JEPA predictor backbone (multi-head attention and layernorm) from a torchvision ViT encoder.
Args:
vit_encoder: The Vision Transformer encoder from torchvision.
num_patches: The number of patches (tokens).
Returns:
IJEPAPredictor: An I-JEPA predictor backbone initialized from the ViT encoder.
"""

# Create a new instance with dummy values as they will be overwritten
# by the copied vit_encoder attributes
encoder = cls(
Expand All @@ -95,11 +105,27 @@ def from_vit_encoder(cls, vit_encoder, num_patches):
dropout=0,
attention_dropout=0,
)

# Copy attributes from the ViT encoder
encoder.layers = vit_encoder.layers
encoder.ln = vit_encoder.ln

return encoder

def forward(self, x, masks_x, masks):
"""Forward pass of the IJEPAPredictor.
Args:
x:
Input tensor.
masks_x:
Mask indices for the input tensor.
masks:
Mask indices for the predicted tokens.
Returns:
The predicted output tensor.
"""
assert (masks is not None) and (
masks_x is not None
), "Cannot run predictor without mask indices"
Expand Down Expand Up @@ -160,7 +186,6 @@ class IJEPAEncoder(vision_transformer.Encoder):
Percentage of elements set to zero after the MLP in the transformer.
attention_dropout:
Percentage of elements set to zero after the attention head.
"""

def __init__(
Expand All @@ -174,6 +199,8 @@ def __init__(
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
"""Initializes the IJEPAEncoder with the specified dimensions."""

super().__init__(
seq_length=seq_length,
num_layers=num_layers,
Expand All @@ -188,6 +215,7 @@ def __init__(
@classmethod
def from_vit_encoder(cls, vit_encoder: vision_transformer.Encoder):
"""Creates a IJEPA encoder from a torchvision ViT encoder."""

# Create a new instance with dummy values as they will be overwritten
# by the copied vit_encoder attributes
encoder = cls(
Expand Down Expand Up @@ -221,6 +249,7 @@ def forward(
Returns:
Batch of encoded output tokens.
"""

input = input + self.interpolate_pos_encoding(input)
if idx_keep is not None:
input = utils.apply_masks(input, idx_keep)
Expand All @@ -236,7 +265,11 @@ def interpolate_pos_encoding(self, input: torch.Tensor):
input:
Input tensor with shape (batch_size, num_sequences).
Returns:
Interpolated positional embedding.
"""

# code copied from:
# https://github.com/facebookresearch/msn/blob/4388dc1eadbe3042b85d3296d41b9b207656e043/src/deit.py#L291
npatch = input.shape[1] - 1
Expand Down Expand Up @@ -264,6 +297,7 @@ class IJEPABackbone(vision_transformer.VisionTransformer):
in the future.
Converts images into patches and encodes them. Code inspired by [1].
Note that this implementation uses a learned positional embedding while [0]
uses a fixed positional embedding.
Expand Down Expand Up @@ -342,6 +376,7 @@ def __init__(
@classmethod
def from_vit(cls, vit: vision_transformer.VisionTransformer):
"""Creates a IJEPABackbone from a torchvision ViT model."""

# Create a new instance with dummy values as they will be overwritten
# by the copied vit_encoder attributes
backbone = cls(
Expand All @@ -357,18 +392,20 @@ def from_vit(cls, vit: vision_transformer.VisionTransformer):
representation_size=vit.representation_size,
norm_layer=vit.norm_layer,
)

# Copy attributes from the ViT model
backbone.conv_proj = vit.conv_proj
backbone.class_token = vit.class_token
backbone.seq_length = vit.seq_length
backbone.heads = vit.heads
backbone.encoder = IJEPAEncoder.from_vit_encoder(vit.encoder)

return backbone

def forward(
self, images: torch.Tensor, idx_keep: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Returns encoded class tokens from a batch of images.
"""Returns encoded class tokens from a batch of images.
Args:
images:
Expand All @@ -382,8 +419,8 @@ def forward(
Returns:
Tensor with shape (batch_size, hidden_dim) containing the
encoded class token for every image.
"""

if idx_keep is not None:
if not isinstance(idx_keep, list):
idx_keep = [idx_keep]
Expand Down Expand Up @@ -421,6 +458,8 @@ def images_to_tokens(
Args:
images:
Tensor with shape (batch_size, channels, image_size, image_size).
prepend_class_token:
Whether to prepend the class token to the patch tokens.
Returns:
Tensor with shape (batch_size, sequence_length - 1, hidden_dim)
Expand Down
31 changes: 25 additions & 6 deletions lightly/models/modules/ijepa_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class IJEPAPredictorTIMM(nn.Module):
Percentage of elements set to zero after the attention head.
norm_layer:
Normalization layer.
"""

def __init__(
Expand All @@ -59,6 +58,8 @@ def __init__(
attn_drop_rate: float = 0.0,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
"""Initializes the IJEPAPredictorTIMM with the specified dimensions."""

super().__init__()

self.predictor_embed = nn.Linear(mlp_dim, predictor_embed_dim, bias=True)
Expand Down Expand Up @@ -98,6 +99,20 @@ def forward(
masks_x: Union[List[torch.Tensor], torch.Tensor],
masks: Union[List[torch.Tensor], torch.Tensor],
) -> torch.Tensor:
"""Forward pass of the IJEPAPredictorTIMM.
Args:
x:
Input tensor.
masks_x:
Mask indices for the input tensor.
masks:
Mask indices for the predicted tokens.
Returns:
The predicted output tensor.
"""

assert (masks is not None) and (
masks_x is not None
), "Cannot run predictor without mask indices"
Expand Down Expand Up @@ -147,16 +162,20 @@ def repeat_interleave_batch(
def apply_masks(
self, x: torch.Tensor, masks: Union[torch.Tensor, List[torch.Tensor]]
) -> torch.Tensor:
"""
"""Apply masks to the input tensor.
From https://github.com/facebookresearch/ijepa/blob/main/src/masks/utils.py
Apply masks to the input tensor.
Args:
x: tensor of shape [B (batch-size), N (num-patches), D (feature-dim)]
masks: tensor or list of tensors containing indices of patches in [N] to keep
x:
tensor of shape [B (batch-size), N (num-patches), D (feature-dim)].
masks:
tensor or list of tensors containing indices of patches in [N] to keep.
Returns:
tensor of shape [B, N', D] where N' is the number of patches to keep
Tensor of shape [B, N', D] where N' is the number of patches to keep.
"""

if not isinstance(masks, list):
masks = [masks]

Expand Down
31 changes: 28 additions & 3 deletions lightly/models/modules/masked_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class MAEEncoder(vision_transformer.Encoder):
Percentage of elements set to zero after the MLP in the transformer.
attention_dropout:
Percentage of elements set to zero after the attention head.
"""

def __init__(
Expand All @@ -52,6 +51,8 @@ def __init__(
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
"""Initializes the MAEEncoder with the specified dimensions."""

super().__init__(
seq_length=seq_length,
num_layers=num_layers,
Expand Down Expand Up @@ -79,8 +80,8 @@ def from_vit_encoder(
Returns:
A MAEEncoder with the same architecture as vit_encoder.
"""

# Create a new instance with dummy values as they will be overwritten
# by the copied vit_encoder attributes
encoder = cls(
Expand All @@ -92,10 +93,13 @@ def from_vit_encoder(
dropout=0,
attention_dropout=0,
)

# Copy attributes from the ViT encoder
encoder.pos_embedding = vit_encoder.pos_embedding
encoder.dropout = vit_encoder.dropout
encoder.layers = vit_encoder.layers
encoder.ln = vit_encoder.ln

if initialize_weights:
encoder._initialize_weights()
return encoder
Expand Down Expand Up @@ -131,16 +135,22 @@ def interpolate_pos_encoding(self, input: torch.Tensor):
input:
Input tensor with shape (batch_size, num_sequences).
Returns:
Interpolated positional embedding.
"""
# code copied from:
# https://github.com/facebookresearch/msn/blob/4388dc1eadbe3042b85d3296d41b9b207656e043/src/deit.py#L291
npatch = input.shape[1] - 1
N = self.pos_embedding.shape[1] - 1
if npatch == N:
return self.pos_embedding

# Separate the class embedding from the positional embeddings
class_emb = self.pos_embedding[:, 0]
pos_embedding = self.pos_embedding[:, 1:]
dim = input.shape[-1]

pos_embedding = nn.functional.interpolate(
pos_embedding.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
0, 3, 1, 2
Expand Down Expand Up @@ -215,6 +225,8 @@ def __init__(
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
conv_stem_configs: Optional[List[ConvStemConfig]] = None,
):
"""Initializes the MAEBackbone with the specified dimensions."""

super().__init__(
image_size=image_size,
patch_size=patch_size,
Expand Down Expand Up @@ -272,6 +284,8 @@ def from_vit(
representation_size=vit.representation_size,
norm_layer=vit.norm_layer,
)

# Copy attributes from the ViT model
backbone.conv_proj = vit.conv_proj
backbone.class_token = vit.class_token
backbone.seq_length = vit.seq_length
Expand Down Expand Up @@ -334,18 +348,23 @@ def images_to_tokens(
Args:
images:
Tensor with shape (batch_size, channels, image_size, image_size).
prepend_class_token:
Whether to prepend the class token to the patch tokens.
Returns:
Tensor with shape (batch_size, sequence_length - 1, hidden_dim)
containing the patch tokens.
"""

x = self.conv_proj(images)
tokens = x.flatten(2).transpose(1, 2)
if prepend_class_token:
tokens = utils.prepend_class_token(tokens, self.class_token)
return tokens

def _initialize_weights(self) -> None:
"""Initializes weights for the backbone components."""

# Initialize the patch embedding layer like a linear layer instead of conv
# layer.
w = self.conv_proj.weight.data
Expand Down Expand Up @@ -404,6 +423,8 @@ def __init__(
attention_dropout: float = 0.0,
norm_layer: Callable[..., nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
"""Initializes the MAEDecoder with the specified dimensions."""

super().__init__(
seq_length=seq_length,
num_layers=num_layers,
Expand All @@ -427,8 +448,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
Returns:
Tensor with shape (batch_size, seq_length, out_dim).
"""

out = self.embed(input)
out = self.decode(out)
return self.predict(out)
Expand Down Expand Up @@ -487,6 +508,8 @@ def _initialize_weights(self) -> None:


def _initialize_2d_sine_cosine_positional_embedding(pos_embedding: Parameter) -> None:
"""Initializes a 2D sine-cosine positional embedding."""

_, seq_length, hidden_dim = pos_embedding.shape
grid_size = int((seq_length - 1) ** 0.5)
sine_cosine_embedding = utils.get_2d_sine_cosine_positional_embedding(
Expand All @@ -502,6 +525,8 @@ def _initialize_2d_sine_cosine_positional_embedding(pos_embedding: Parameter) ->


def _initialize_linear_layers(module: Module) -> None:
"""Initializes linear layers in the given module."""

def init(mod: Module) -> None:
if isinstance(mod, Linear):
nn.init.xavier_uniform_(mod.weight)
Expand Down
Loading

0 comments on commit c4551db

Please sign in to comment.