diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..5b3f4e8
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,9 @@
+__pycache__
+
+.ipynb_checkpoints
+
+output/*
+!output/.gitkeep
+
+data/*
+!data/.gitkeep
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..e613aeb
--- /dev/null
+++ b/README.md
@@ -0,0 +1,97 @@
+# Heron - A Library for Vision/Video and Language models
+
+
+
+
+
+Welcome to "heron" repository. Heron is a library that seamlessly integrates multiple Vision and Language models, as well as Video and Language models. One of its standout features is its support for Japanese V&L models. Additionally, we provide pretrained weights trained on various datasets.
+
+
+# Installation
+1. Clone this repository
+```bash
+git clone https://github.com/turingmotors/heron
+cd heron
+```
+
+2. Install Packages
+```bash
+conda create -n git_llm python=3.10 -y
+conda activate git_llm
+pip install --upgrade pip # enable PEP 660 support
+
+pip install -r requirements.txt
+pip install -e .
+```
+
+## For Llama 2
+First, you request access to the llama-2 models, in [huggingface page](https://huggingface.co/meta-llama/Llama-2-7b) and [facebook website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/)
+
+Please sign-in the huggingface account
+```bash
+huggingface-cli login
+```
+
+# Training
+
+Now we support LLaMA, MPT, and OPT as a LLM module.
+
+```bash
+./scripts/run.sh
+```
+
+# Evaluation
+
+You can get the pretrained weight form HuggingFace Hub: [Inoichan/GIT-Llama-2-7B](https://huggingface.co/Inoichan/GIT-Llama-2-7B)
+See also [notebooks](./notebooks).
+
+```python
+import requests
+from transformers import AutoProcessor
+from git_llm.git_llama import GitLlamaForCausalLM
+
+device_id = 0
+
+# prepare a pretrained model
+model = GitLlamaForCausalLM.from_pretrained('Inoichan/GIT-Llama-2-7B')
+model.eval()
+model.to(f"cuda:{device_id}")
+
+# prepare a processor
+processor = AutoProcessor.from_pretrained('Inoichan/GIT-Llama-2-7B')
+
+# prepare inputs
+url = "https://www.barnorama.com/wp-content/uploads/2016/12/03-Confusing-Pictures.jpg"
+image = Image.open(requests.get(url, stream=True).raw)
+
+text = f"##Instruction: Please answer the following question concletely. ##Question: What is unusual about this image? Explain precisely and concletely what he is doing? ##Answer: "
+
+# do preprocessing
+inputs = processor(
+ text,
+ image,
+ return_tensors="pt",
+ truncation=True,
+)
+inputs = {k: v.to(f"cuda:{device_id}") for k, v in inputs.items()}
+
+# set eos token
+eos_token_id_list = [
+ processor.tokenizer.pad_token_id,
+ processor.tokenizer.eos_token_id,
+]
+
+# do inference
+with torch.no_grad():
+ out = model.generate(**inputs, max_length=256, do_sample=False, temperature=0., eos_token_id=eos_token_id_list)
+
+# print result
+print(processor.tokenizer.batch_decode(out))
+```
+
+# Acknoledge
+
+- [GenerativeImage2Text](https://github.com/microsoft/GenerativeImage2Text): The main idia of the model is based on original GIT.
+- [Llava](https://github.com/haotian-liu/LLaVA): This project is learned a lot from the great Llava project.
+- [GIT-LLM](https://github.com/Ino-Ichan/GIT-LLM)
+- [video_blip](https://github.com/kotarotanahashi/video_blip)
diff --git a/configs/datasets/m3it.yaml b/configs/datasets/m3it.yaml
new file mode 100644
index 0000000..acba611
--- /dev/null
+++ b/configs/datasets/m3it.yaml
@@ -0,0 +1,3 @@
+dataset_type: coco
+
+tarain_val: true
diff --git a/configs/deepspeed/ds_config_zero1.json b/configs/deepspeed/ds_config_zero1.json
new file mode 100644
index 0000000..e9200a5
--- /dev/null
+++ b/configs/deepspeed/ds_config_zero1.json
@@ -0,0 +1,48 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "scheduler": {
+ "type": "WarmupLR",
+ "params": {
+ "warmup_min_lr": "auto",
+ "warmup_max_lr": "auto",
+ "warmup_num_steps": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 1,
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": 1e7,
+ "stage3_prefetch_bucket_size": 1e7,
+ "stage3_param_persistence_threshold": 10240,
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 100,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/configs/deepspeed/ds_config_zero2.json b/configs/deepspeed/ds_config_zero2.json
new file mode 100644
index 0000000..79a93e1
--- /dev/null
+++ b/configs/deepspeed/ds_config_zero2.json
@@ -0,0 +1,48 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "scheduler": {
+ "type": "WarmupLR",
+ "params": {
+ "warmup_min_lr": "auto",
+ "warmup_max_lr": "auto",
+ "warmup_num_steps": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 2,
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": 1e7,
+ "stage3_prefetch_bucket_size": 1e7,
+ "stage3_param_persistence_threshold": 10240,
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 100,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/configs/deepspeed/ds_config_zero3.json b/configs/deepspeed/ds_config_zero3.json
new file mode 100644
index 0000000..a778176
--- /dev/null
+++ b/configs/deepspeed/ds_config_zero3.json
@@ -0,0 +1,56 @@
+{
+ "fp16": {
+ "enabled": "auto",
+ "loss_scale": 0,
+ "loss_scale_window": 1000,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": "auto",
+ "betas": "auto",
+ "eps": "auto",
+ "weight_decay": "auto"
+ }
+ },
+ "scheduler": {
+ "type": "WarmupLR",
+ "params": {
+ "warmup_min_lr": "auto",
+ "warmup_max_lr": "auto",
+ "warmup_num_steps": "auto"
+ }
+ },
+ "zero_optimization": {
+ "stage": 3,
+ "offload_optimizer": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "offload_param": {
+ "device": "cpu",
+ "pin_memory": true
+ },
+ "overlap_comm": true,
+ "contiguous_gradients": true,
+ "sub_group_size": 1e9,
+ "reduce_bucket_size": 1e7,
+ "stage3_prefetch_bucket_size": 1e7,
+ "stage3_param_persistence_threshold": 10240,
+ "stage3_max_live_parameters": 1e9,
+ "stage3_max_reuse_distance": 1e9,
+ "stage3_gather_16bit_weights_on_model_save": true
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "steps_per_print": 100,
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
\ No newline at end of file
diff --git a/configs/llama/training_config_exp050_llama.yml b/configs/llama/training_config_exp050_llama.yml
new file mode 100644
index 0000000..4b8e692
--- /dev/null
+++ b/configs/llama/training_config_exp050_llama.yml
@@ -0,0 +1,64 @@
+training:
+ per_device_train_batch_size: 2
+ gradient_accumulation_steps: 4
+ num_train_epochs: 1
+ dataloader_num_workers: 16
+ fp16: true
+ optim: "adamw_torch"
+ learning_rate: 5.0e-5
+ logging_steps: 100
+ evaluation_strategy: "steps"
+ save_strategy: "steps"
+ eval_steps: 4000
+ save_steps: 4000
+ save_total_limit: 1
+ deepspeed: configs/ds_config_zero1.json
+ output_dir: ./output/
+ report_to: "wandb"
+
+settings:
+ model_name: meta-llama/Llama-2-7b-chat-hf
+ vision_model_name: openai/clip-vit-base-patch16
+ num_image_with_embedding: # None or video sequence num
+ max_length: 512
+ keys_finetune:
+ - visual_projection
+ - num_image_with_embedding
+
+use_lora: true
+lora:
+ r: 8
+ lora_alpha: 32
+ target_modules:
+ - q_proj
+ - v_proj
+ lora_dropout: 0.01
+ bias: none
+ task_type: CAUSAL_LM
+
+dataset_type: path/to/config
+ - coco
+ - textcap
+ - image-paragraph-captioning
+ - coco-goi
+ - coco-text
+ - imagenet
+ - coco-itm
+ - snli-ve
+ - mocheg
+ - iqa
+ - vqa-v2
+ - shapes
+ - docvqa
+ - ocr-vqa
+ - st-vqa
+ - text-vqa
+ - gqa
+ - okvqa
+ - a-okvqa
+ - viquae
+ - clevr
+ - vcr
+ - visual-mrc
+ - visual-dialog
+ - multi30k
diff --git a/heron/__init__.py b/heron/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/heron/datasets/README.md b/heron/datasets/README.md
new file mode 100644
index 0000000..64f9c00
--- /dev/null
+++ b/heron/datasets/README.md
@@ -0,0 +1,17 @@
+# Datasets Description
+
+# Supported Datasets
+
+## English
+- [M3IT](https://huggingface.co/datasets/MMInstruction/M3IT)
+
+## Japanese
+- [STAIR](http://captions.stair.center/)
+- [Japanese Visual Genome VQA dataset](https://github.com/yahoojapan/ja-vg-vqa)
+
+### Preparing CSV files for Japanese STAIR/Visual Genome
+
+Download [data](../../data/) at data directory.
+For using Japanese dataset, please generate preprocessed csv files. See notebooks in [preprocess](./preprocess/).
+
+
diff --git a/heron/datasets/__init__.py b/heron/datasets/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/heron/datasets/base_datasets.py b/heron/datasets/base_datasets.py
new file mode 100644
index 0000000..41da241
--- /dev/null
+++ b/heron/datasets/base_datasets.py
@@ -0,0 +1,30 @@
+import abc
+
+from torch.utils.data import Dataset
+
+
+class BaseDataset(Dataset):
+
+ def __init__(self, is_inference: bool = False):
+ super(BaseDataset, self).__init__()
+ self.is_inference = is_inference
+
+ @abc.abstractmethod
+ @classmethod
+ def create(cls, *args, **kwargs):
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def __getitem__(self, index):
+ if self.is_inference:
+ return self._get_item_inference(index)
+ else:
+ return self._get_item_train(index)
+
+ @abc.abstractmethod
+ def _get_item_train(self, index):
+ raise NotImplementedError
+
+ @abc.abstractmethod
+ def _get_item_inference(self, index):
+ raise NotImplementedError
diff --git a/heron/datasets/ja_csv_datasets.py b/heron/datasets/ja_csv_datasets.py
new file mode 100644
index 0000000..dc8a831
--- /dev/null
+++ b/heron/datasets/ja_csv_datasets.py
@@ -0,0 +1,106 @@
+import cv2
+import datasets
+import numpy as np
+from PIL import Image
+from torch.utils.data import Dataset
+from transformers import (
+ AutoProcessor,
+ AutoTokenizer,
+ CLIPImageProcessor,
+ LlamaTokenizer,
+)
+
+
+class JapaneseCSVDataset(Dataset):
+ """Dataset for Custom Japanese CSV V&L Dataset learning
+ """
+
+ def __init__(
+ self,
+ model_name: str,
+ vision_model_name: str,
+ loaded_dataset: datasets.GeneratorBasedBuilder,
+ max_length: int = 128,
+ ):
+ super(JapaneseCSVDataset, self).__init__()
+ self.loaded_dataset = loaded_dataset
+ self.unique_img_path = loaded_dataset.img_path.unique()
+
+ self.max_length = max_length
+
+ self.processor = AutoProcessor.from_pretrained("microsoft/git-base")
+ self.processor.image_processor = CLIPImageProcessor.from_pretrained(vision_model_name)
+ if "japanese-stablelm" in model_name:
+ self.processor.tokenizer = LlamaTokenizer.from_pretrained(
+ "novelai/nerdstash-tokenizer-v1",
+ padding_side="right",
+ additional_special_tokens=["▁▁"],
+ )
+ elif (
+ "mpt" in model_name
+ or "matsuo-lab/weblab" in model_name
+ or "cyberagent/open-calm-7b" in model_name
+ ):
+ self.processor.tokenizer = AutoTokenizer.from_pretrained(
+ model_name, padding_side="right", use_fast=True
+ )
+ else:
+ self.processor.tokenizer = AutoTokenizer.from_pretrained(
+ model_name, padding_side="right", use_fast=False
+ )
+ if "llama" in model_name:
+ self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
+ elif "mpt" in model_name:
+ self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
+ elif "matsuo-lab/weblab" in model_name:
+ self.processor.tokenizer.add_special_tokens(
+ {
+ "bos_token": "<|endoftext|>",
+ "eos_token": "<|endoftext|>",
+ "pad_token": "<|padding|>",
+ "unk_token": "<|endoftext|>",
+ }
+ )
+
+ def __len__(self) -> int:
+ return len(self.unique_img_path)
+
+ def __getitem__(self, index) -> dict:
+ # cf: https://huggingface.co/datasets/MMInstruction/M3IT#data-instances
+ img_path = self.unique_img_path[index]
+
+ df_interest = self.loaded_dataset[self.loaded_dataset.img_path == img_path].reset_index(
+ drop=True
+ )
+ text = ""
+
+ # concatenate text data
+ for i in np.random.randint(0, len(df_interest), len(df_interest)):
+ row = df_interest.iloc[i]
+ # some of nlvr data were broken
+ question = row["question"] # str
+ answer = row["caption"] # str
+ text += f"##問: {question} ##答: {answer}。"
+
+ # remove final space
+ text = text[: len(text) - 1]
+
+ # imageのロード
+ img = Image.open(img_path).convert("RGB")
+ img = np.array(img)
+ if img.shape[2] != 3:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+
+ inputs = self.processor(
+ text,
+ img,
+ return_tensors="pt",
+ max_length=self.max_length,
+ padding="max_length",
+ truncation=True,
+ )
+
+ # batch size 1 -> unbatch
+ inputs = {k: v[0] for k, v in inputs.items()}
+ inputs["labels"] = inputs["input_ids"]
+ return inputs
diff --git a/heron/datasets/m3it_datasets.py b/heron/datasets/m3it_datasets.py
new file mode 100644
index 0000000..9f1b355
--- /dev/null
+++ b/heron/datasets/m3it_datasets.py
@@ -0,0 +1,97 @@
+from base64 import b64decode
+from io import BytesIO
+
+import cv2
+import datasets
+import numpy as np
+from PIL import Image
+from torch.utils.data import Dataset
+from transformers import (
+ AutoProcessor,
+ AutoTokenizer,
+ CLIPImageProcessor,
+ LlamaTokenizer,
+)
+
+
+class M3ITDataset(Dataset):
+ """Dataset for M3IT Dataset learning
+ """
+ def __init__(
+ self,
+ model_name: str,
+ vision_model_name: str,
+ loaded_dataset: datasets.GeneratorBasedBuilder,
+ max_length: int = 128,
+ ):
+ super(M3ITDataset, self).__init__()
+ self.loaded_dataset = loaded_dataset
+ self.max_length = max_length
+
+ self.processor = AutoProcessor.from_pretrained("microsoft/git-base")
+ self.processor.image_processor = CLIPImageProcessor.from_pretrained(vision_model_name)
+ if "japanese-stablelm" in model_name:
+ self.processor.tokenizer = LlamaTokenizer.from_pretrained(
+ "novelai/nerdstash-tokenizer-v1",
+ padding_side="right",
+ additional_special_tokens=["▁▁"],
+ )
+ elif (
+ "mpt" in model_name
+ or "matsuo-lab/weblab" in model_name
+ or "cyberagent/open-calm-7b" in model_name
+ ):
+ self.processor.tokenizer = AutoTokenizer.from_pretrained(
+ model_name, padding_side="right", use_fast=True
+ )
+ else:
+ self.processor.tokenizer = AutoTokenizer.from_pretrained(
+ model_name, padding_side="right", use_fast=False
+ )
+ if "llama" in model_name:
+ self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
+ elif "mpt" in model_name:
+ self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token
+ elif "matsuo-lab/weblab" in model_name:
+ self.processor.tokenizer.add_special_tokens(
+ {
+ "bos_token": "<|endoftext|>",
+ "eos_token": "<|endoftext|>",
+ "pad_token": "<|padding|>",
+ "unk_token": "<|endoftext|>",
+ }
+ )
+
+ def __len__(self) -> int:
+ return len(self.loaded_dataset)
+
+ def __getitem__(self, index) -> dict:
+ # cf: https://huggingface.co/datasets/MMInstruction/M3IT#data-instances
+ row = self.loaded_dataset[index]
+
+ # some of nlvr data were broken
+ instruction = row["instruction"] # str
+ question = row["inputs"] # str
+ answer = row["outputs"] # str
+ text = f"##Instruction: {instruction} ##Question: {question} ##Answer: {answer}"
+
+ # imageのロード
+ image_base64_str_list = row["image_base64_str"] # str (base64)
+ img = Image.open(BytesIO(b64decode(image_base64_str_list[0]))).convert("RGB")
+ img = np.array(img)
+ if img.shape[2] != 3:
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
+
+ inputs = self.processor(
+ text,
+ img,
+ return_tensors="pt",
+ max_length=self.max_length,
+ padding="max_length",
+ truncation=True,
+ )
+ # batch size 1 -> unbatch
+ inputs = {k: v[0] for k, v in inputs.items()}
+ inputs["labels"] = inputs["input_ids"]
+ return inputs
+
diff --git a/heron/datasets/preprocess/make_STAIR_csv.ipynb b/heron/datasets/preprocess/make_STAIR_csv.ipynb
new file mode 100644
index 0000000..e5ecced
--- /dev/null
+++ b/heron/datasets/preprocess/make_STAIR_csv.ipynb
@@ -0,0 +1,245 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "8a82402e-eea9-4227-8d08-4e4178895602",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import glob\n",
+ "import json\n",
+ "import os\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import tqdm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "4cf98eff-c7b9-4f6e-8c7f-781c52cf4378",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set path to STAIR data\n",
+ "# Images are download from here: https://cocodataset.org/#download\n",
+ "# Text data are download from here: https://github.com/STAIR-Lab-CIT/STAIR-captions\n",
+ "# Download data at data/coco\n",
+ "\n",
+ "PATH_TO_COCO = os.path.abspath(\"../../../../data/coco\") + \"/\"\n",
+ "PATH_TO_TRAIN_JSON = PATH_TO_COCO + \"stair_captions_v1.2_train.json\"\n",
+ "PATH_TO_VAL_JSON = PATH_TO_COCO + \"stair_captions_v1.2_val.json\"\n",
+ "\n",
+ "# Add some pseudo question's text\n",
+ "random_question_list = [\n",
+ " \"画像の内容を教えてください。\",\n",
+ " \"この画像を説明できますか?\",\n",
+ " \"画像に何が写っていますか?\",\n",
+ " \"画像の詳細を話してください。\",\n",
+ " \"画像に関する情報を共有して。\",\n",
+ " \"画像を解説してもらえますか?\",\n",
+ " \"この画像の主題は何ですか?\",\n",
+ " \"画像を簡潔に説明してください。\",\n",
+ " \"画像についての概要を教えて。\",\n",
+ " \"この画像に関する基本情報を話してください。\",\n",
+ " \"これは何の写真ですか?\",\n",
+ " \"写真には何が写っていますか?\",\n",
+ " \"写真について説明してください。\",\n",
+ " \"この写真はどういう状況ですか?説明してください。\",\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c2578a80-db4a-4be4-b1ee-43f2c718ab7d",
+ "metadata": {},
+ "source": [
+ "# STAIR / COCO"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "d3243048-4846-4ff2-8687-b6c80fc4796d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(PATH_TO_TRAIN_JSON, 'r') as f:\n",
+ " coco_train = json.load(f)\n",
+ "\n",
+ "with open(PATH_TO_VAL_JSON, 'r') as f:\n",
+ " coco_val = json.load(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "f6a23770-495b-4444-9665-825747b4d7c4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# annotations to pandas DataFrame\n",
+ "df_coco_train = pd.DataFrame(coco_train[\"annotations\"])\n",
+ "df_coco_val = pd.DataFrame(coco_val[\"annotations\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "7625e256-4cf0-44d7-9035-a70aef61525e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 413915/413915 [00:35<00:00, 11705.96it/s]\n",
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 202520/202520 [00:17<00:00, 11630.47it/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " img_path | \n",
+ " caption | \n",
+ " question | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " /home/y_inoue/coco/val2014/COCO_val2014_000000... | \n",
+ " 踏切の近くにワイナリーが開店している | \n",
+ " 写真について説明してください。 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " /home/y_inoue/coco/val2014/COCO_val2014_000000... | \n",
+ " 渋滞の中2人乗りのバイクが走っている | \n",
+ " 画像の内容を教えてください。 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " /home/y_inoue/coco/val2014/COCO_val2014_000000... | \n",
+ " 部屋の中に白い自転車が置いてある | \n",
+ " 画像の詳細を話してください。 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " /home/y_inoue/coco/val2014/COCO_val2014_000000... | \n",
+ " 街の道を白い馬車が走っている | \n",
+ " これは何の写真ですか? | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " /home/y_inoue/coco/val2014/COCO_val2014_000000... | \n",
+ " ビーチ近くの水際に鳥が羽を広げて停まっている | \n",
+ " 画像を解説してもらえますか? | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " img_path caption \\\n",
+ "0 /home/y_inoue/coco/val2014/COCO_val2014_000000... 踏切の近くにワイナリーが開店している \n",
+ "1 /home/y_inoue/coco/val2014/COCO_val2014_000000... 渋滞の中2人乗りのバイクが走っている \n",
+ "2 /home/y_inoue/coco/val2014/COCO_val2014_000000... 部屋の中に白い自転車が置いてある \n",
+ "3 /home/y_inoue/coco/val2014/COCO_val2014_000000... 街の道を白い馬車が走っている \n",
+ "4 /home/y_inoue/coco/val2014/COCO_val2014_000000... ビーチ近くの水際に鳥が羽を広げて停まっている \n",
+ "\n",
+ " question \n",
+ "0 写真について説明してください。 \n",
+ "1 画像の内容を教えてください。 \n",
+ "2 画像の詳細を話してください。 \n",
+ "3 これは何の写真ですか? \n",
+ "4 画像を解説してもらえますか? "
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "for target, df_coco in [\n",
+ " [\"train\", df_coco_train],\n",
+ " [\"val\", df_coco_val]\n",
+ "]:\n",
+ " img_path_list = []\n",
+ " caption_list = []\n",
+ " question_list = []\n",
+ "\n",
+ " for i in tqdm.tqdm(range(len(df_coco))):\n",
+ " row = df_coco.iloc[i]\n",
+ " image_id = row.image_id\n",
+ " img_path = PATH_TO_COCO + f\"{target}2014/COCO_{target}2014_{image_id:012}.jpg\"\n",
+ " if os.path.exists(img_path):\n",
+ " img_path_list.append(img_path)\n",
+ " caption_list.append(row.caption)\n",
+ " q_index = np.random.randint(len(random_question_list))\n",
+ " question_list.append(random_question_list[q_index])\n",
+ " else:\n",
+ " print(f\"Fail path: {img_path}\")\n",
+ " \n",
+ " df = pd.DataFrame({\n",
+ " \"img_path\": img_path_list,\n",
+ " \"caption\": caption_list,\n",
+ " \"question\": question_list,\n",
+ " })\n",
+ " \n",
+ " df.to_csv(PATH_TO_COCO + f\"df_{target}.csv\", index=False)\n",
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c8657132-6913-4a16-baa8-b34cfe49f8ec",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/heron/datasets/preprocess/make_japanese_visual_genome_csv.ipynb b/heron/datasets/preprocess/make_japanese_visual_genome_csv.ipynb
new file mode 100644
index 0000000..958dcd0
--- /dev/null
+++ b/heron/datasets/preprocess/make_japanese_visual_genome_csv.ipynb
@@ -0,0 +1,220 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "8a82402e-eea9-4227-8d08-4e4178895602",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import glob\n",
+ "import json\n",
+ "import os\n",
+ "\n",
+ "import pandas as pd\n",
+ "import tqdm"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "4cf98eff-c7b9-4f6e-8c7f-781c52cf4378",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set path to visual genome data\n",
+ "# Images are download from here: https://homes.cs.washington.edu/~ranjay/visualgenome/api.html\n",
+ "# Text data are download from here: https://github.com/yahoojapan/ja-vg-vqa\n",
+ "# Download data at data/visual_genome_ja\n",
+ "\n",
+ "PATH_TO_VISUAL_GENOME = os.path.abspath(\"../../../../data/visual_genome_ja\") + \"/\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "889eb7b5-d60d-4502-8603-56465ca7de0b",
+ "metadata": {},
+ "source": [
+ "# Japanese Visual Genome"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "5b017be4-2400-4fbc-b275-bd5032922a5e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(PATH_TO_VISUAL_GENOME + 'question_answers.json', 'r') as f:\n",
+ " v_g = json.load(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "f9160aaf-bc19-4aea-8e32-76aafeb422f8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Extract question/answer pairs\n",
+ "qas_list = []\n",
+ "for data in v_g:\n",
+ " qas_list.extend(data[\"qas\"])\n",
+ "d_vg = pd.DataFrame(qas_list)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "6235d3e0-df32-4836-a6dd-75e0ad8e4535",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 793664/793664 [01:01<00:00, 12947.91it/s]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " img_path | \n",
+ " caption | \n",
+ " question | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " /home/y_inoue/visual_genome_ja/VG_100K/2395966... | \n",
+ " オレンジ色 | \n",
+ " バイクの車体の色はどんな色をしていますか? | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " /home/y_inoue/visual_genome_ja/VG_100K/2395966... | \n",
+ " 白色 | \n",
+ " バイクを固定している器具の色はどんな色をしていますか? | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " /home/y_inoue/visual_genome_ja/VG_100K/2395966... | \n",
+ " 黒色 | \n",
+ " バイクが置かれている床は何色ですか? | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " /home/y_inoue/visual_genome_ja/VG_100K/2395966... | \n",
+ " 緑色 | \n",
+ " バイクの右側の通路の床は何色ですか? | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " /home/y_inoue/visual_genome_ja/VG_100K/2395966... | \n",
+ " 5つ | \n",
+ " バイクの左に落ちている箱側面に書かれたメニューはいくつですか? | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " img_path caption \\\n",
+ "0 /home/y_inoue/visual_genome_ja/VG_100K/2395966... オレンジ色 \n",
+ "1 /home/y_inoue/visual_genome_ja/VG_100K/2395966... 白色 \n",
+ "2 /home/y_inoue/visual_genome_ja/VG_100K/2395966... 黒色 \n",
+ "3 /home/y_inoue/visual_genome_ja/VG_100K/2395966... 緑色 \n",
+ "4 /home/y_inoue/visual_genome_ja/VG_100K/2395966... 5つ \n",
+ "\n",
+ " question \n",
+ "0 バイクの車体の色はどんな色をしていますか? \n",
+ "1 バイクを固定している器具の色はどんな色をしていますか? \n",
+ "2 バイクが置かれている床は何色ですか? \n",
+ "3 バイクの右側の通路の床は何色ですか? \n",
+ "4 バイクの左に落ちている箱側面に書かれたメニューはいくつですか? "
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "img_path_list = []\n",
+ "caption_list = []\n",
+ "question_list = []\n",
+ "\n",
+ "for i in tqdm.tqdm(range(len(d_vg))):\n",
+ " row = d_vg.iloc[i]\n",
+ "\n",
+ " image_id = row.image_id\n",
+ " img_path = PATH_TO_VISUAL_GENOME + f\"VG_100K/{image_id}.jpg\"\n",
+ "\n",
+ " if os.path.exists(img_path):\n",
+ " img_path_list.append(img_path)\n",
+ " caption_list.append(row.answer)\n",
+ " question_list.append(row.question)\n",
+ " else:\n",
+ " print(f\"Fail path: {img_path}\")\n",
+ "\n",
+ "df = pd.DataFrame({\n",
+ " \"img_path\": img_path_list,\n",
+ " \"caption\": caption_list,\n",
+ " \"question\": question_list,\n",
+ "})\n",
+ "\n",
+ "df.to_csv(PATH_TO_VISUAL_GENOME + f\"df_vg.csv\", index=False)\n",
+ "df.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5cb70772-220c-4cd8-a4a0-866eab65f987",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/heron/datasets/utils.py b/heron/datasets/utils.py
new file mode 100644
index 0000000..fd4e187
--- /dev/null
+++ b/heron/datasets/utils.py
@@ -0,0 +1,40 @@
+from typing import Union, Optional
+
+import datasets
+import pandas as pd
+from torch.utils.data import ConcatDataset, Dataset
+
+from .ja_csv_datasets import JapaneseCSVDataset
+from .m3it_datasets import M3ITDataset
+
+def get_dataset(config: dict, model_name: str, vision_model_name: str, max_length: int) -> Union[Dataset, Dataset]:
+ if config.get("dataset_type") == "m3it":
+ dataset_list = [
+ datasets.load_dataset("MMInstruction/M3IT", i) for i in config["dataset_names"]
+ ]
+ train_dataframe = ConcatDataset([d["train"] for d in dataset_list])
+ train_dataset = M3ITDataset(model_name, vision_model_name, train_dataframe, max_length)
+
+ # some dataset have no validation
+ val_dataset_list = []
+ for d in dataset_list:
+ try:
+ val_dataset_list.append(d["validation"])
+ except:
+ print(f"{d['train']._info.config_name} has no validation set.")
+ val_dataframe = ConcatDataset(val_dataset_list)
+ val_dataset = M3ITDataset(model_name, vision_model_name, val_dataframe, max_length)
+ elif config.get("dataset_type") == "japanese_csv":
+ df_train = pd.read_csv("./data/coco/df_train.csv")
+ df_val = pd.read_csv("./data/coco/df_val.csv")
+ df_vg = pd.read_csv("./data/visual_genome_ja/df_vg.csv")
+
+ train_dataframe = pd.concat([df_train, df_vg], axis=0, ignore_index=True)
+ train_dataset = JapaneseCSVDataset(model_name, vision_model_name, train_dataframe, max_length)
+
+ val_dataframe = df_val
+ val_dataset = JapaneseCSVDataset(model_name, vision_model_name, val_dataframe, max_length)
+ else:
+ raise ValueError(f"dataset_type: {config.get('dataset_type')} is not supported.")
+
+ return train_dataset, val_dataset
diff --git a/heron/models/__init__.py b/heron/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/heron/models/git_llm/__init__.py b/heron/models/git_llm/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/heron/models/git_llm/git_gpt_neox/__init__.py b/heron/models/git_llm/git_gpt_neox/__init__.py
new file mode 100644
index 0000000..4f669d9
--- /dev/null
+++ b/heron/models/git_llm/git_gpt_neox/__init__.py
@@ -0,0 +1 @@
+from .modeling_git_gpt_neox import GitGPTNeoXConfig, GitGPTNeoXForCausalLM, GitGPTNeoXModel
diff --git a/heron/models/git_llm/git_gpt_neox/modeling_git_gpt_neox.py b/heron/models/git_llm/git_gpt_neox/modeling_git_gpt_neox.py
new file mode 100644
index 0000000..bcace8e
--- /dev/null
+++ b/heron/models/git_llm/git_gpt_neox/modeling_git_gpt_neox.py
@@ -0,0 +1,527 @@
+"""PyTorch GIT GPTneoX model."""
+import copy
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import CrossEntropyLoss
+from transformers import (
+ CLIPVisionConfig,
+ CLIPVisionModel,
+ GPTNeoXConfig,
+ GPTNeoXForCausalLM,
+ GPTNeoXModel,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPooling,
+ CausalLMOutputWithPast,
+)
+from transformers.models.git.modeling_git import GitProjection
+
+
+class GitGPTNeoXConfig(GPTNeoXConfig):
+ model_type = "git_gpt_neox"
+
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.vision_config = CLIPVisionConfig()
+ self.num_image_with_embedding = None
+
+ def set_vision_configs(
+ self,
+ num_image_with_embedding: Union[int, None] = None,
+ vision_model_name: Union[str, None] = None,
+ ):
+ self.num_image_with_embedding = num_image_with_embedding
+ self.vision_model_name = vision_model_name
+ self.vision_config = CLIPVisionConfig.from_pretrained(vision_model_name)
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["vision_config"] = self.vision_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+ return output
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class GitGPTNeoXModel(GPTNeoXModel):
+ config_class = GitGPTNeoXConfig
+
+ def __init__(self, config: GPTNeoXConfig):
+ super(GitGPTNeoXModel, self).__init__(config)
+
+ # Git modules
+ self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
+ self.visual_projection = GitProjection(config)
+
+ if config.num_image_with_embedding is not None:
+ self.img_temporal_embedding = nn.ParameterList(
+ nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
+ for _ in range(config.num_image_with_embedding)
+ )
+
+ self.image_patch_tokens = int(
+ (config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1
+ )
+ if config.num_image_with_embedding is not None:
+ self.image_patch_tokens *= config.num_image_with_embedding
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # def get_input_embeddings(self):
+ # return self.decoder.embed_in
+
+ # def set_input_embeddings(self, value):
+ # self.decoder.embed_in = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def _generate_future_mask(
+ self, size: int, dtype: torch.dtype, device: torch.device
+ ) -> torch.Tensor:
+ # Default mask is for forward direction. Flip for backward direction.
+ mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
+ mask = mask.masked_fill(mask == 1, float("-inf"))
+ return mask
+
+ def create_attention_mask(
+ self,
+ tgt,
+ memory,
+ tgt_mask,
+ past_key_values_length,
+ memory_key_padding_mask=None,
+ ):
+ num_tgt = tgt.shape[1]
+ num_memory = memory.shape[1]
+ device = tgt.device
+ dtype = tgt.dtype
+ top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
+ top_right = torch.full(
+ (num_memory, num_tgt + past_key_values_length),
+ float("-inf"),
+ device=tgt.device,
+ dtype=dtype,
+ )
+ bottom_left = torch.zeros(
+ (num_tgt, num_memory),
+ dtype=dtype,
+ device=tgt_mask.device,
+ )
+
+ if past_key_values_length > 0:
+ tgt_mask = torch.zeros(
+ (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
+ dtype=dtype,
+ device=tgt_mask.device,
+ )
+
+ left = torch.cat((top_left, bottom_left), dim=0)
+ right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
+
+ full_attention_mask = torch.cat((left, right), dim=1)[None, :]
+
+ if memory_key_padding_mask is None:
+ memory_key_padding_mask = torch.full(
+ (memory.shape[0], memory.shape[1]), fill_value=False, device=device
+ )
+ # if it is False, it means valid. That is, it is not a padding
+ if memory_key_padding_mask.dtype != torch.bool:
+ raise ValueError("Memory key padding mask must be a boolean tensor.")
+ zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
+ zero_negative_infinity[memory_key_padding_mask] = float("-inf")
+ full_attention_mask = full_attention_mask.expand(
+ (
+ memory_key_padding_mask.shape[0],
+ num_memory + num_tgt,
+ num_memory + past_key_values_length + num_tgt,
+ )
+ )
+ full_attention_mask = full_attention_mask.clone()
+ origin_left = full_attention_mask[:, :, :num_memory]
+ update = zero_negative_infinity[:, None, :]
+ full_attention_mask[:, :, :num_memory] = origin_left + update
+
+ # add axis for multi-head
+ full_attention_mask = full_attention_mask[:, None, :, :]
+
+ return full_attention_mask
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_hidden_states
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ seq_length_with_past = seq_length
+
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ # GIT Vision Encoder part
+ projected_visual_features = None
+ if pixel_values is not None and past_key_values is None:
+ if pixel_values.ndim == 4:
+ # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
+ visual_features = self.image_encoder(pixel_values).last_hidden_state
+
+ elif pixel_values.ndim == 5:
+ # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
+ visual_features = []
+ for frame_idx in range(pixel_values.shape[1]):
+ visual_features_frame = self.image_encoder(
+ pixel_values[:, frame_idx, :, :]
+ ).last_hidden_state
+ visual_features_frame += self.img_temporal_embedding[frame_idx]
+ visual_features.append(visual_features_frame)
+
+ # finally, concatenate all features along sequence dimension
+ visual_features = torch.cat(visual_features, dim=1)
+ else:
+ raise ValueError("pixel_values must be of rank 4 or 5")
+
+ projected_visual_features = self.visual_projection(visual_features)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_in(input_ids)
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+
+ embedding_output = self.emb_dropout(inputs_embeds)
+
+ if projected_visual_features is None:
+ projected_visual_features = torch.zeros(
+ (embedding_output.shape[0], 0, embedding_output.shape[2]),
+ dtype=embedding_output.dtype,
+ device=embedding_output.device,
+ )
+
+ # Repeat visual features to match embedding batch size.
+ projected_visual_features = projected_visual_features.repeat(
+ embedding_output.size(0) // projected_visual_features.size(0), 1, 1
+ )
+
+ # concatenate patch token and text token embeddings
+ hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + projected_visual_features.shape[1] + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(
+ -1, seq_length + projected_visual_features.shape[1]
+ )
+ else:
+ position_ids = position_ids.view(
+ -1, seq_length + projected_visual_features.shape[1]
+ ).long()
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # By default, an additive causal mask is created
+ # for masking the future (one direction).
+ tgt_mask = self._generate_future_mask(
+ seq_length, embedding_output.dtype, embedding_output.device
+ )
+
+ # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
+ combined_attention_mask = self.create_attention_mask(
+ tgt=embedding_output,
+ memory=projected_visual_features,
+ tgt_mask=tgt_mask,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # if the user provides an attention mask, we add it to the default one
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(
+ attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
+ ).to(embedding_output.device)
+ if past_key_values_length > 0:
+ expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
+ else:
+ combined_attention_mask[
+ :, :, -input_shape[1] :, -input_shape[1] :
+ ] += expanded_attn_mask
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * self.config.num_hidden_layers)
+
+ # decoder layers
+ presents = () if use_cache else None
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for layer_past
+ return module(*inputs, use_cache, None, output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer),
+ hidden_states,
+ combined_attention_mask,
+ position_ids,
+ head_mask[i],
+ )
+ else:
+ outputs = layer(
+ hidden_states,
+ attention_mask=combined_attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask[i],
+ layer_past=layer_past,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+ if output_attentions:
+ all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, presents, all_hidden_states, all_attentions]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ )
+
+
+class GitGPTNeoXForCausalLM(GPTNeoXForCausalLM):
+ config_class = GitGPTNeoXConfig
+
+ def __init__(
+ self,
+ config,
+ ):
+ super(GitGPTNeoXForCausalLM, self).__init__(config)
+ self.gpt_neox = GitGPTNeoXModel(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.gpt_neox(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ pixel_values=pixel_values,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ logits = self.embed_out(sequence_output)
+
+ loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ num_image_tokens = self.gpt_neox.image_patch_tokens
+ shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ use_cache=None,
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ input_shape = input_ids.shape
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": kwargs.get("pixel_values", None),
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ def _reorder_cache(self, past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),
+ )
+ return reordered_past
diff --git a/heron/models/git_llm/git_japanese_stablelm_alpha/__init__.py b/heron/models/git_llm/git_japanese_stablelm_alpha/__init__.py
new file mode 100644
index 0000000..6d293ae
--- /dev/null
+++ b/heron/models/git_llm/git_japanese_stablelm_alpha/__init__.py
@@ -0,0 +1,10 @@
+from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig
+from .modeling_git_japanese_stablelm_alpha import (
+ GitJapaneseStableLMAlphaConfig,
+ GitJapaneseStableLMAlphaForCausalLM,
+ GitJapaneseStableLMAlphaModel,
+)
+from .modeling_japanese_stablelm_alpha import (
+ JapaneseStableLMAlphaForCausalLM,
+ JapaneseStableLMAlphaModel,
+)
diff --git a/heron/models/git_llm/git_japanese_stablelm_alpha/configuration_japanese_stablelm_alpha.py b/heron/models/git_llm/git_japanese_stablelm_alpha/configuration_japanese_stablelm_alpha.py
new file mode 100644
index 0000000..79f17e1
--- /dev/null
+++ b/heron/models/git_llm/git_japanese_stablelm_alpha/configuration_japanese_stablelm_alpha.py
@@ -0,0 +1,120 @@
+# coding=utf-8
+# Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" JapaneseStableLMAlpha model configuration"""
+
+from transformers import PretrainedConfig
+from transformers.utils import logging
+
+logger = logging.get_logger(__name__)
+
+STABLE_LM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
+
+
+class JapaneseStableLMAlphaConfig(PretrainedConfig):
+ r"""
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+
+ Args:
+ vocab_size (`int`, *optional*, defaults to 65536):
+ Vocabulary size of the JapaneseStableLMAlphaModel. Defines the number of different tokens that
+ can be represented by the `inputs_ids` passed when calling [`JapaneseStableLMAlphaModel`].
+ hidden_size (`int`, *optional*, defaults to 4096):
+ Dimension of the decoder layers and the pooler layer.
+ num_hidden_layers (`int`, *optional*, defaults to 32):
+ Number of hidden layers in the Transformer decoder.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ intermediate_size (`int`, *optional*, defaults to 16384):
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer decoder.
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
+ The non-linear activation function (function or string).
+ rotary_pct (`float`, *optional*, defaults to 0.25):
+ Percentage of hidden dimensions to allocate to rotary embeddings.
+ rotary_emb_base (`int`, *optional*, defaults to 10000)
+ Base for computing rotary embeddings frequency.
+ rotary_scale_base (`int`, *optional*, defaults to 512)
+ Base `scale` for computing XPos rotary embeddings scale.
+ classifier_dropout (`float`, *optional*, defaults to 0.1):
+ Argument used when doing token classification, used in the model
+ [`StableLMForTokenClassification`]. The dropout ratio for the hidden layer.
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
+ The maximum sequence length that this model might ever be used with.
+ Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
+ initializer_range (`float`, *optional*, defaults to 1e-5):
+ The standard deviation of the truncated_normal_initializer for initializing
+ all weight matrices.
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
+ The epsilon used by the layer normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions
+ (not used by all models). Only relevant if `config.is_decoder=True`.
+ use_parallel_residual (`bool`, *optional*, defaults to `True`):
+ Whether to use a "parallel" formulation in each Transformer layer,
+ which can provide a slight training speedup at large scales.
+ Example:
+
+ ```python
+ >>> from transformers import JapaneseStableLMAlphaConfig, JapaneseStableLMAlphaModel
+
+ >>> # Initializing a JapaneseStableLMAlpha style configuration
+ >>> configuration = JapaneseStableLMAlphaConfig()
+
+ >>> # Initializing a model (with random weights) from the style configuration
+ >>> model = JapaneseStableLMAlphaModel(configuration) # doctest: +SKIP
+
+ >>> # Accessing the model configuration
+ >>> configuration = model.config # doctest: +SKIP
+ ```"""
+
+ def __init__(
+ self,
+ vocab_size=65536,
+ hidden_size=4096,
+ num_hidden_layers=32,
+ num_attention_heads=32,
+ hidden_act="silu",
+ rotary_pct=0.25,
+ rotary_emb_base=10000,
+ rotary_scale_base=512,
+ classifier_dropout=0.1,
+ max_position_embeddings=2048,
+ initializer_range=0.02,
+ layer_norm_eps=1e-5,
+ use_cache=True,
+ bos_token_id=3,
+ eos_token_id=3,
+ tie_word_embeddings=False,
+ use_parallel_residual=True,
+ use_bias_in_mlp=True,
+ **kwargs,
+ ):
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.rotary_pct = rotary_pct
+ self.rotary_emb_base = rotary_emb_base
+ self.rotary_scale_base = rotary_scale_base
+ self.classifier_dropout = classifier_dropout
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ self.use_cache = use_cache
+ self.tie_word_embeddings = tie_word_embeddings
+ self.use_parallel_residual = use_parallel_residual
+ self.use_bias_in_mlp = use_bias_in_mlp
diff --git a/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_git_japanese_stablelm_alpha.py b/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_git_japanese_stablelm_alpha.py
new file mode 100644
index 0000000..8852390
--- /dev/null
+++ b/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_git_japanese_stablelm_alpha.py
@@ -0,0 +1,527 @@
+"""PyTorch GIT Jpanese StableLM alpha model."""
+import copy
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import CrossEntropyLoss
+from transformers import CLIPVisionConfig, CLIPVisionModel
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPooling,
+ CausalLMOutputWithPast,
+)
+from transformers.models.git.modeling_git import GitProjection
+
+from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig
+from .modeling_japanese_stablelm_alpha import (
+ JapaneseStableLMAlphaForCausalLM,
+ JapaneseStableLMAlphaModel,
+)
+
+
+class GitJapaneseStableLMAlphaConfig(JapaneseStableLMAlphaConfig):
+ model_type = "git_japanese_stablelm_alpha"
+
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.vision_config = CLIPVisionConfig()
+ self.num_image_with_embedding = None
+
+ def set_vision_configs(
+ self,
+ num_image_with_embedding: Union[int, None] = None,
+ vision_model_name: Union[str, None] = None,
+ ):
+ self.num_image_with_embedding = num_image_with_embedding
+ self.vision_model_name = vision_model_name
+ self.vision_config = CLIPVisionConfig.from_pretrained(vision_model_name)
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["vision_config"] = self.vision_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+ return output
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class GitJapaneseStableLMAlphaModel(JapaneseStableLMAlphaModel):
+ config_class = GitJapaneseStableLMAlphaConfig
+
+ def __init__(self, config: JapaneseStableLMAlphaConfig):
+ super(GitJapaneseStableLMAlphaModel, self).__init__(config)
+
+ # Git modules
+ self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
+ self.visual_projection = GitProjection(config)
+
+ if config.num_image_with_embedding is not None:
+ self.img_temporal_embedding = nn.ParameterList(
+ nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
+ for _ in range(config.num_image_with_embedding)
+ )
+
+ self.image_patch_tokens = int(
+ (config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1
+ )
+ if config.num_image_with_embedding is not None:
+ self.image_patch_tokens *= config.num_image_with_embedding
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # def get_input_embeddings(self):
+ # return self.decoder.embed_in
+
+ # def set_input_embeddings(self, value):
+ # self.decoder.embed_in = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def _generate_future_mask(
+ self, size: int, dtype: torch.dtype, device: torch.device
+ ) -> torch.Tensor:
+ # Default mask is for forward direction. Flip for backward direction.
+ mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
+ mask = mask.masked_fill(mask == 1, float("-inf"))
+ return mask
+
+ def create_attention_mask(
+ self,
+ tgt,
+ memory,
+ tgt_mask,
+ past_key_values_length,
+ memory_key_padding_mask=None,
+ ):
+ num_tgt = tgt.shape[1]
+ num_memory = memory.shape[1]
+ device = tgt.device
+ dtype = tgt.dtype
+ top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
+ top_right = torch.full(
+ (num_memory, num_tgt + past_key_values_length),
+ float("-inf"),
+ device=tgt.device,
+ dtype=dtype,
+ )
+ bottom_left = torch.zeros(
+ (num_tgt, num_memory),
+ dtype=dtype,
+ device=tgt_mask.device,
+ )
+
+ if past_key_values_length > 0:
+ tgt_mask = torch.zeros(
+ (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
+ dtype=dtype,
+ device=tgt_mask.device,
+ )
+
+ left = torch.cat((top_left, bottom_left), dim=0)
+ right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
+
+ full_attention_mask = torch.cat((left, right), dim=1)[None, :]
+
+ if memory_key_padding_mask is None:
+ memory_key_padding_mask = torch.full(
+ (memory.shape[0], memory.shape[1]), fill_value=False, device=device
+ )
+ # if it is False, it means valid. That is, it is not a padding
+ if memory_key_padding_mask.dtype != torch.bool:
+ raise ValueError("Memory key padding mask must be a boolean tensor.")
+ zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
+ zero_negative_infinity[memory_key_padding_mask] = float("-inf")
+ full_attention_mask = full_attention_mask.expand(
+ (
+ memory_key_padding_mask.shape[0],
+ num_memory + num_tgt,
+ num_memory + past_key_values_length + num_tgt,
+ )
+ )
+ full_attention_mask = full_attention_mask.clone()
+ origin_left = full_attention_mask[:, :, :num_memory]
+ update = zero_negative_infinity[:, None, :]
+ full_attention_mask[:, :, :num_memory] = origin_left + update
+
+ # add axis for multi-head
+ full_attention_mask = full_attention_mask[:, None, :, :]
+
+ return full_attention_mask
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_hidden_states
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ seq_length_with_past = seq_length
+
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ # GIT Vision Encoder part
+ projected_visual_features = None
+ if pixel_values is not None and past_key_values is None:
+ if pixel_values.ndim == 4:
+ # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
+ visual_features = self.image_encoder(pixel_values).last_hidden_state
+
+ elif pixel_values.ndim == 5:
+ # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
+ visual_features = []
+ for frame_idx in range(pixel_values.shape[1]):
+ visual_features_frame = self.image_encoder(
+ pixel_values[:, frame_idx, :, :]
+ ).last_hidden_state
+ visual_features_frame += self.img_temporal_embedding[frame_idx]
+ visual_features.append(visual_features_frame)
+
+ # finally, concatenate all features along sequence dimension
+ visual_features = torch.cat(visual_features, dim=1)
+ else:
+ raise ValueError("pixel_values must be of rank 4 or 5")
+
+ projected_visual_features = self.visual_projection(visual_features)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_in(input_ids)
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+
+ embedding_output = inputs_embeds
+
+ if projected_visual_features is None:
+ projected_visual_features = torch.zeros(
+ (embedding_output.shape[0], 0, embedding_output.shape[2]),
+ dtype=embedding_output.dtype,
+ device=embedding_output.device,
+ )
+
+ # Repeat visual features to match embedding batch size.
+ projected_visual_features = projected_visual_features.repeat(
+ embedding_output.size(0) // projected_visual_features.size(0), 1, 1
+ )
+
+ # concatenate patch token and text token embeddings
+ hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + projected_visual_features.shape[1] + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(
+ -1, seq_length + projected_visual_features.shape[1]
+ )
+ else:
+ position_ids = position_ids.view(
+ -1, seq_length + projected_visual_features.shape[1]
+ ).long()
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # By default, an additive causal mask is created
+ # for masking the future (one direction).
+ tgt_mask = self._generate_future_mask(
+ seq_length, embedding_output.dtype, embedding_output.device
+ )
+
+ # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
+ combined_attention_mask = self.create_attention_mask(
+ tgt=embedding_output,
+ memory=projected_visual_features,
+ tgt_mask=tgt_mask,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # if the user provides an attention mask, we add it to the default one
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(
+ attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
+ ).to(embedding_output.device)
+ if past_key_values_length > 0:
+ expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
+ else:
+ combined_attention_mask[
+ :, :, -input_shape[1] :, -input_shape[1] :
+ ] += expanded_attn_mask
+
+ if past_key_values is None:
+ past_key_values = tuple([None] * self.config.num_hidden_layers)
+
+ # decoder layers
+ presents = () if use_cache else None
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for layer_past
+ return module(*inputs, use_cache, None, output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer),
+ hidden_states,
+ combined_attention_mask,
+ position_ids,
+ head_mask[i],
+ )
+ else:
+ outputs = layer(
+ hidden_states,
+ attention_mask=combined_attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask[i],
+ layer_past=layer_past,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+ if output_attentions:
+ all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, presents, all_hidden_states, all_attentions]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ )
+
+
+class GitJapaneseStableLMAlphaForCausalLM(JapaneseStableLMAlphaForCausalLM):
+ config_class = GitJapaneseStableLMAlphaConfig
+
+ def __init__(
+ self,
+ config,
+ ):
+ super(GitJapaneseStableLMAlphaForCausalLM, self).__init__(config)
+ self.transformer = GitJapaneseStableLMAlphaModel(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ pixel_values=pixel_values,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ logits = self.embed_out(sequence_output)
+
+ loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ num_image_tokens = self.transformer.image_patch_tokens
+ shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ use_cache=None,
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ input_shape = input_ids.shape
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": kwargs.get("pixel_values", None),
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ def _reorder_cache(self, past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),
+ )
+ return reordered_past
diff --git a/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_japanese_stablelm_alpha.py b/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_japanese_stablelm_alpha.py
new file mode 100644
index 0000000..a9a042e
--- /dev/null
+++ b/heron/models/git_llm/git_japanese_stablelm_alpha/modeling_japanese_stablelm_alpha.py
@@ -0,0 +1,712 @@
+# coding=utf-8
+# Copyright 2023 Stability and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch JapaneseStableLMAlpha model. """
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+
+from .configuration_japanese_stablelm_alpha import JapaneseStableLMAlphaConfig
+
+logger = logging.get_logger(__name__)
+
+
+class JapaneseStableLMAlphaPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = JapaneseStableLMAlphaConfig
+ base_model_prefix = "transformer"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["DecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if module.weight is not None:
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, JapaneseStableLMAlphaModel):
+ module.gradient_checkpointing = value
+
+
+class JapaneseStableLMAlphaModel(JapaneseStableLMAlphaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.config = config
+
+ self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
+ self.layers = nn.ModuleList(
+ [DecoderLayer(config) for _ in range(config.num_hidden_layers)]
+ )
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_in
+
+ def set_input_embeddings(self, value):
+ self.embed_in = value
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = tuple([None] * self.config.num_hidden_layers)
+ else:
+ past_length = past_key_values[0][0].size(-2)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_length, seq_length + past_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # Attention mask.
+ if attention_mask is not None:
+ assert batch_size > 0, "batch_size has to be defined and > 0"
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and the dtype's smallest value for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_in(input_ids)
+
+ hidden_states = inputs_embeds
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ presents = () if use_cache else None
+ all_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+ for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for layer_past
+ return module(*inputs, use_cache, None, output_attentions)
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer),
+ hidden_states,
+ attention_mask,
+ position_ids,
+ head_mask[i],
+ )
+ else:
+ outputs = layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask[i],
+ layer_past=layer_past,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+ if output_attentions:
+ all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
+
+ hidden_states = self.final_layer_norm(hidden_states)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, presents, all_hidden_states, all_attentions]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_attentions,
+ )
+
+
+class DecoderLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.use_parallel_residual = config.use_parallel_residual
+ self.input_layernorm = nn.LayerNorm(
+ config.hidden_size,
+ eps=config.layer_norm_eps,
+ elementwise_affine=False,
+ )
+ self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.attention = Attention(config)
+ self.mlp = MLP(config)
+
+ def forward(
+ self,
+ hidden_states: Optional[torch.FloatTensor],
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ ):
+ attention_layer_outputs = self.attention(
+ self.input_layernorm(hidden_states),
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ layer_past=layer_past,
+ head_mask=head_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+ attn_output = attention_layer_outputs[
+ 0
+ ] # output_attn: attn_output, present, (attn_weights)
+ outputs = attention_layer_outputs[1:]
+
+ mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
+ hidden_states = hidden_states + mlp_output + attn_output
+
+ if use_cache:
+ outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
+ else:
+ outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
+
+ return outputs
+
+
+class MLP(nn.Module):
+ def __init__(self, config: JapaneseStableLMAlphaConfig):
+ super().__init__()
+ hidden_size = config.hidden_size
+ multiple_of = 256
+ ff_dim = int(8 * hidden_size / 3)
+ intermediate_size = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
+
+ self.packed_input_proj = torch.nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
+ self.out_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
+ self.act = nn.SiLU()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ ff, ff_gate = self.packed_input_proj(x).chunk(2, dim=-1)
+ return self.out_proj(ff * self.act(ff_gate))
+
+
+class RotaryEmbedding(torch.nn.Module):
+ """Based on Tri Dao's XPos: https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/layers/rotary.py"""
+
+ def __init__(
+ self,
+ dim: int,
+ max_position_embeddings: int,
+ base: int = 10_000,
+ scale_base: int = 512,
+ device: str = None,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.seq_len_cached = max_position_embeddings
+
+ # Set up `inv_freq` term
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
+ )
+ self.register_buffer("inv_freq", inv_freq)
+
+ # Set up `scale` term
+ self.scale_base = scale_base
+ scale = (
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
+ if scale_base is not None
+ else None
+ )
+ self.register_buffer("scale", scale)
+
+ # Seet up `cos..` and `sin...` cache terms
+ t = torch.arange(self.seq_len_cached, device=device, dtype=torch.float32)
+ freqs = torch.outer(t, self.inv_freq)
+ # freqs = torch.cat((freqs, freqs), dim=-1)
+ seq_range = torch.arange(
+ self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device
+ )
+ power = (seq_range - self.seq_len_cached // 2) / self.scale_base
+ scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1)
+ # scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
+ self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False)
+ self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False)
+ self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False)
+ self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False)
+
+ def forward(self, x, seq_len=None):
+ if seq_len > self.seq_len_cached:
+ self.seq_len_cached = seq_len
+ t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
+ freqs = torch.outer(t, self.inv_freq)
+ freqs = torch.cat((freqs, freqs), dim=-1)
+ seq_range = torch.arange(
+ self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device
+ )
+ power = (seq_range - self.seq_len_cached // 2) / self.scale_base
+ scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1)
+ scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
+ self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False)
+ self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False)
+ self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False)
+ self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False)
+ return (
+ self.cos_cached[:seq_len, ...],
+ self.sin_cached[:seq_len, ...],
+ self.cos_k_cached[:seq_len, ...],
+ self.sin_k_cached[:seq_len, ...],
+ )
+
+
+def rotate_half(x):
+ x1, x2 = x.chunk(2, dim=-1)
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, cos_k=None, sin_k=None):
+ """
+ q, k: [bs, num_heads, seq_len, rot_dim]
+ cos, sin: [seq_len, rot_dim / 2]
+ position_ids: [bs, seq_len]
+ """
+ # print(f"q: {q.shape}, k: {k.shape}, cos: {cos.shape}, sin: {sin.shape}, position_ids: {position_ids.shape}")
+ import einops
+
+ cos = einops.repeat(cos, "s r -> s (2 r)")
+ sin = einops.repeat(sin, "s r -> s (2 r)")
+ cos_k = einops.repeat(cos_k, "s r -> s (2 r)")
+ sin_k = einops.repeat(sin_k, "s r -> s (2 r)")
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
+ cos_k = cos_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
+ sin_k = sin_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim]
+
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos_k) + (rotate_half(k) * sin_k)
+ return q_embed, k_embed
+
+
+class Attention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.num_attention_heads = config.num_attention_heads
+ self.hidden_size = config.hidden_size
+ if self.hidden_size % self.num_attention_heads != 0:
+ raise ValueError(
+ "The hidden size is not divisble by the number of attention heads! Make sure to update them"
+ )
+ self.head_size = self.hidden_size // self.num_attention_heads
+
+ max_positions = config.max_position_embeddings
+ self.register_buffer(
+ "bias",
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
+ 1, 1, max_positions, max_positions
+ ),
+ persistent=False,
+ )
+ self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
+
+ self.rotary_ndims = int(self.head_size * config.rotary_pct)
+ self.rotary_emb = RotaryEmbedding(
+ self.rotary_ndims,
+ max_position_embeddings=config.max_position_embeddings,
+ base=config.rotary_emb_base,
+ scale_base=config.rotary_scale_base,
+ )
+
+ self.register_buffer(
+ "norm_factor",
+ torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(
+ torch.get_default_dtype()
+ ),
+ persistent=False,
+ )
+
+ self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
+ self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ position_ids: torch.LongTensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ):
+ has_layer_past = layer_past is not None
+
+ # Compute QKV
+ # Attention heads [batch, seq_len, hidden_size]
+ # --> [batch, seq_len, (np * 3 * head_size)]
+ qkv = self.query_key_value(hidden_states)
+
+ # [batch, seq_len, (num_heads * 3 * head_size)]
+ # --> [batch, seq_len, num_heads, 3 * head_size]
+ new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
+ qkv = qkv.view(*new_qkv_shape)
+
+ # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
+ query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
+ key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
+ value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
+
+ # Compute rotary embeddings on rotary_ndims
+ query_rot = query[..., : self.rotary_ndims]
+ query_pass = query[..., self.rotary_ndims :]
+ key_rot = key[..., : self.rotary_ndims]
+ key_pass = key[..., self.rotary_ndims :]
+
+ # Compute token offset for rotary embeddings (when decoding)
+ kv_seq_len = key.shape[-2]
+ if has_layer_past:
+ kv_seq_len += layer_past[0].shape[-2]
+
+ # Add rotary embeddings to query and key
+ # TODO: Check if using xpos
+ cos, sin, cos_k, sin_k = self.rotary_emb(value, seq_len=kv_seq_len)
+ query, key = apply_rotary_pos_emb(
+ query_rot, key_rot, cos, sin, position_ids, cos_k=cos_k, sin_k=sin_k
+ )
+
+ query = torch.cat((query, query_pass), dim=-1)
+ key = torch.cat((key, key_pass), dim=-1)
+
+ # Cache QKV values
+ if has_layer_past:
+ past_key = layer_past[0]
+ past_value = layer_past[1]
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
+ present = (key, value) if use_cache else None
+
+ # Compute attention
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
+
+ # Merge attn_head_size dim and num_attn_heads dim into hidden dim
+ # [bs, seq_len, num_attention_heads, attn_head_size]
+ attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
+ attn_output = attn_output.view(
+ attn_output.size(0), attn_output.size(1), self.num_attention_heads * self.head_size
+ )
+
+ attn_output = self.dense(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
+ # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
+ # compute causal mask from causal mask buffer
+
+ batch_size, num_attention_heads, query_length, attn_head_size = query.size()
+ key_length = key.size(-2)
+
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+
+ query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
+ key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
+ attn_scores = torch.zeros(
+ batch_size * num_attention_heads,
+ query_length,
+ key_length,
+ dtype=query.dtype,
+ device=key.device,
+ )
+ attn_scores = torch.baddbmm(
+ attn_scores,
+ query,
+ key.transpose(1, 2),
+ beta=1.0,
+ alpha=(
+ torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device)
+ / self.norm_factor
+ ),
+ )
+ attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
+
+ mask_value = torch.finfo(attn_scores.dtype).min
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
+ mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype, device=attn_scores.device)
+ attn_scores = torch.where(causal_mask, attn_scores, mask_value)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_scores = attn_scores + attention_mask
+
+ # NOTE: Upcast to float32
+ attn_weights = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).type_as(
+ value
+ )
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = torch.matmul(attn_weights, value)
+ return attn_output, attn_weights
+
+
+def attention_mask_func(attention_scores, ltor_mask):
+ attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
+ return attention_scores
+
+
+class JapaneseStableLMAlphaForCausalLM(JapaneseStableLMAlphaPreTrainedModel):
+ _tied_weights_keys = ["embed_out.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.transformer = JapaneseStableLMAlphaModel(config)
+ self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.embed_out
+
+ def set_output_embeddings(self, new_embeddings):
+ self.embed_out = new_embeddings
+
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Example:
+
+ ```python
+ >>> import torch
+ >>> from transformers import LlamaTokenizer, JapaneseStableLMAlphaForCausalLM, JapaneseStableLMAlphaConfig
+
+ >>> tokenizer = LlamaTokenizer.from_pretrained("novelai/nerdstash-tokenizer-v1")
+ >>> config = JapaneseStableLMAlphaConfig.from_pretrained("stabilityai/stablelm-ja-base-alpha-7b")
+ >>> config.is_decoder = True
+ >>> model = JapaneseStableLMAlphaForCausalLM.from_pretrained("stabilityai/stablelm-ja-base-alpha-7b", config=config, trust_remote_code=True)
+
+ >>> inputs = tokenizer("日本語の美しいところは、", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ lm_logits = self.embed_out(hidden_states)
+
+ lm_loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(lm_logits.device)
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shift_logits = lm_logits[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ input_shape = input_ids.shape
+
+ # cut decoder_input_ids if past is used
+ if past_key_values and past_key_values[0] is not None:
+ input_ids = input_ids[:, -1:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "attention_mask": attention_mask,
+ "past_key_values": past_key_values,
+ "position_ids": position_ids,
+ }
+ )
+
+ return model_inputs
+
+ def _reorder_cache(self, past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2])
+ + layer_past[2:],
+ )
+ return reordered_past
diff --git a/heron/models/git_llm/git_llama/__init__.py b/heron/models/git_llm/git_llama/__init__.py
new file mode 100644
index 0000000..68c23ef
--- /dev/null
+++ b/heron/models/git_llm/git_llama/__init__.py
@@ -0,0 +1 @@
+from .modeling_git_llama import GitLlamaConfig, GitLlamaForCausalLM, GitLlamaModel
diff --git a/heron/models/git_llm/git_llama/modeling_git_llama.py b/heron/models/git_llm/git_llama/modeling_git_llama.py
new file mode 100644
index 0000000..955717a
--- /dev/null
+++ b/heron/models/git_llm/git_llama/modeling_git_llama.py
@@ -0,0 +1,519 @@
+"""PyTorch GIT OPT model."""
+import copy
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import CrossEntropyLoss
+from transformers import (
+ CLIPVisionConfig,
+ CLIPVisionModel,
+ LlamaConfig,
+ LlamaForCausalLM,
+ LlamaModel,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPooling,
+ CausalLMOutputWithPast,
+)
+from transformers.models.git.modeling_git import GitProjection
+
+
+class GitLlamaConfig(LlamaConfig):
+ model_type = "git_llama"
+
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.vision_config = CLIPVisionConfig()
+ self.num_image_with_embedding = None
+
+ def set_vision_configs(
+ self,
+ num_image_with_embedding: Union[int, None] = None,
+ vision_model_name: Union[str, None] = None,
+ ):
+ self.num_image_with_embedding = num_image_with_embedding
+ self.vision_model_name = vision_model_name
+ self.vision_config = CLIPVisionConfig.from_pretrained(vision_model_name)
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["vision_config"] = self.vision_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+ return output
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class GitLlamaModel(LlamaModel):
+ config_class = GitLlamaConfig
+
+ def __init__(self, config: LlamaConfig):
+ super(GitLlamaModel, self).__init__(config)
+
+ # Git modules
+ self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
+ self.visual_projection = GitProjection(config)
+
+ if config.num_image_with_embedding is not None:
+ self.img_temporal_embedding = nn.ParameterList(
+ nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
+ for _ in range(config.num_image_with_embedding)
+ )
+
+ self.image_patch_tokens = int(
+ (config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1
+ )
+ if config.num_image_with_embedding is not None:
+ self.image_patch_tokens *= config.num_image_with_embedding
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.decoder.embed_tokens = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def _generate_future_mask(
+ self, size: int, dtype: torch.dtype, device: torch.device
+ ) -> torch.Tensor:
+ # Default mask is for forward direction. Flip for backward direction.
+ mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
+ mask = mask.masked_fill(mask == 1, float("-inf"))
+ return mask
+
+ def create_attention_mask(
+ self,
+ tgt,
+ memory,
+ tgt_mask,
+ past_key_values_length,
+ memory_key_padding_mask=None,
+ ):
+ num_tgt = tgt.shape[1]
+ num_memory = memory.shape[1]
+ device = tgt.device
+ dtype = tgt.dtype
+ top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
+ top_right = torch.full(
+ (num_memory, num_tgt + past_key_values_length),
+ float("-inf"),
+ device=tgt.device,
+ dtype=dtype,
+ )
+ bottom_left = torch.zeros(
+ (num_tgt, num_memory),
+ dtype=dtype,
+ device=tgt_mask.device,
+ )
+
+ if past_key_values_length > 0:
+ tgt_mask = torch.zeros(
+ (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
+ dtype=dtype,
+ device=tgt_mask.device,
+ )
+
+ left = torch.cat((top_left, bottom_left), dim=0)
+ right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
+
+ full_attention_mask = torch.cat((left, right), dim=1)[None, :]
+
+ if memory_key_padding_mask is None:
+ memory_key_padding_mask = torch.full(
+ (memory.shape[0], memory.shape[1]), fill_value=False, device=device
+ )
+ # if it is False, it means valid. That is, it is not a padding
+ if memory_key_padding_mask.dtype != torch.bool:
+ raise ValueError("Memory key padding mask must be a boolean tensor.")
+ zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
+ zero_negative_infinity[memory_key_padding_mask] = float("-inf")
+ full_attention_mask = full_attention_mask.expand(
+ (
+ memory_key_padding_mask.shape[0],
+ num_memory + num_tgt,
+ num_memory + past_key_values_length + num_tgt,
+ )
+ )
+ full_attention_mask = full_attention_mask.clone()
+ origin_left = full_attention_mask[:, :, :num_memory]
+ update = zero_negative_infinity[:, None, :]
+ full_attention_mask[:, :, :num_memory] = origin_left + update
+
+ # add axis for multi-head
+ full_attention_mask = full_attention_mask[:, None, :, :]
+
+ return full_attention_mask
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_hidden_states
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ seq_length_with_past = seq_length
+
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ # GIT Vision Encoder part
+ projected_visual_features = None
+ if pixel_values is not None and past_key_values is None:
+ if pixel_values.ndim == 4:
+ # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
+ visual_features = self.image_encoder(pixel_values).last_hidden_state
+
+ elif pixel_values.ndim == 5:
+ # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
+ visual_features = []
+ for frame_idx in range(pixel_values.shape[1]):
+ visual_features_frame = self.image_encoder(
+ pixel_values[:, frame_idx, :, :]
+ ).last_hidden_state
+ visual_features_frame += self.img_temporal_embedding[frame_idx]
+ visual_features.append(visual_features_frame)
+
+ # finally, concatenate all features along sequence dimension
+ visual_features = torch.cat(visual_features, dim=1)
+ else:
+ raise ValueError("pixel_values must be of rank 4 or 5")
+
+ projected_visual_features = self.visual_projection(visual_features)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+
+ embedding_output = inputs_embeds
+
+ if projected_visual_features is None:
+ projected_visual_features = torch.zeros(
+ (embedding_output.shape[0], 0, embedding_output.shape[2]),
+ dtype=embedding_output.dtype,
+ device=embedding_output.device,
+ )
+
+ # Repeat visual features to match embedding batch size.
+ projected_visual_features = projected_visual_features.repeat(
+ embedding_output.size(0) // projected_visual_features.size(0), 1, 1
+ )
+
+ # concatenate patch token and text token embeddings
+ hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + projected_visual_features.shape[1] + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ position_ids = position_ids.unsqueeze(0).view(
+ -1, seq_length + projected_visual_features.shape[1]
+ )
+ else:
+ position_ids = position_ids.view(
+ -1, seq_length + projected_visual_features.shape[1]
+ ).long()
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # By default, an additive causal mask is created
+ # for masking the future (one direction).
+ tgt_mask = self._generate_future_mask(
+ seq_length, embedding_output.dtype, embedding_output.device
+ )
+
+ # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
+ combined_attention_mask = self.create_attention_mask(
+ tgt=embedding_output,
+ memory=projected_visual_features,
+ tgt_mask=tgt_mask,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # if the user provides an attention mask, we add it to the default one
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(
+ attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
+ ).to(embedding_output.device)
+ if past_key_values_length > 0:
+ expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
+ else:
+ combined_attention_mask[
+ :, :, -input_shape[1] :, -input_shape[1] :
+ ] += expanded_attn_mask
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ combined_attention_mask,
+ position_ids,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=combined_attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class GitLlamaForCausalLM(LlamaForCausalLM):
+ config_class = GitLlamaConfig
+
+ def __init__(
+ self,
+ config,
+ ):
+ super(GitLlamaForCausalLM, self).__init__(config)
+ self.model = GitLlamaModel(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ logits = self.lm_head(sequence_output)
+
+ loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ num_image_tokens = self.model.image_patch_tokens
+ shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ use_cache=None,
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ input_shape = input_ids.shape
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": kwargs.get("pixel_values", None),
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ def _reorder_cache(self, past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),
+ )
+ return reordered_past
diff --git a/heron/models/git_llm/git_mpt/__init__.py b/heron/models/git_llm/git_mpt/__init__.py
new file mode 100644
index 0000000..9a36f12
--- /dev/null
+++ b/heron/models/git_llm/git_mpt/__init__.py
@@ -0,0 +1,2 @@
+# from .git_opt_trainer import GitOPTTrainer
+from .modeling_git_mpt import GitMptConfig, GitMptForCausalLM, GitMptModel
diff --git a/heron/models/git_llm/git_mpt/modeling_git_mpt.py b/heron/models/git_llm/git_mpt/modeling_git_mpt.py
new file mode 100644
index 0000000..8c1fdc7
--- /dev/null
+++ b/heron/models/git_llm/git_mpt/modeling_git_mpt.py
@@ -0,0 +1,533 @@
+"""PyTorch GIT MPT model.
+This codes is based on modeling_mpt.py in transformers.
+See: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py
+"""
+import copy
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import CrossEntropyLoss
+from transformers import CLIPVisionConfig, CLIPVisionModel, MptConfig, MptForCausalLM, MptModel
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPooling,
+ CausalLMOutputWithPast,
+)
+from transformers.models.git.modeling_git import GitProjection
+
+
+class GitMptConfig(MptConfig):
+ model_type = "git_mpt"
+
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.vision_config = CLIPVisionConfig()
+ self.num_image_with_embedding = None
+
+ def set_vision_configs(
+ self,
+ num_image_with_embedding: Union[int, None] = None,
+ vision_model_name: Union[str, None] = None,
+ ):
+ self.num_image_with_embedding = num_image_with_embedding
+ self.vision_model_name = vision_model_name
+ self.vision_config = CLIPVisionConfig.from_pretrained(vision_model_name)
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["attn_config"] = output["attn_config"].to_dict()
+ output["vision_config"] = self.vision_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+ return output
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class GitMptModel(MptModel):
+ config_class = GitMptConfig
+
+ def __init__(self, config: MptConfig):
+ super(GitMptModel, self).__init__(config)
+
+ # Git modules
+ self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
+ self.visual_projection = GitProjection(config)
+
+ if config.num_image_with_embedding is not None:
+ self.img_temporal_embedding = nn.ParameterList(
+ nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
+ for _ in range(config.num_image_with_embedding)
+ )
+
+ self.image_patch_tokens = int(
+ (config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1
+ )
+ if config.num_image_with_embedding is not None:
+ self.image_patch_tokens *= config.num_image_with_embedding
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, value):
+ self.wte = value
+
+ def _generate_future_mask(
+ self, size: int, dtype: torch.dtype, device: torch.device
+ ) -> torch.Tensor:
+ # Default mask is for forward direction. Flip for backward direction.
+ mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
+ mask = mask.masked_fill(mask == 1, float("-inf"))
+ return mask
+
+ def create_attention_mask(
+ self,
+ tgt,
+ memory,
+ tgt_mask,
+ past_key_values_length,
+ memory_key_padding_mask=None,
+ ):
+ num_tgt = tgt.shape[1]
+ num_memory = memory.shape[1]
+ device = tgt.device
+ dtype = tgt.dtype
+ top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
+ top_right = torch.full(
+ (num_memory, num_tgt + past_key_values_length),
+ float("-inf"),
+ device=tgt.device,
+ dtype=dtype,
+ )
+ bottom_left = torch.zeros(
+ (num_tgt, num_memory),
+ dtype=dtype,
+ device=tgt_mask.device,
+ )
+
+ if past_key_values_length > 0:
+ tgt_mask = torch.zeros(
+ (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
+ dtype=dtype,
+ device=tgt_mask.device,
+ )
+
+ left = torch.cat((top_left, bottom_left), dim=0)
+ right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
+
+ full_attention_mask = torch.cat((left, right), dim=1)[None, :]
+
+ if memory_key_padding_mask is None:
+ memory_key_padding_mask = torch.full(
+ (memory.shape[0], memory.shape[1]), fill_value=False, device=device
+ )
+ # if it is False, it means valid. That is, it is not a padding
+ if memory_key_padding_mask.dtype != torch.bool:
+ raise ValueError("Memory key padding mask must be a boolean tensor.")
+ zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
+ zero_negative_infinity[memory_key_padding_mask] = float("-inf")
+ full_attention_mask = full_attention_mask.expand(
+ (
+ memory_key_padding_mask.shape[0],
+ num_memory + num_tgt,
+ num_memory + past_key_values_length + num_tgt,
+ )
+ )
+ full_attention_mask = full_attention_mask.clone()
+ origin_left = full_attention_mask[:, :, :num_memory]
+ update = zero_negative_infinity[:, None, :]
+ full_attention_mask[:, :, :num_memory] = origin_left + update
+
+ # add axis for multi-head
+ full_attention_mask = full_attention_mask[:, None, :, :]
+
+ return full_attention_mask
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_hidden_states
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ seq_length_with_past = seq_length
+
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ # GIT Vision Encoder part
+ projected_visual_features = None
+ if pixel_values is not None and past_key_values is None:
+ if pixel_values.ndim == 4:
+ # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
+ visual_features = self.image_encoder(pixel_values).last_hidden_state
+
+ elif pixel_values.ndim == 5:
+ # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
+ visual_features = []
+ for frame_idx in range(pixel_values.shape[1]):
+ visual_features_frame = self.image_encoder(
+ pixel_values[:, frame_idx, :, :]
+ ).last_hidden_state
+ visual_features_frame += self.img_temporal_embedding[frame_idx]
+ visual_features.append(visual_features_frame)
+
+ # finally, concatenate all features along sequence dimension
+ visual_features = torch.cat(visual_features, dim=1)
+ else:
+ raise ValueError("pixel_values must be of rank 4 or 5")
+
+ projected_visual_features = self.visual_projection(visual_features)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
+ )
+
+ embedding_output = inputs_embeds
+
+ if projected_visual_features is None:
+ projected_visual_features = torch.zeros(
+ (embedding_output.shape[0], 0, embedding_output.shape[2]),
+ dtype=embedding_output.dtype,
+ device=embedding_output.device,
+ )
+
+ # Repeat visual features to match embedding batch size.
+ projected_visual_features = projected_visual_features.repeat(
+ embedding_output.size(0) // projected_visual_features.size(0), 1, 1
+ )
+
+ # concatenate patch token and text token embeddings
+ hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # By default, an additive causal mask is created
+ # for masking the future (one direction).
+ tgt_mask = self._generate_future_mask(
+ seq_length, embedding_output.dtype, embedding_output.device
+ )
+
+ # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
+ combined_attention_mask = self.create_attention_mask(
+ tgt=embedding_output,
+ memory=projected_visual_features,
+ tgt_mask=tgt_mask,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # if the user provides an attention mask, we add it to the default one
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(
+ attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
+ ).to(embedding_output.device)
+ if past_key_values_length > 0:
+ expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
+ else:
+ combined_attention_mask[
+ :, :, -input_shape[1] :, -input_shape[1] :
+ ] += expanded_attn_mask
+
+ # MPT mask should be bool, mask pos is True, un-mask pos is False
+ # See: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py#L59
+ combined_attention_mask = torch.where(combined_attention_mask == 0, False, True)
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ presents = () if use_cache else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # Compute alibi tensor: check build_alibi_tensor documentation
+ alibi = self.build_mpt_alibi_tensor(
+ self.num_heads, self.config.max_seq_len, device=hidden_states.device
+ )
+
+ for idx, block in enumerate(self.blocks):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ layer_past = past_key_values[idx] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(
+ *inputs, use_cache=use_cache, output_attentions=output_attentions
+ )
+
+ return custom_forward
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ alibi,
+ combined_attention_mask,
+ layer_past,
+ )
+ else:
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=combined_attention_mask,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ position_bias=alibi,
+ )
+
+ hidden_states = outputs[0]
+
+ if use_cache:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ # Add last hidden state
+ hidden_states = self.norm_f(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions]
+ if v is not None
+ )
+
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class GitMptForCausalLM(MptForCausalLM):
+ config_class = GitMptConfig
+
+ def __init__(
+ self,
+ config,
+ ):
+ super(GitMptForCausalLM, self).__init__(config)
+ self.transformer = GitMptModel(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ pixel_values=pixel_values,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ logits = self.lm_head(sequence_output)
+
+ loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ num_image_tokens = self.transformer.image_patch_tokens
+ shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ use_cache=None,
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ input_shape = input_ids.shape
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": kwargs.get("pixel_values", None),
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ def _reorder_cache(
+ self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
+ """
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
+ beam_idx at every generation step.
+
+ Output shares the same memory storage as `past`.
+ """
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
+ device_to_beam_idx = {
+ past_state.device: beam_idx.to(past_state.device)
+ for layer_past in past
+ for past_state in layer_past
+ }
+ reordered_past = tuple(
+ (
+ layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
+ layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
+ )
+ for layer_past in past
+ )
+ return reordered_past
diff --git a/heron/models/git_llm/git_opt/__init__.py b/heron/models/git_llm/git_opt/__init__.py
new file mode 100644
index 0000000..9c96f44
--- /dev/null
+++ b/heron/models/git_llm/git_opt/__init__.py
@@ -0,0 +1,2 @@
+# from .git_opt_trainer import GitOPTTrainer
+from .modeling_git_opt import GitOPTConfig, GitOPTForCausalLM, GitOPTModel
diff --git a/heron/models/git_llm/git_opt/modeling_git_opt.py b/heron/models/git_llm/git_opt/modeling_git_opt.py
new file mode 100644
index 0000000..e7d6710
--- /dev/null
+++ b/heron/models/git_llm/git_opt/modeling_git_opt.py
@@ -0,0 +1,535 @@
+"""PyTorch GIT OPT model."""
+import copy
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn import CrossEntropyLoss
+from transformers import CLIPVisionConfig, CLIPVisionModel, OPTConfig, OPTForCausalLM, OPTModel
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ BaseModelOutputWithPooling,
+ CausalLMOutputWithPast,
+)
+from transformers.models.git.modeling_git import GitProjection
+from transformers.models.opt.modeling_opt import OPTLearnedPositionalEmbedding
+
+
+class GitOPTConfig(OPTConfig):
+ model_type = "git_opt"
+
+ def __init__(
+ self,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.vision_config = CLIPVisionConfig()
+ self.num_image_with_embedding = None
+
+ def set_vision_configs(
+ self,
+ num_image_with_embedding: Union[int, None] = None,
+ vision_model_name: Union[str, None] = None,
+ ):
+ self.num_image_with_embedding = num_image_with_embedding
+ self.vision_model_name = vision_model_name
+ self.vision_config = CLIPVisionConfig.from_pretrained(vision_model_name)
+
+ def to_dict(self):
+ """
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ output = copy.deepcopy(self.__dict__)
+ output["vision_config"] = self.vision_config.to_dict()
+ output["model_type"] = self.__class__.model_type
+ return output
+
+
+# Copied from transformers.models.bart.modeling_bart._expand_mask
+def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
+ """
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
+ """
+ bsz, src_len = mask.size()
+ tgt_len = tgt_len if tgt_len is not None else src_len
+
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
+
+ inverted_mask = 1.0 - expanded_mask
+
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
+
+
+class GitOPTModel(OPTModel):
+ config_class = GitOPTConfig
+
+ def __init__(self, config: OPTConfig):
+ super(GitOPTModel, self).__init__(config)
+
+ # Git modules
+ self.image_encoder = CLIPVisionModel.from_pretrained(config.vision_model_name)
+ self.visual_projection = GitProjection(config)
+
+ if config.num_image_with_embedding is not None:
+ self.img_temporal_embedding = nn.ParameterList(
+ nn.Parameter(torch.zeros(1, 1, config.vision_config.hidden_size))
+ for _ in range(config.num_image_with_embedding)
+ )
+
+ self.image_patch_tokens = int(
+ (config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1
+ )
+ if config.num_image_with_embedding is not None:
+ self.image_patch_tokens *= config.num_image_with_embedding
+
+ self.embed_positions = OPTLearnedPositionalEmbedding(
+ config.max_position_embeddings, config.hidden_size
+ )
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.decoder.embed_tokens = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ def _generate_future_mask(
+ self, size: int, dtype: torch.dtype, device: torch.device
+ ) -> torch.Tensor:
+ # Default mask is for forward direction. Flip for backward direction.
+ mask = torch.triu(torch.ones(size, size, device=device, dtype=dtype), diagonal=1)
+ mask = mask.masked_fill(mask == 1, float("-inf"))
+ return mask
+
+ def create_attention_mask(
+ self,
+ tgt,
+ memory,
+ tgt_mask,
+ past_key_values_length,
+ memory_key_padding_mask=None,
+ ):
+ num_tgt = tgt.shape[1]
+ num_memory = memory.shape[1]
+ device = tgt.device
+ dtype = tgt.dtype
+ top_left = torch.zeros((num_memory, num_memory), device=device, dtype=dtype)
+ top_right = torch.full(
+ (num_memory, num_tgt + past_key_values_length),
+ float("-inf"),
+ device=tgt.device,
+ dtype=dtype,
+ )
+ bottom_left = torch.zeros(
+ (num_tgt, num_memory),
+ dtype=dtype,
+ device=tgt_mask.device,
+ )
+
+ if past_key_values_length > 0:
+ tgt_mask = torch.zeros(
+ (tgt_mask.shape[0], tgt_mask.shape[0] + past_key_values_length),
+ dtype=dtype,
+ device=tgt_mask.device,
+ )
+
+ left = torch.cat((top_left, bottom_left), dim=0)
+ right = torch.cat((top_right, tgt_mask.to(dtype)), dim=0)
+
+ full_attention_mask = torch.cat((left, right), dim=1)[None, :]
+
+ if memory_key_padding_mask is None:
+ memory_key_padding_mask = torch.full(
+ (memory.shape[0], memory.shape[1]), fill_value=False, device=device
+ )
+ # if it is False, it means valid. That is, it is not a padding
+ if memory_key_padding_mask.dtype != torch.bool:
+ raise ValueError("Memory key padding mask must be a boolean tensor.")
+ zero_negative_infinity = torch.zeros_like(memory_key_padding_mask, dtype=tgt.dtype)
+ zero_negative_infinity[memory_key_padding_mask] = float("-inf")
+ full_attention_mask = full_attention_mask.expand(
+ (
+ memory_key_padding_mask.shape[0],
+ num_memory + num_tgt,
+ num_memory + past_key_values_length + num_tgt,
+ )
+ )
+ full_attention_mask = full_attention_mask.clone()
+ origin_left = full_attention_mask[:, :, :num_memory]
+ update = zero_negative_infinity[:, None, :]
+ full_attention_mask[:, :, :num_memory] = origin_left + update
+
+ # add axis for multi-head
+ full_attention_mask = full_attention_mask[:, None, :, :]
+
+ return full_attention_mask
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
+ r"""
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:"""
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_hidden_states
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+
+ # past_key_values_length
+ past_key_values_length = (
+ past_key_values[0][0].shape[2] if past_key_values is not None else 0
+ )
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ projected_visual_features = None
+ if pixel_values is not None and past_key_values is None:
+ if pixel_values.ndim == 4:
+ # here we assume pixel_values is of shape (batch_size, num_channels, height, width)
+ visual_features = self.image_encoder(pixel_values).last_hidden_state
+
+ elif pixel_values.ndim == 5:
+ # here we assume pixel_values is of shape (batch_size, num_frames, num_channels, height, width)
+ visual_features = []
+ for frame_idx in range(pixel_values.shape[1]):
+ visual_features_frame = self.image_encoder(
+ pixel_values[:, frame_idx, :, :]
+ ).last_hidden_state
+ visual_features_frame += self.img_temporal_embedding[frame_idx]
+ visual_features.append(visual_features_frame)
+
+ # finally, concatenate all features along sequence dimension
+ visual_features = torch.cat(visual_features, dim=1)
+ else:
+ raise ValueError("pixel_values must be of rank 4 or 5")
+
+ projected_visual_features = self.visual_projection(visual_features)
+
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py#L634-L658
+ inputs_embeds = self.decoder.embed_tokens(input_ids)
+
+ # required mask seq length can be calculated via length of past
+ mask_seq_length = past_key_values_length + seq_length
+ if past_key_values is not None:
+ mask_seq_length = mask_seq_length - self.image_patch_tokens
+ past_key_values_length = past_key_values_length - self.image_patch_tokens
+
+ # embed positions
+ if attention_mask is None:
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+ elif attention_mask.shape[1] != mask_seq_length:
+ raise ValueError(
+ f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be "
+ f"{mask_seq_length} (sum of the lengths of current and past inputs)"
+ )
+ pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
+
+ if self.decoder.project_in is not None:
+ inputs_embeds = self.decoder.project_in(inputs_embeds)
+
+ embedding_output = inputs_embeds + pos_embeds
+
+ if projected_visual_features is None:
+ projected_visual_features = torch.zeros(
+ (embedding_output.shape[0], 0, embedding_output.shape[2]),
+ dtype=embedding_output.dtype,
+ device=embedding_output.device,
+ )
+
+ # Repeat visual features to match embedding batch size.
+ projected_visual_features = projected_visual_features.repeat(
+ embedding_output.size(0) // projected_visual_features.size(0), 1, 1
+ )
+
+ # concatenate patch token and text token embeddings
+ hidden_states = torch.cat((projected_visual_features, embedding_output), dim=1)
+
+ # By default, an additive causal mask is created
+ # for masking the future (one direction).
+ tgt_mask = self._generate_future_mask(
+ seq_length, embedding_output.dtype, embedding_output.device
+ )
+
+ # for full sequence (w/ image patch tokens)
+ if past_key_values is not None:
+ past_key_values_length = past_key_values_length + self.image_patch_tokens
+
+ # Create an attention mask of shape (batch_size, 1, tgt_seq_len, src_seq_len)
+ combined_attention_mask = self.create_attention_mask(
+ tgt=embedding_output,
+ memory=projected_visual_features,
+ tgt_mask=tgt_mask,
+ past_key_values_length=past_key_values_length,
+ )
+
+ if attention_mask is not None:
+ # if the user provides an attention mask, we add it to the default one
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ expanded_attn_mask = _expand_mask(
+ attention_mask, embedding_output.dtype, tgt_len=input_shape[-1]
+ ).to(embedding_output.device)
+ if past_key_values_length > 0:
+ expanded_attn_mask = expanded_attn_mask[:, :, -past_key_values_length:, :]
+ else:
+ combined_attention_mask[
+ :, :, -input_shape[1] :, -input_shape[1] :
+ ] += expanded_attn_mask
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.decoder.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ # dropout_probability = random.uniform(0, 1)
+ # if self.training and (dropout_probability < self.layerdrop):
+ # continue
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ if self.decoder.gradient_checkpointing and self.decoder.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ # None for past_key_value
+ return module(*inputs, output_attentions, None)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(decoder_layer),
+ hidden_states,
+ combined_attention_mask,
+ head_mask[idx] if head_mask is not None else None,
+ None,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=combined_attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if self.decoder.final_layer_norm is not None:
+ hidden_states = self.decoder.final_layer_norm(hidden_states)
+
+ if self.decoder.project_out is not None:
+ hidden_states = self.decoder.project_out(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
+ if v is not None
+ )
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class GitOPTForCausalLM(OPTForCausalLM):
+ config_class = GitOPTConfig
+
+ def __init__(
+ self,
+ config,
+ ):
+ super(GitOPTForCausalLM, self).__init__(config)
+ self.model = GitOPTModel(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ pixel_values: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.Tensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ pixel_values=pixel_values,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ logits = self.lm_head(sequence_output)
+
+ loss = None
+ if labels is not None:
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ num_image_tokens = self.model.image_patch_tokens
+ shifted_logits = logits[:, num_image_tokens:-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(shifted_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ use_cache=None,
+ **kwargs,
+ ):
+ # cut decoder_input_ids if past_key_values is used
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ input_shape = input_ids.shape
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "pixel_values": kwargs.get("pixel_values", None),
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ }
+
+ def _reorder_cache(self, past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),
+ )
+ return reordered_past
diff --git a/heron/models/prepare_processors.py b/heron/models/prepare_processors.py
new file mode 100644
index 0000000..e69de29
diff --git a/heron/models/utils.py b/heron/models/utils.py
new file mode 100644
index 0000000..33680b9
--- /dev/null
+++ b/heron/models/utils.py
@@ -0,0 +1,135 @@
+import glob
+from typing import Any, Optional
+
+import numpy as np
+import torch
+from peft import LoraConfig, get_peft_config, get_peft_model
+
+from .git_llm.git_japanese_stablelm_alpha import (
+ GitJapaneseStableLMAlphaConfig,
+ GitJapaneseStableLMAlphaForCausalLM,
+)
+from .git_llm.git_llama import GitLlamaConfig, GitLlamaForCausalLM
+from .git_llm.git_mpt import GitMptConfig, GitMptForCausalLM
+from .git_llm.git_opt import GitOPTConfig, GitOPTForCausalLM
+from .git_llm.git_gpt_neox import GitGPTNeoXConfig, GitGPTNeoXForCausalLM
+
+GitLLMForCausalLM = Any
+
+
+def load_model(
+ model_name: str,
+ vision_model_name: str,
+ num_image_with_embedding: Optional[int],
+ is_fp16: bool,
+) -> GitLLMForCausalLM:
+ """Loading a GIT-LLM depending on configs"""
+ # set dtype
+ if is_fp16:
+ torch_dtype = torch.float16
+ else:
+ torch_dtype = torch.float32
+
+ if "opt" in model_name:
+ git_config = GitOPTConfig.from_pretrained(model_name)
+ git_config.set_vision_configs(
+ num_image_with_embedding=num_image_with_embedding, vision_model_name=vision_model_name
+ )
+ model = GitOPTForCausalLM.from_pretrained(
+ model_name, config=git_config, torch_dtype=torch_dtype
+ )
+ elif "llama" in model_name:
+ git_config = GitLlamaConfig.from_pretrained(model_name)
+ git_config.set_vision_configs(
+ num_image_with_embedding=num_image_with_embedding, vision_model_name=vision_model_name
+ )
+ model = GitLlamaForCausalLM.from_pretrained(
+ model_name, config=git_config, torch_dtype=torch_dtype
+ )
+ elif "mpt" in model_name:
+ git_config = GitMptConfig.from_pretrained(model_name)
+ git_config.set_vision_configs(
+ num_image_with_embedding=num_image_with_embedding, vision_model_name=vision_model_name
+ )
+ model = GitMptForCausalLM.from_pretrained(
+ model_name, config=git_config, torch_dtype=torch_dtype
+ )
+ elif "japanese-stablelm" in model_name:
+ git_config = GitJapaneseStableLMAlphaConfig.from_pretrained(model_name)
+ git_config.set_vision_configs(
+ num_image_with_embedding=num_image_with_embedding, vision_model_name=vision_model_name
+ )
+ model = GitJapaneseStableLMAlphaForCausalLM.from_pretrained(
+ model_name, config=git_config, torch_dtype=torch_dtype
+ )
+ elif (
+ "line-corporation/japanese-large-lm" in model_name
+ or "matsuo-lab/weblab" in model_name
+ or "cyberagent/open-calm-7b" in model_name
+ ):
+ git_config = GitGPTNeoXConfig.from_pretrained(model_name)
+ git_config.set_vision_configs(
+ num_image_with_embedding=num_image_with_embedding, vision_model_name=vision_model_name
+ )
+ model = GitGPTNeoXForCausalLM.from_pretrained(
+ model_name, config=git_config, torch_dtype=torch_dtype
+ )
+ return model
+
+
+def load_pretrained_weight(model: GitLLMForCausalLM, weight_path: str):
+
+ weight = {}
+ weight_path = glob.glob(f"{weight_path}/pytorch*.bin")
+ for w in weight_path:
+ weight_temp = torch.load(w, map_location="cpu")
+ weight.update(weight_temp)
+ model.load_state_dict(weight, strict=False)
+
+
+def apply_lora_model(model: GitLLMForCausalLM, model_name: str, config: dict) -> GitLLMForCausalLM:
+ """Apply LoRA"""
+ peft_config = LoraConfig(**config["lora"])
+ # apply lora only to LLM
+ if "opt" in model_name:
+ model.model.decoder = get_peft_model(model.model.decoder, peft_config)
+ elif "llama" in model_name:
+ target_modules = []
+ for m in peft_config.target_modules:
+ target_modules += [
+ f"model.layers.{i}.self_attn.{m}" for i in range(len(model.model.layers))
+ ]
+
+ peft_config.target_modules = target_modules
+ model = get_peft_model(model, peft_config)
+ model.base_model.model.lm_head = model.lm_head
+ # remove peft wrapper
+ model = model.base_model.model
+ elif "mpt" in model_name:
+ model = get_peft_model(model, peft_config)
+ model.base_model.model.lm_head = model.lm_head
+ # remove peft wrapper
+ model = model.base_model.model
+ elif (
+ "japanese-stablelm" in model_name
+ or "line-corporation/japanese-large-lm" in model_name
+ or "matsuo-lab/weblab" in model_name
+ or "cyberagent/open-calm-7b" in model_name
+ ):
+ model = get_peft_model(model, peft_config)
+ model.base_model.model.embed_out = model.embed_out
+ # remove peft wrapper
+ model = model.base_model.model
+ return model
+
+
+def set_trainable_params(model: GitLLMForCausalLM, model_name: str, keys_finetune: list) -> None:
+ trainable_list = []
+ untrainable_list = []
+ for name, p in model.named_parameters():
+ if np.any([k in name for k in keys_finetune]):
+ p.requires_grad = True
+ trainable_list.append(name)
+ else:
+ p.requires_grad = False
+ untrainable_list.append(name)
diff --git a/images/heron_image.png b/images/heron_image.png
new file mode 100644
index 0000000..9e8cc18
Binary files /dev/null and b/images/heron_image.png differ
diff --git a/notebooks/inference.ipynb b/notebooks/inference.ipynb
new file mode 100644
index 0000000..f65eab3
--- /dev/null
+++ b/notebooks/inference.ipynb
@@ -0,0 +1,1423 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "0b513f29-6f5f-4524-82f7-04177b835c3b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "exp = \"exp050_llama\"\n",
+ "device_id = 0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "3b156306-373d-41a9-82b6-4f0e85c737e7",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[2023-08-09 03:11:29,783] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import glob\n",
+ "import os\n",
+ "import sys\n",
+ "from base64 import b64decode\n",
+ "from io import BytesIO\n",
+ "from PIL import Image\n",
+ "\n",
+ "from transformers import (\n",
+ " AutoTokenizer,\n",
+ " CLIPImageProcessor,\n",
+ " AutoProcessor,\n",
+ " TrainingArguments,\n",
+ " Trainer,\n",
+ " AutoModelForCausalLM\n",
+ ")\n",
+ "import datasets\n",
+ "import torch\n",
+ "from torch.utils.data import Dataset\n",
+ "import yaml\n",
+ "import deepspeed\n",
+ "import fire\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import japanize_matplotlib\n",
+ "from peft import LoraConfig, get_peft_config, get_peft_model\n",
+ "\n",
+ "\n",
+ "from git_llm.git_opt import GitOPTForCausalLM, GitOPTConfig\n",
+ "from git_llm.git_llama import GitLlamaForCausalLM, GitLlamaConfig"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "7b5cb602-f80d-4881-9408-6b6be3369a33",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sys.path.append('..')\n",
+ "from train import load_model, apply_lora_model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "87dd43fe-0005-4db2-a3f5-744f6be7fc2c",
+ "metadata": {},
+ "source": [
+ "# Functions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "535b178f-f679-434a-872b-5e53a0236b6f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def generate_text(model, data):\n",
+ " eos_token_id_list = [\n",
+ " supervised_test_dataset.processor.tokenizer.pad_token_id,\n",
+ " supervised_test_dataset.processor.tokenizer.eos_token_id,\n",
+ " ]\n",
+ " with torch.no_grad():\n",
+ " out = model.generate(**data, max_length=256, do_sample=False, temperature=0., eos_token_id=eos_token_id_list)\n",
+ " return supervised_test_dataset.processor.tokenizer.batch_decode(out)\n",
+ "\n",
+ "# SupervisedDataset\n",
+ "class SupervisedDataset(Dataset):\n",
+ " \"\"\"Dataset for supervised learning\"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " model_name: str,\n",
+ " vision_model_name: str,\n",
+ " loaded_dataset: datasets.GeneratorBasedBuilder,\n",
+ " max_length: int = 128,\n",
+ " ):\n",
+ " super(SupervisedDataset, self).__init__()\n",
+ " self.loaded_dataset = loaded_dataset\n",
+ " self.max_length = max_length\n",
+ "\n",
+ " self.processor = AutoProcessor.from_pretrained(\"microsoft/git-base\")\n",
+ " self.processor.image_processor = CLIPImageProcessor.from_pretrained(vision_model_name)\n",
+ " self.processor.tokenizer = AutoTokenizer.from_pretrained(\n",
+ " model_name, padding_side=\"right\", use_fast=False\n",
+ " )\n",
+ " if \"llama\" in model_name:\n",
+ " self.processor.tokenizer.pad_token = self.processor.tokenizer.eos_token\n",
+ "\n",
+ " def __len__(self) -> int:\n",
+ " return len(self.loaded_dataset)\n",
+ "\n",
+ " def __getitem__(self, index) -> dict:\n",
+ " # cf: https://huggingface.co/datasets/MMInstruction/M3IT#data-instances\n",
+ " row = self.loaded_dataset[index]\n",
+ "\n",
+ " instruction = row[\"instruction\"] # str\n",
+ " question = row[\"inputs\"] # str\n",
+ " answer = row[\"outputs\"] # str\n",
+ " full_text = f\"##Instruction: {instruction} ##Question: {question} ##Answer: {answer}\"\n",
+ " text = f\"##Instruction: {instruction} ##Question: {question} ##Answer:\"\n",
+ "\n",
+ " # imageのロード\n",
+ " image_base64_str_list = row[\"image_base64_str\"] # str (base64)\n",
+ " img = Image.open(BytesIO(b64decode(image_base64_str_list[0])))\n",
+ "\n",
+ " inputs = self.process_data(text, img)\n",
+ "\n",
+ " return inputs, img, text, full_text\n",
+ "\n",
+ " def process_data(self, text, img):\n",
+ " inputs = self.processor(\n",
+ " text,\n",
+ " img,\n",
+ " return_tensors=\"pt\",\n",
+ " # max_length=self.max_length,\n",
+ " # padding=\"max_length\",\n",
+ " truncation=True,\n",
+ " )\n",
+ " inputs = {k: v.to(f\"cuda:{device_id}\") for k, v in inputs.items()}\n",
+ " inputs[\"pixel_values\"] = inputs[\"pixel_values\"].to(torch.float16)\n",
+ " inputs[\"labels\"] = None\n",
+ " return inputs"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "974b1a15-0574-4d61-b167-05e74aadad09",
+ "metadata": {},
+ "source": [
+ "# Load configs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "d85d3447-cb64-4ee1-83b9-32ea808c323c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config_file = f\"../configs/training_config_{exp}.yml\"\n",
+ "\n",
+ "# get config\n",
+ "with open(config_file, \"r\") as i_:\n",
+ " config = yaml.safe_load(i_)\n",
+ "\n",
+ "\n",
+ "# model\n",
+ "model_name = config[\"settings\"][\"model_name\"]\n",
+ "vision_model_name = config[\"settings\"][\"vision_model_name\"]\n",
+ "num_image_with_embedding = config[\"settings\"][\"num_image_with_embedding\"]\n",
+ "\n",
+ "keys_finetune = config[\"settings\"][\"keys_finetune\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2b9a7fb1-1594-472c-8ead-116b4924a0f3",
+ "metadata": {},
+ "source": [
+ "# Load a pretrained model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "3fc3a60d-f92e-4e74-a504-114ab42ac1df",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "You are using a model of type llama to instantiate a model of type git_llama. This is not supported for all configurations of models and can yield errors.\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "adf8224c06d14ca6b89975b5669018f2",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Loading checkpoint shards: 0%| | 0/2 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of GitLlamaForCausalLM were not initialized from the model checkpoint at meta-llama/Llama-2-7b-chat-hf and are newly initialized: ['model.image_encoder.vision_model.encoder.layers.9.mlp.fc2.weight', 'model.image_encoder.vision_model.encoder.layers.1.mlp.fc2.bias', 'model.image_encoder.vision_model.encoder.layers.4.self_attn.q_proj.bias', 'model.image_encoder.vision_model.encoder.layers.4.self_attn.out_proj.bias', 'model.image_encoder.vision_model.encoder.layers.4.layer_norm2.bias', 'model.image_encoder.vision_model.encoder.layers.3.layer_norm1.weight', 'model.image_encoder.vision_model.encoder.layers.2.layer_norm1.weight', 'model.image_encoder.vision_model.encoder.layers.9.layer_norm1.weight', 'model.image_encoder.vision_model.encoder.layers.6.self_attn.v_proj.weight', 'model.image_encoder.vision_model.encoder.layers.2.self_attn.out_proj.bias', 'model.image_encoder.vision_model.encoder.layers.3.mlp.fc2.bias', 'model.image_encoder.vision_model.encoder.layers.9.self_attn.out_proj.weight', 'model.image_encoder.vision_model.encoder.layers.7.mlp.fc1.bias', 'model.image_encoder.vision_model.encoder.layers.1.layer_norm2.bias', 'model.image_encoder.vision_model.encoder.layers.9.self_attn.k_proj.weight', 'model.image_encoder.vision_model.encoder.layers.9.layer_norm1.bias', 'model.image_encoder.vision_model.encoder.layers.6.self_attn.q_proj.bias', 'model.image_encoder.vision_model.encoder.layers.6.mlp.fc2.bias', 'model.image_encoder.vision_model.encoder.layers.9.self_attn.out_proj.bias', 'model.image_encoder.vision_model.encoder.layers.9.self_attn.v_proj.bias', 'model.image_encoder.vision_model.embeddings.patch_embedding.weight', 'model.image_encoder.vision_model.encoder.layers.11.layer_norm2.bias', 'model.image_encoder.vision_model.encoder.layers.6.self_attn.q_proj.weight', 'model.image_encoder.vision_model.encoder.layers.11.self_attn.out_proj.weight', 'model.image_encoder.vision_model.encoder.layers.4.mlp.fc1.bias', 'model.image_encoder.vision_model.encoder.layers.10.layer_norm1.weight', 'model.image_encoder.vision_model.encoder.layers.0.mlp.fc1.bias', 'model.image_encoder.vision_model.encoder.layers.0.self_attn.out_proj.weight', 'model.image_encoder.vision_model.encoder.layers.10.layer_norm2.weight', 'model.image_encoder.vision_model.encoder.layers.11.layer_norm2.weight', 'model.image_encoder.vision_model.encoder.layers.9.mlp.fc1.weight', 'model.image_encoder.vision_model.encoder.layers.7.self_attn.out_proj.bias', 'model.image_encoder.vision_model.encoder.layers.7.mlp.fc1.weight', 'model.image_encoder.vision_model.encoder.layers.1.self_attn.k_proj.bias', 'model.image_encoder.vision_model.encoder.layers.7.self_attn.q_proj.weight', 'model.image_encoder.vision_model.pre_layrnorm.bias', 'model.image_encoder.vision_model.encoder.layers.3.self_attn.q_proj.bias', 'model.image_encoder.vision_model.encoder.layers.9.layer_norm2.bias', 'model.image_encoder.vision_model.embeddings.position_embedding.weight', 'model.image_encoder.vision_model.encoder.layers.11.self_attn.k_proj.weight', 'model.image_encoder.vision_model.encoder.layers.8.self_attn.q_proj.weight', 'model.image_encoder.vision_model.encoder.layers.10.self_attn.k_proj.bias', 'model.image_encoder.vision_model.embeddings.class_embedding', 'model.image_encoder.vision_model.encoder.layers.3.self_attn.v_proj.bias', 'model.image_encoder.vision_model.encoder.layers.1.mlp.fc1.bias', 'model.image_encoder.vision_model.encoder.layers.5.mlp.fc1.weight', 'model.visual_projection.visual_projection.1.weight', 'model.image_encoder.vision_model.encoder.layers.7.self_attn.v_proj.weight', 'model.image_encoder.vision_model.encoder.layers.7.mlp.fc2.bias', 'model.image_encoder.vision_model.encoder.layers.9.layer_norm2.weight', 'model.image_encoder.vision_model.encoder.layers.8.mlp.fc2.weight', 'model.image_encoder.vision_model.encoder.layers.11.self_attn.q_proj.bias', 'model.image_encoder.vision_model.encoder.layers.3.layer_norm2.weight', 'model.image_encoder.vision_model.encoder.layers.5.self_attn.q_proj.weight', 'model.image_encoder.vision_model.encoder.layers.6.mlp.fc1.weight', 'model.image_encoder.vision_model.encoder.layers.7.layer_norm2.weight', 'model.image_encoder.vision_model.encoder.layers.4.layer_norm1.weight', 'model.image_encoder.vision_model.encoder.layers.8.self_attn.v_proj.bias', 'model.image_encoder.vision_model.encoder.layers.10.self_attn.q_proj.bias', 'model.image_encoder.vision_model.encoder.layers.1.self_attn.k_proj.weight', 'model.image_encoder.vision_model.encoder.layers.9.self_attn.q_proj.weight', 'model.image_encoder.vision_model.encoder.layers.5.self_attn.out_proj.bias', 'model.image_encoder.vision_model.encoder.layers.11.self_attn.q_proj.weight', 'model.image_encoder.vision_model.encoder.layers.11.self_attn.out_proj.bias', 'model.image_encoder.vision_model.encoder.layers.1.self_attn.v_proj.weight', 'model.image_encoder.vision_model.encoder.layers.7.self_attn.v_proj.bias', 'model.image_encoder.vision_model.encoder.layers.10.self_attn.v_proj.bias', 'model.image_encoder.vision_model.encoder.layers.11.mlp.fc2.weight', 'model.image_encoder.vision_model.encoder.layers.5.self_attn.k_proj.bias', 'model.visual_projection.visual_projection.0.weight', 'model.image_encoder.vision_model.encoder.layers.7.layer_norm2.bias', 'model.image_encoder.vision_model.encoder.layers.5.layer_norm1.weight', 'model.image_encoder.vision_model.encoder.layers.5.self_attn.v_proj.bias', 'model.image_encoder.vision_model.encoder.layers.11.mlp.fc1.weight', 'model.image_encoder.vision_model.encoder.layers.0.mlp.fc2.weight', 'model.image_encoder.vision_model.encoder.layers.10.self_attn.out_proj.weight', 'model.image_encoder.vision_model.encoder.layers.3.layer_norm1.bias', 'model.image_encoder.vision_model.encoder.layers.8.self_attn.k_proj.weight', 'model.image_encoder.vision_model.encoder.layers.11.self_attn.v_proj.bias', 'model.image_encoder.vision_model.encoder.layers.6.mlp.fc2.weight', 'model.image_encoder.vision_model.encoder.layers.1.self_attn.q_proj.bias', 'model.image_encoder.vision_model.encoder.layers.0.layer_norm2.weight', 'model.image_encoder.vision_model.encoder.layers.11.mlp.fc2.bias', 'model.image_encoder.vision_model.encoder.layers.5.self_attn.k_proj.weight', 'model.image_encoder.vision_model.encoder.layers.10.self_attn.q_proj.weight', 'model.visual_projection.visual_projection.1.bias', 'model.image_encoder.vision_model.encoder.layers.0.self_attn.q_proj.bias', 'model.image_encoder.vision_model.encoder.layers.2.mlp.fc2.bias', 'model.image_encoder.vision_model.encoder.layers.3.mlp.fc2.weight', 'model.image_encoder.vision_model.encoder.layers.11.layer_norm1.weight', 'model.image_encoder.vision_model.encoder.layers.8.self_attn.out_proj.weight', 'model.image_encoder.vision_model.encoder.layers.8.layer_norm1.bias', 'model.image_encoder.vision_model.encoder.layers.2.mlp.fc2.weight', 'model.image_encoder.vision_model.encoder.layers.3.self_attn.k_proj.weight', 'model.image_encoder.vision_model.encoder.layers.3.self_attn.out_proj.weight', 'model.image_encoder.vision_model.post_layernorm.weight', 'model.image_encoder.vision_model.encoder.layers.8.mlp.fc2.bias', 'model.image_encoder.vision_model.encoder.layers.10.layer_norm1.bias', 'model.image_encoder.vision_model.encoder.layers.6.self_attn.out_proj.weight', 'model.image_encoder.vision_model.encoder.layers.0.mlp.fc1.weight', 'model.image_encoder.vision_model.encoder.layers.2.self_attn.k_proj.bias', 'model.image_encoder.vision_model.encoder.layers.2.layer_norm1.bias', 'model.image_encoder.vision_model.encoder.layers.1.mlp.fc1.weight', 'model.image_encoder.vision_model.encoder.layers.7.self_attn.k_proj.weight', 'model.image_encoder.vision_model.encoder.layers.2.self_attn.out_proj.weight', 'model.image_encoder.vision_model.encoder.layers.2.self_attn.q_proj.weight', 'model.image_encoder.vision_model.encoder.layers.2.layer_norm2.weight', 'model.image_encoder.vision_model.encoder.layers.8.self_attn.out_proj.bias', 'model.image_encoder.vision_model.encoder.layers.6.mlp.fc1.bias', 'model.image_encoder.vision_model.encoder.layers.5.self_attn.q_proj.bias', 'model.image_encoder.vision_model.encoder.layers.6.layer_norm2.bias', 'model.image_encoder.vision_model.encoder.layers.5.layer_norm2.bias', 'model.image_encoder.vision_model.encoder.layers.2.self_attn.v_proj.weight', 'model.image_encoder.vision_model.encoder.layers.9.mlp.fc1.bias', 'model.image_encoder.vision_model.encoder.layers.4.self_attn.out_proj.weight', 'model.image_encoder.vision_model.encoder.layers.3.self_attn.q_proj.weight', 'model.image_encoder.vision_model.encoder.layers.1.mlp.fc2.weight', 'model.image_encoder.vision_model.encoder.layers.4.mlp.fc1.weight', 'model.image_encoder.vision_model.encoder.layers.4.layer_norm2.weight', 'model.image_encoder.vision_model.encoder.layers.5.layer_norm2.weight', 'model.image_encoder.vision_model.encoder.layers.5.mlp.fc2.bias', 'model.image_encoder.vision_model.pre_layrnorm.weight', 'model.image_encoder.vision_model.encoder.layers.8.layer_norm2.bias', 'model.image_encoder.vision_model.encoder.layers.10.mlp.fc2.bias', 'model.image_encoder.vision_model.encoder.layers.7.layer_norm1.weight', 'model.image_encoder.vision_model.encoder.layers.8.self_attn.v_proj.weight', 'model.image_encoder.vision_model.encoder.layers.10.mlp.fc2.weight', 'model.image_encoder.vision_model.encoder.layers.1.self_attn.out_proj.bias', 'model.image_encoder.vision_model.encoder.layers.3.mlp.fc1.weight', 'model.image_encoder.vision_model.encoder.layers.8.layer_norm2.weight', 'model.image_encoder.vision_model.encoder.layers.0.mlp.fc2.bias', 'model.image_encoder.vision_model.encoder.layers.10.self_attn.v_proj.weight', 'model.image_encoder.vision_model.encoder.layers.4.self_attn.k_proj.weight', 'model.image_encoder.vision_model.encoder.layers.9.mlp.fc2.bias', 'model.image_encoder.vision_model.encoder.layers.2.self_attn.k_proj.weight', 'model.image_encoder.vision_model.encoder.layers.4.self_attn.v_proj.bias', 'model.image_encoder.vision_model.encoder.layers.7.self_attn.k_proj.bias', 'model.image_encoder.vision_model.encoder.layers.2.self_attn.v_proj.bias', 'model.image_encoder.vision_model.encoder.layers.5.mlp.fc1.bias', 'model.image_encoder.vision_model.encoder.layers.4.self_attn.q_proj.weight', 'model.image_encoder.vision_model.encoder.layers.7.mlp.fc2.weight', 'model.image_encoder.vision_model.encoder.layers.8.self_attn.k_proj.bias', 'model.image_encoder.vision_model.encoder.layers.10.self_attn.k_proj.weight', 'model.image_encoder.vision_model.encoder.layers.2.mlp.fc1.bias', 'model.image_encoder.vision_model.encoder.layers.0.layer_norm1.weight', 'model.image_encoder.vision_model.encoder.layers.10.mlp.fc1.weight', 'model.image_encoder.vision_model.encoder.layers.4.mlp.fc2.weight', 'model.image_encoder.vision_model.encoder.layers.5.self_attn.out_proj.weight', 'model.image_encoder.vision_model.encoder.layers.6.self_attn.k_proj.bias', 'model.image_encoder.vision_model.encoder.layers.4.self_attn.v_proj.weight', 'model.image_encoder.vision_model.encoder.layers.8.self_attn.q_proj.bias', 'model.image_encoder.vision_model.encoder.layers.0.layer_norm2.bias', 'model.image_encoder.vision_model.encoder.layers.8.mlp.fc1.bias', 'model.image_encoder.vision_model.encoder.layers.10.self_attn.out_proj.bias', 'model.image_encoder.vision_model.encoder.layers.4.self_attn.k_proj.bias', 'model.image_encoder.vision_model.encoder.layers.6.layer_norm1.weight', 'model.image_encoder.vision_model.encoder.layers.6.self_attn.k_proj.weight', 'model.image_encoder.vision_model.encoder.layers.1.layer_norm2.weight', 'model.image_encoder.vision_model.encoder.layers.5.self_attn.v_proj.weight', 'model.image_encoder.vision_model.encoder.layers.1.layer_norm1.weight', 'model.image_encoder.vision_model.encoder.layers.5.layer_norm1.bias', 'model.image_encoder.vision_model.encoder.layers.1.self_attn.v_proj.bias', 'model.image_encoder.vision_model.encoder.layers.0.self_attn.v_proj.weight', 'model.image_encoder.vision_model.encoder.layers.5.mlp.fc2.weight', 'model.image_encoder.vision_model.encoder.layers.0.layer_norm1.bias', 'model.image_encoder.vision_model.encoder.layers.9.self_attn.k_proj.bias', 'model.image_encoder.vision_model.encoder.layers.9.self_attn.q_proj.bias', 'model.image_encoder.vision_model.encoder.layers.8.mlp.fc1.weight', 'model.image_encoder.vision_model.encoder.layers.2.layer_norm2.bias', 'model.visual_projection.visual_projection.0.bias', 'model.image_encoder.vision_model.encoder.layers.6.layer_norm1.bias', 'model.image_encoder.vision_model.encoder.layers.9.self_attn.v_proj.weight', 'model.image_encoder.vision_model.encoder.layers.11.self_attn.v_proj.weight', 'model.image_encoder.vision_model.encoder.layers.2.self_attn.q_proj.bias', 'model.image_encoder.vision_model.encoder.layers.0.self_attn.q_proj.weight', 'model.image_encoder.vision_model.encoder.layers.4.mlp.fc2.bias', 'model.image_encoder.vision_model.encoder.layers.0.self_attn.k_proj.weight', 'model.image_encoder.vision_model.encoder.layers.1.self_attn.q_proj.weight', 'model.image_encoder.vision_model.encoder.layers.8.layer_norm1.weight', 'model.image_encoder.vision_model.encoder.layers.11.mlp.fc1.bias', 'model.image_encoder.vision_model.post_layernorm.bias', 'model.image_encoder.vision_model.encoder.layers.0.self_attn.v_proj.bias', 'model.image_encoder.vision_model.encoder.layers.6.layer_norm2.weight', 'model.image_encoder.vision_model.encoder.layers.6.self_attn.v_proj.bias', 'model.image_encoder.vision_model.encoder.layers.3.layer_norm2.bias', 'model.image_encoder.vision_model.encoder.layers.0.self_attn.out_proj.bias', 'model.image_encoder.vision_model.encoder.layers.6.self_attn.out_proj.bias', 'model.image_encoder.vision_model.encoder.layers.10.mlp.fc1.bias', 'model.image_encoder.vision_model.encoder.layers.4.layer_norm1.bias', 'model.image_encoder.vision_model.encoder.layers.1.layer_norm1.bias', 'model.image_encoder.vision_model.encoder.layers.2.mlp.fc1.weight', 'model.image_encoder.vision_model.encoder.layers.3.mlp.fc1.bias', 'model.image_encoder.vision_model.encoder.layers.10.layer_norm2.bias', 'model.image_encoder.vision_model.encoder.layers.11.self_attn.k_proj.bias', 'model.image_encoder.vision_model.encoder.layers.3.self_attn.v_proj.weight', 'model.image_encoder.vision_model.encoder.layers.1.self_attn.out_proj.weight', 'model.image_encoder.vision_model.encoder.layers.7.self_attn.q_proj.bias', 'model.image_encoder.vision_model.encoder.layers.0.self_attn.k_proj.bias', 'model.image_encoder.vision_model.encoder.layers.7.self_attn.out_proj.weight', 'model.image_encoder.vision_model.encoder.layers.11.layer_norm1.bias', 'model.image_encoder.vision_model.encoder.layers.3.self_attn.out_proj.bias', 'model.image_encoder.vision_model.encoder.layers.7.layer_norm1.bias', 'model.image_encoder.vision_model.encoder.layers.3.self_attn.k_proj.bias']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Complete preparing an eval model\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = load_model(model_name, vision_model_name, num_image_with_embedding)\n",
+ "\n",
+ "# lora\n",
+ "if config[\"use_lora\"]:\n",
+ " keys_finetune.append(\"lora\")\n",
+ " model = apply_lora_model(model, model_name, config)\n",
+ "\n",
+ "weight = {}\n",
+ "weight_path = glob.glob(f'../output/{exp}/checkpoint*/pytorch*.bin')\n",
+ "for w in weight_path:\n",
+ " weight_temp = torch.load(w, map_location=\"cpu\")\n",
+ " weight.update(weight_temp)\n",
+ "model.load_state_dict(weight, strict=False)\n",
+ "\n",
+ "model.eval()\n",
+ "model.to(f\"cuda:{device_id}\")\n",
+ "print(\"Complete preparing an eval model\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "70c5259a-b26c-45ef-b941-86b74fbd6d75",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a17c3b9c-724f-4d38-9d36-19823ea3ddc6",
+ "metadata": {},
+ "source": [
+ "# Inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "ca00dcf9-6334-4b3f-a7cc-da0f4150dd55",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Found cached dataset m3_it (/home/y_inoue/.cache/huggingface/datasets/MMInstruction___m3_it/coco/1.0.0/631dfd20153e0fbacb50b0239d4a71727503813fa0e821ba5ab399bed706034e)\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3269a01d9ed14a96a3b00384b9c75e37",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "--------------------\n",
+ "[' ##Instruction: Write a succinct description of the image, capturing its main components, the relationships between them, and any notable details. ##Question: ##Answer: A street sign with a street name and a parking sign.']\n",
+ "##Instruction: Write a succinct description of the image, capturing its main components, the relationships between them, and any notable details. ##Question: ##Answer: A large wooden pole with a green street sign hanging from it.\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "