From a3e2baf02fcdb259629a2375a677b1ee4e5eff65 Mon Sep 17 00:00:00 2001 From: andreas Date: Fri, 27 Sep 2024 15:51:32 +0200 Subject: [PATCH] interpolate positional embeddings for pretrained image tower --- src/open_clip/model.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/open_clip/model.py b/src/open_clip/model.py index e4e03b39e..969f1def1 100644 --- a/src/open_clip/model.py +++ b/src/open_clip/model.py @@ -930,19 +930,25 @@ def trace_model(model, batch_size=256, device=torch.device('cpu')): return model -def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_dim=1): +def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic'): # interpolate position embedding prepending = 'module.' if next(iter(state_dict)).startswith('module') else '' - + visual_tower = ( + 'visual.' if any('visual' in key for key in state_dict.keys()) else '' + ) # Check if the position embedding is in the state_dict - pos_embed_key = f'{prepending}visual.pos_embed' + pos_embed_key = f'{prepending}{visual_tower}pos_embed' if pos_embed_key in state_dict: pos_embed_checkpoint = state_dict[pos_embed_key] embedding_size = pos_embed_checkpoint.shape[-1] num_patches = ( model.module.visual.patch_embed.num_patches if hasattr(model, 'module') - else model.visual.patch_embed.num_patches + else ( + model.visual.patch_embed.num_patches + if hasattr(model, 'visual') + else model.patch_embed.num_patches + ) ) num_extra_tokens = model.state_dict()[pos_embed_key].shape[-2] - num_patches # height (== width) for the checkpoint position embedding @@ -972,12 +978,16 @@ def resize_eva_pos_embed(state_dict, model, interpolation: str = 'bicubic', seq_ state_dict[pos_embed_key] = new_pos_embed # Resize the patch embedding projection - patch_embed_key = f'{prepending}visual.patch_embed.proj.weight' + patch_embed_key = f'{prepending}{visual_tower}patch_embed.proj.weight' patch_embed_proj = state_dict[patch_embed_key] patch_size = ( model.module.visual.patch_embed.patch_size if hasattr(model, 'module') - else model.visual.patch_embed.patch_size + else ( + model.visual.patch_embed.patch_size + if hasattr(model, 'visual') + else model.patch_embed.patch_size + ) ) state_dict[patch_embed_key] = torch.nn.functional.interpolate(