Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat mtl pair training #11

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/open_clip/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ClipLoss,
CoCaLoss,
DistillClipLoss,
MTLPairLoss,
SigLipLoss,
ThreeTowerLoss,
ThreeTowersCosEmbeddingLoss,
Expand All @@ -26,6 +27,7 @@
load_checkpoint,
set_model_preprocess_cfg,
)
from .mtl_model import MTLPairCLIP
from .multi_tower_model import ThreeTowersCustomTextCLIP
from .openai import load_openai_model
from .pretrained import (
Expand Down Expand Up @@ -254,6 +256,10 @@ def create_model(
model = ThreeTowersCustomTextCLIP(
**model_cfg, cast_dtype=cast_dtype, cache_dir=cache_dir
)
elif 'mtl_training' in model_cfg:
model = MTLPairCLIP(
**model_cfg, cast_dtype=cast_dtype, cache_dir=cache_dir
)
else:
model = CustomTextCLIP(
**model_cfg, cast_dtype=cast_dtype, cache_dir=cache_dir
Expand Down Expand Up @@ -365,6 +371,19 @@ def create_loss(args):
world_size=args.world_size,
use_horovod=args.horovod,
)
elif 'mtl-pair' in args.model_lower():
return MTLPairLoss(
pair_loss_weight=args.mtl_pair_loss_weight,
temperature=args.temperature_pair_loss,
bidirectional=args.bidirectional_pair_loss,
clip_loss_weight=args.mtl_clip_loss_weight,
local_loss=args.local_loss,
gather_with_grad=args.gather_with_grad,
cache_labels=True,
rank=args.rank,
world_size=args.world_size,
use_horovod=args.horovod,
)
elif '3towers-text' in args.model.lower():
return ThreeTowersCosEmbeddingLoss(
mse_loss_weight=args.coca_caption_loss_weight,
Expand Down
102 changes: 102 additions & 0 deletions src/open_clip/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,108 @@ def forward(self, image_features, text_features, logit_scale, output_dict=False)
return {'contrastive_loss': total_loss} if output_dict else total_loss


def cos_sim(a: torch.Tensor, b: torch.Tensor):
"""
Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
:return: Matrix with res[i][j] = cos_sim(a[i], b[j])
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)

if not isinstance(b, torch.Tensor):
b = torch.tensor(b)

if len(a.shape) == 1:
a = a.unsqueeze(0)

if len(b.shape) == 1:
b = b.unsqueeze(0)

a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
return torch.mm(a_norm, b_norm.transpose(0, 1))


def info_nce(left, right, temperature):
logits = nn.functional.log_softmax(cos_sim(left, right) / temperature, dim=1)
return -torch.mean(torch.diag(logits))


class InfoNCELoss(nn.Module):
def __init__(
self,
temperature: float = 0.05,
bidirectional: bool = True,
):
super(InfoNCELoss, self).__init__()
self.temperature = temperature
self.bidirectional = bidirectional

def forward(self, embeddings_left, embeddings_right):
loss = info_nce(embeddings_left, embeddings_right, self.temperature)
if self.bidirectional:
loss += info_nce(embeddings_right, embeddings_left, self.temperature)
return loss / 2
return loss


class MTLPairLoss(nn.Module):
def __init__(
self,
pair_loss_weight,
temperature,
bidirectional,
clip_loss_weight,
pad_id=0, # pad_token for open_clip custom tokenizer
local_loss=False,
gather_with_grad=False,
cache_labels=False,
rank=0,
world_size=1,
use_horovod=False,
):
super(MTLPairLoss, self).__init__()

self._clip_loss = ClipLoss(
local_loss=local_loss,
gather_with_grad=gather_with_grad,
cache_labels=cache_labels,
rank=rank,
world_size=world_size,
use_horovod=use_horovod,
)
self._clip_loss_weight = clip_loss_weight

self._pair_loss = InfoNCELoss(
temperature=temperature, bidirectional=bidirectional
)
self._pair_loss_weight = pair_loss_weight

def forward(
self,
image_features,
text_features,
embedding_batch_features,
logit_scale,
output_dict=False,
):
clip_loss = torch.tensor(0)
pair_loss = torch.tensor(0)

if image_features is not None and text_features is not None:
clip_loss = self._clip_loss(image_features, text_features, logit_scale)
clip_loss = self._clip_loss_weight * clip_loss

if embedding_batch_features is not None:
pair_loss = self._pair_loss(*embedding_batch_features)
pair_loss = pair_loss * self._pair_loss_weight

if output_dict:
return {'multimodal_loss': clip_loss, 'text_pair_loss': pair_loss}

return clip_loss, pair_loss


class ThreeTowersCosEmbeddingLoss(ClipLoss):
def __init__(
self,
Expand Down
89 changes: 89 additions & 0 deletions src/open_clip/mtl_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import List, Optional, Union

import numpy as np
import torch
import torch.nn.functional as f
from torch import nn
from transformers.tokenization_utils import BatchEncoding

from .hf_model import HFTextEncoder
from .model import (
CLIPTextCfg,
CLIPVisionCfg,
CustomTextCLIP,
_build_text_tower,
_build_vision_tower,
)
from .transformer import TextTransformer, VisionTransformer


class MTLPairCLIP(CustomTextCLIP):
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
cache_dir: Optional[str] = None,
tie_projections: bool = False,
):
super(MTLPairCLIP, self).__init__(
embed_dim=embed_dim,
vision_cfg=vision_cfg,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
init_logit_scale=init_logit_scale,
init_logit_bias=init_logit_bias,
cast_dtype=cast_dtype,
output_dict=output_dict,
cache_dir=cache_dir,
)

def forward(
self,
image: Optional[torch.Tensor] = None,
text: Optional[torch.Tensor] = None,
embedding_batch: Optional[List[BatchEncoding]] = None,
):
image_features = (
self.encode_image(image, normalize=True) if image is not None else None
)
text_features = (
self.encode_text(text, normalize=True) if text is not None else None
)

embedding_batch_features = (
[self.encode_text(embedding) for embedding in embedding_batch]
if embedding_batch is not None
else None
)

if self.output_dict:
out_dict = {
'image_features': image_features,
'text_features': text_features,
'embedding_batch_features': embedding_batch_features,
'logit_scale': self.logit_scale.exp(),
}
if self.logit_bias is not None:
out_dict['logit_bias'] = self.logit_bias
return out_dict

if self.logit_bias is not None:
return (
image_features,
text_features,
embedding_batch_features,
self.logit_scale.exp(),
self.logit_bias,
)
return (
image_features,
text_features,
embedding_batch_features,
self.logit_scale.exp(),
)
28 changes: 26 additions & 2 deletions src/training/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,13 +488,37 @@ def parse_args(args):
'--3towers-cos-embeddings-loss-weight',
type=float,
default=2.0,
help='Weight assigned to caption loss in CoCa.',
help='Weight assigned to cosine embedding loss when training with 3towers loss.',
)
parser.add_argument(
'--3towers-contrastive-loss-weight',
type=float,
default=1.0,
help='Weight assigned to contrastive loss when training CoCa.',
help='Weight assigned to contrastive loss when training with 3towers loss.',
)
parser.add_argument(
'--mtl-pair-loss-weight',
type=float,
default=1.0,
help='Weight assigned to text pair loss in multitask learning.',
)
parser.add_argument(
'--mtl-clip-loss-weight',
type=float,
default=1.0,
help='Weight assigned to contrastive multimodal pair loss in multitask learning.',
)
parser.add_argument(
'--temperature-pair-loss',
type=float,
default=0.1,
help='Temperature to be assigned to InfoNCELoss for text pair training',
)
parser.add_argument(
'--bidirectional-pair-loss',
default=False,
action='store_true',
help='Option to compute InfoNCELoss, for text pair training, bidirectionally',
)
parser.add_argument(
'--remote-sync',
Expand Down