From 1c4d50b404e23df05f520311985937e30c859a96 Mon Sep 17 00:00:00 2001 From: Inoichan <37664066+Ino-Ichan@users.noreply.github.com> Date: Mon, 4 Mar 2024 18:23:22 +0900 Subject: [PATCH] feat: update heron's datasets (#31) * feat: update llava dataset for ignoring padding loss, and add llava instruct dataset designed for only apply losses to gpt responses * fix m3it ignore index * add m3it instruction dataset * minor fix and add comments * fix ignore index of japanese csv dataset, and add instruction tuning dataset of japanese csv * -100 to IGNORE_INDEX * refactor get_each_dataset function * remove redundant conversion to RGB * minor fix: change order of dataset_classes dict --- configs/datasets/japanese_csv_instruct.yaml | 5 + configs/datasets/llava_both.yaml | 6 + configs/datasets/llava_both_instruct.yaml | 6 + configs/datasets/llava_en.yaml | 5 +- configs/datasets/llava_en_instruct.yaml | 6 + configs/datasets/llava_ja.yaml | 5 +- configs/datasets/llava_ja_instruct.yaml | 6 + configs/datasets/m3it_instruct.yaml | 22 ++ configs/datasets/m3it_ipc_instruct.yaml | 3 + heron/datasets/base_datasets.py | 11 +- heron/datasets/ja_csv_datasets.py | 51 ++--- heron/datasets/ja_csv_instruct_datasets.py | 219 ++++++++++++++++++ heron/datasets/llava_datasets.py | 60 +++-- heron/datasets/llava_instruct_datasets.py | 239 ++++++++++++++++++++ heron/datasets/m3it_datasets.py | 67 +++--- heron/datasets/m3it_instruct_datasets.py | 186 +++++++++++++++ heron/datasets/utils.py | 34 +-- 17 files changed, 840 insertions(+), 91 deletions(-) create mode 100644 configs/datasets/japanese_csv_instruct.yaml create mode 100644 configs/datasets/llava_both.yaml create mode 100644 configs/datasets/llava_both_instruct.yaml create mode 100644 configs/datasets/llava_en_instruct.yaml create mode 100644 configs/datasets/llava_ja_instruct.yaml create mode 100644 configs/datasets/m3it_instruct.yaml create mode 100644 configs/datasets/m3it_ipc_instruct.yaml create mode 100644 heron/datasets/ja_csv_instruct_datasets.py create mode 100644 heron/datasets/llava_instruct_datasets.py create mode 100644 heron/datasets/m3it_instruct_datasets.py diff --git a/configs/datasets/japanese_csv_instruct.yaml b/configs/datasets/japanese_csv_instruct.yaml new file mode 100644 index 0000000..2a84bb4 --- /dev/null +++ b/configs/datasets/japanese_csv_instruct.yaml @@ -0,0 +1,5 @@ +dataset_type: japanese_csv_instruct +dataset_root: "./" +dataset_names: + - coco + - visual_genome diff --git a/configs/datasets/llava_both.yaml b/configs/datasets/llava_both.yaml new file mode 100644 index 0000000..2b12322 --- /dev/null +++ b/configs/datasets/llava_both.yaml @@ -0,0 +1,6 @@ +dataset_type: llava +dataset_root: ./ +language: "both" +jsonl_path: +n_train: 157000 +n_val: 712 diff --git a/configs/datasets/llava_both_instruct.yaml b/configs/datasets/llava_both_instruct.yaml new file mode 100644 index 0000000..5825c84 --- /dev/null +++ b/configs/datasets/llava_both_instruct.yaml @@ -0,0 +1,6 @@ +dataset_type: llava_instruct +dataset_root: ./ +jsonl_path: +language: "both" +n_train: 157000 +n_val: 712 diff --git a/configs/datasets/llava_en.yaml b/configs/datasets/llava_en.yaml index 1a6686f..7804080 100644 --- a/configs/datasets/llava_en.yaml +++ b/configs/datasets/llava_en.yaml @@ -1,3 +1,6 @@ dataset_type: llava -dataset_root: "./" +dataset_root: ./ language: "en" +jsonl_path: +n_train: 157000 +n_val: 712 diff --git a/configs/datasets/llava_en_instruct.yaml b/configs/datasets/llava_en_instruct.yaml new file mode 100644 index 0000000..e01fcf7 --- /dev/null +++ b/configs/datasets/llava_en_instruct.yaml @@ -0,0 +1,6 @@ +dataset_type: llava_instruct +dataset_root: ./ +jsonl_path: +language: "en" +n_train: 157000 +n_val: 712 diff --git a/configs/datasets/llava_ja.yaml b/configs/datasets/llava_ja.yaml index b85f47a..a005a94 100644 --- a/configs/datasets/llava_ja.yaml +++ b/configs/datasets/llava_ja.yaml @@ -1,3 +1,6 @@ dataset_type: llava -dataset_root: "./" +dataset_root: ./ language: "ja" +jsonl_path: +n_train: 157000 +n_val: 712 diff --git a/configs/datasets/llava_ja_instruct.yaml b/configs/datasets/llava_ja_instruct.yaml new file mode 100644 index 0000000..6fa24bc --- /dev/null +++ b/configs/datasets/llava_ja_instruct.yaml @@ -0,0 +1,6 @@ +dataset_type: llava_instruct +dataset_root: ./ +jsonl_path: +language: "ja" +n_train: 157000 +n_val: 712 diff --git a/configs/datasets/m3it_instruct.yaml b/configs/datasets/m3it_instruct.yaml new file mode 100644 index 0000000..16d465c --- /dev/null +++ b/configs/datasets/m3it_instruct.yaml @@ -0,0 +1,22 @@ +dataset_type: m3it +dataset_names: + - textcap + - image-paragraph-captioning + - coco-goi + - coco-itm + - vqa-v2 + - shapes + - docvqa + - ocr-vqa + - st-vqa + - text-vqa + - gqa + - okvqa + - a-okvqa + - viquae + - clevr + - nlvr + - vcr + - visual-mrc + - visual-dialog + - multi30k diff --git a/configs/datasets/m3it_ipc_instruct.yaml b/configs/datasets/m3it_ipc_instruct.yaml new file mode 100644 index 0000000..06af536 --- /dev/null +++ b/configs/datasets/m3it_ipc_instruct.yaml @@ -0,0 +1,3 @@ +dataset_type: m3it_instruct +dataset_names: + - image-paragraph-captioning diff --git a/heron/datasets/base_datasets.py b/heron/datasets/base_datasets.py index 0680789..7c05602 100644 --- a/heron/datasets/base_datasets.py +++ b/heron/datasets/base_datasets.py @@ -14,9 +14,12 @@ import abc +import traceback from torch.utils.data import Dataset -import traceback + +IGNORE_INDEX = -100 + class BaseDataset(Dataset): def __init__(self, is_inference: bool = False): @@ -44,13 +47,11 @@ def _get_item_inference(self, index): class ResilientDataset(BaseDataset): - - def __init__(self, is_inference: bool = False, max_trials: int = 5): + def __init__(self, is_inference: bool = False, max_trials: int = 5): super().__init__(is_inference) self.max_trials = max_trials - + def __getitem__(self, index: int): - if self.is_inference: return self._get_item_inference(index) else: diff --git a/heron/datasets/ja_csv_datasets.py b/heron/datasets/ja_csv_datasets.py index bbb679c..2f0249f 100644 --- a/heron/datasets/ja_csv_datasets.py +++ b/heron/datasets/ja_csv_datasets.py @@ -19,10 +19,11 @@ import cv2 import numpy as np import pandas as pd +import torch from PIL import Image from torch.utils.data import Dataset -from .base_datasets import BaseDataset +from .base_datasets import IGNORE_INDEX, BaseDataset HFProcessor = "HFProcessor" @@ -91,8 +92,8 @@ def create( def __len__(self) -> int: return len(self.unique_img_path) - def preprocess_image(self, image): - return self.processor(images=[image], return_tensors="pt")["pixel_values"][0] + def preprocess_image(self, images): + return self.processor(images=images, return_tensors="pt")["pixel_values"][0] def tokenize(self, text): if self.is_inference: @@ -108,7 +109,12 @@ def _get_item_train(self, index): df_interest = self.loaded_dataset[self.loaded_dataset.img_path == img_path].reset_index( drop=True ) - text = "" + # imageのロード + image = Image.open(os.path.join(self.dataset_root, img_path)).convert("RGB") + image = np.array(image) + images = [image] + + prompt = "" # concatenate text data order = list(range(len(df_interest))) @@ -117,30 +123,23 @@ def _get_item_train(self, index): row = df_interest.iloc[i] question = row["question"] # str answer = row["caption"] # str - text += f"##human: {question}\n##gpt: {answer}\n" - - # remove final space - text = text[: len(text) - 1] + prompt += f"##human: {question}\n##gpt: {answer}\n" - # imageのロード - image = Image.open(os.path.join(self.dataset_root, img_path)).convert("RGB") - img = np.array(image) - if img.shape[2] != 3: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + tokenized = self.tokenize(prompt) + tokenized_prompt = tokenized["input_ids"][0] + labels = torch.full_like(tokenized_prompt, IGNORE_INDEX) + prompt_attn_mask = tokenized["attention_mask"][0] - inputs = self.processor( - images=img, - text=text, - return_tensors="pt", - max_length=self.max_length, - padding="max_length", - truncation=True, - ) + index_ignore_loss = prompt_attn_mask.sum().item() + 1 + labels[:index_ignore_loss] = tokenized_prompt[:index_ignore_loss] - # batch size 1 -> unbatch - inputs = {k: v[0] for k, v in inputs.items()} - inputs["labels"] = inputs["input_ids"] - return inputs + return_dict = { + "input_ids": tokenized_prompt, + "labels": labels, + "attention_mask": prompt_attn_mask, + "pixel_values": self.preprocess_image(images), + } + return return_dict def _get_item_inference(self, index): # cf: https://huggingface.co/datasets/MMInstruction/M3IT#data-instances @@ -159,8 +158,6 @@ def _get_item_inference(self, index): # imageのロード img = Image.open(os.path.join(self.dataset_root, img_path)).convert("RGB") img = np.array(img) - if img.shape[2] != 3: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) inputs = self.processor( text, diff --git a/heron/datasets/ja_csv_instruct_datasets.py b/heron/datasets/ja_csv_instruct_datasets.py new file mode 100644 index 0000000..6a94578 --- /dev/null +++ b/heron/datasets/ja_csv_instruct_datasets.py @@ -0,0 +1,219 @@ +# Copyright 2023 Turing Inc. Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +import random + +import cv2 +import numpy as np +import pandas as pd +import torch +from PIL import Image +from torch.utils.data import Dataset + +from .base_datasets import IGNORE_INDEX, BaseDataset + +HFProcessor = "HFProcessor" + + +class JapaneseCSVInstructDataset(BaseDataset): + """Dataset for Custom Japanese CSV V&L Dataset learning + This dataset is designed for instruction tuning, meaning it considers the lossese associated with gpt responses. + """ + + def __init__( + self, + loaded_dataset: pd.DataFrame, + processor: HFProcessor, + max_length: int, + dataset_root: str, + is_inference: bool = False, + ): + super(JapaneseCSVInstructDataset, self).__init__() + self.loaded_dataset = loaded_dataset + self.unique_img_path = loaded_dataset.img_path.unique() + + self.max_length = max_length + self.processor = processor + self.is_inference = is_inference + self.dataset_root = dataset_root + + @classmethod + def create( + cls, + dataset_config: dict, + processor: HFProcessor, + max_length: int, + split: str = "train", + is_inference: bool = False, + ): + dataset_root = dataset_config["dataset_root"] + target_dataset_list = [] + if "coco" in dataset_config["dataset_names"]: + if split == "train": + df_train = pd.read_csv(os.path.join(dataset_root, "data/coco/df_train.csv")) + target_dataset_list.append(df_train) + else: + df_val = pd.read_csv(os.path.join(dataset_root, "data/coco/df_val.csv")) + target_dataset_list.append(df_val) + + if "visual_genome" in dataset_config["dataset_names"]: + df_vg = pd.read_csv(os.path.join(dataset_root, "data/visual_genome_ja/df_vg.csv")) + if split != "train": + val_ratio = 0.1 + num_val = int(len(df_vg) * val_ratio) + df_vg = df_vg[:num_val] + target_dataset_list.append(df_vg) + else: + raise ValueError( + f"dataset_type: {dataset_config.get('dataset_type')} is not supported." + ) + + target_dataframe = pd.concat(target_dataset_list, axis=0, ignore_index=True) + + return cls( + target_dataframe, + processor, + max_length, + is_inference=is_inference, + dataset_root=dataset_root, + ) + + def __len__(self) -> int: + return len(self.unique_img_path) + + def preprocess_image(self, images): + return self.processor(images=images, return_tensors="pt")["pixel_values"][0] + + def tokenize(self, text): + kwargs = {} + return self.processor.tokenizer(text=text, return_tensors="pt", **kwargs) + + def _get_item_train(self, index): + # cf: https://huggingface.co/datasets/MMInstruction/M3IT#data-instances + img_path = self.unique_img_path[index] + + df_interest = self.loaded_dataset[self.loaded_dataset.img_path == img_path].reset_index( + drop=True + ) + # imageのロード + image = Image.open(os.path.join(self.dataset_root, img_path)).convert("RGB") + image = np.array(image) + images = [image] + + tokenized_list = [] + labels_list = [] + attn_mask_list = [] + + # concatenate text data + order = list(range(len(df_interest))) + random.shuffle(order) + for i, c in enumerate(order): + if i > 0: + drop_eos_token = 1 + else: + drop_eos_token = 0 + + row = df_interest.iloc[i] + question = row["question"] # str + answer = row["caption"] # str + prompt_q = f"##human: {question}\n##gpt: " + prompt_a = f"{answer}" + + # ================================ + # tokenize question text + # ================================ + tokenized = self.tokenize(prompt_q) + tokenized_prompt = tokenized["input_ids"][0][drop_eos_token:] + # all label should be ignored + labels = torch.full_like(tokenized_prompt, IGNORE_INDEX) + prompt_attn_mask = tokenized["attention_mask"][0][drop_eos_token:] + + tokenized_list.append(tokenized_prompt) + labels_list.append(labels) + attn_mask_list.append(prompt_attn_mask) + + # ================================ + # tokenize answer text + # ================================ + tokenized = self.tokenize(prompt_a) + tokenized_prompt = tokenized["input_ids"][0][1:] + # all label should be included in loss + labels = tokenized_prompt + prompt_attn_mask = tokenized["attention_mask"][0][1:] + + tokenized_list.append(tokenized_prompt) + labels_list.append(labels) + attn_mask_list.append(prompt_attn_mask) + + # ================================================= + # concat question and answer, apply max_length + # ================================================= + tokenized_prompt = torch.cat(tokenized_list, dim=-1) + labels = torch.cat(labels_list, dim=-1) + prompt_attn_mask = torch.cat(attn_mask_list, dim=-1) + + if len(tokenized_prompt) < self.max_length: + pad_length = self.max_length - len(tokenized_prompt) + tokenized_prompt = torch.cat( + [ + tokenized_prompt, + torch.tensor([self.processor.tokenizer.pad_token_id] * pad_length), + ], + dim=-1, + ) + labels = torch.cat([labels, torch.tensor([IGNORE_INDEX] * pad_length)], dim=-1) + prompt_attn_mask = torch.cat( + [prompt_attn_mask, torch.tensor([0] * pad_length)], dim=-1 + ) + else: + tokenized_prompt = tokenized_prompt[: self.max_length] + labels = labels[: self.max_length] + prompt_attn_mask = prompt_attn_mask[: self.max_length] + + return_dict = { + "input_ids": tokenized_prompt, + "labels": labels, + "attention_mask": prompt_attn_mask, + "pixel_values": self.preprocess_image(images), + } + return return_dict + + def _get_item_inference(self, index): + # cf: https://huggingface.co/datasets/MMInstruction/M3IT#data-instances + img_path = self.unique_img_path[index] + + df_interest = self.loaded_dataset[self.loaded_dataset.img_path == img_path].reset_index( + drop=True + ) + text = "" + + row = df_interest.iloc[0] + question = row["question"] # str + answer = row["caption"] # str + text += f"##human: {question}\n##gpt: " + + # imageのロード + img = Image.open(os.path.join(self.dataset_root, img_path)).convert("RGB") + img = np.array(img) + + inputs = self.processor( + text, + img, + return_tensors="pt", + ) + + inputs["labels"] = None + return inputs, img, answer diff --git a/heron/datasets/llava_datasets.py b/heron/datasets/llava_datasets.py index 07b5e6d..95a4b54 100644 --- a/heron/datasets/llava_datasets.py +++ b/heron/datasets/llava_datasets.py @@ -15,11 +15,14 @@ import os from typing import Dict +import cv2 +import numpy as np +import torch from datasets import load_dataset from datasets.arrow_dataset import Dataset as HFDataset from PIL import Image -from .base_datasets import BaseDataset +from .base_datasets import IGNORE_INDEX, BaseDataset HFProcessor = "HFProcessor" @@ -62,8 +65,23 @@ def create( is_inference: inference mode or not language: "ja" or "en". """ - hf_dataset = load_dataset("turing-motors/LLaVA-Instruct-150K-JA") - split_datasets = hf_dataset["train"].train_test_split(test_size=0.05, seed=11) + jsonl_path = dataset_config.get("jsonl_path", None) + n_train = dataset_config["n_train"] + n_val = dataset_config["n_val"] + if jsonl_path is not None: + import json + + with open(jsonl_path) as f: + jsonl_datasets = json.load(f) + split_datasets = { + "train": jsonl_datasets[:n_train], + "test": jsonl_datasets[n_train : n_train + n_val], + } + else: + hf_dataset = load_dataset("turing-motors/LLaVA-Instruct-150K-JA") + split_datasets = hf_dataset["train"].train_test_split( + test_size=(n_val / (n_train + n_val)), seed=11 + ) if split == "train": return cls( @@ -100,14 +118,19 @@ def tokenize(self, text): def __len__(self) -> int: return len(self.loaded_dataset) - def get_message(self, c): + def get_language(self): if self.language == "ja": - message = c["jp"] + language = "jp" elif self.language == "en": - message = c["value"] + language = "value" + elif self.language == "both": + if np.random.rand() < 0.5: + language = "jp" + else: + language = "value" else: raise ValueError("invalid language") - return message + return language def _get_item_train(self, index): row = self.loaded_dataset[index] @@ -115,21 +138,28 @@ def _get_item_train(self, index): image_path = os.path.join( self.dataset_root, "coco/train2014/COCO_train2014_" + row["image"] ) - images = [Image.open(image_path)] + image = Image.open(image_path).convert("RGB") + image = np.array(image) + images = [image] prompt = "" + language = self.get_language() for c in row["conversations"]: agent = c["from"] - message = self.get_message(c) - prompt += f"##{agent}: {message}\n##gpt: \n" + message = c[language] + prompt += f"##{agent}: {message}\n" tokenized = self.tokenize(prompt) tokenized_prompt = tokenized["input_ids"][0] + labels = torch.full_like(tokenized_prompt, IGNORE_INDEX) prompt_attn_mask = tokenized["attention_mask"][0] + index_ignore_loss = prompt_attn_mask.sum().item() + 1 + labels[:index_ignore_loss] = tokenized_prompt[:index_ignore_loss] + return_dict = { "input_ids": tokenized_prompt, - "labels": tokenized_prompt, + "labels": labels, "attention_mask": prompt_attn_mask, "pixel_values": self.preprocess_image(images), } @@ -141,10 +171,12 @@ def _get_item_inference(self, index): image_path = os.path.join( self.dataset_root, "coco/train2014/COCO_train2014_" + row["image"] ) - images = [Image.open(image_path)] + image = Image.open(image_path).convert("RGB") + image = np.array(image) + images = [image] - message = self.get_message(row["conversations"][0]) - prompt = f"###human: {message}" + language = self.get_language() + prompt = f"##human: {row['conversations'][language]}\n##gpt: " tokenized = self.tokenize(prompt) tokenized_prompt = tokenized["input_ids"][0] diff --git a/heron/datasets/llava_instruct_datasets.py b/heron/datasets/llava_instruct_datasets.py new file mode 100644 index 0000000..16454e7 --- /dev/null +++ b/heron/datasets/llava_instruct_datasets.py @@ -0,0 +1,239 @@ +# Copyright 2023 Turing Inc. Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Dict + +import cv2 +import numpy as np +import torch +from datasets import load_dataset +from datasets.arrow_dataset import Dataset as HFDataset +from PIL import Image + +from .base_datasets import IGNORE_INDEX, BaseDataset + +HFProcessor = "HFProcessor" + + +class LlavaInstructDataset(BaseDataset): + """Dataset for LLaVA + This dataset is designed for instruction tuning, meaning it considers the lossese associated with gpt responses. + """ + + def __init__( + self, + loaded_dataset: HFDataset, + processor: HFProcessor, + max_length: int, + is_inference: bool, + language: str, + dataset_root: str, + ): + super(LlavaInstructDataset, self).__init__(is_inference) + assert language in ["ja", "en"], "given language is not supported" + self.loaded_dataset = loaded_dataset + self.max_length = max_length + self.processor = processor + self.is_inference = is_inference + self.language = language + self.dataset_root = dataset_root + + @classmethod + def create( + cls, + dataset_config: Dict, + processor: HFProcessor, + max_length: int, + split: str = "train", + is_inference: bool = False, + ): + """ + Args: + dataset_config: dataset configuration + processor: HuggingFace Processor class + split: data split. "train" or "val" + is_inference: inference mode or not + language: "ja" or "en". + """ + jsonl_path = dataset_config.get("jsonl_path", None) + n_train = dataset_config["n_train"] + n_val = dataset_config["n_val"] + if jsonl_path is not None: + import json + + with open(jsonl_path) as f: + jsonl_datasets = json.load(f) + split_datasets = { + "train": jsonl_datasets[:n_train], + "test": jsonl_datasets[n_train : n_train + n_val], + } + else: + hf_dataset = load_dataset("turing-motors/LLaVA-Instruct-150K-JA") + split_datasets = hf_dataset["train"].train_test_split( + test_size=(n_val / (n_train + n_val)), seed=11 + ) + + if split == "train": + return cls( + split_datasets["train"], + processor, + max_length, + is_inference, + dataset_config["language"], + dataset_config["dataset_root"], + ) + + elif split == "validation": + return cls( + split_datasets["test"], + processor, + max_length, + is_inference, + dataset_config["language"], + dataset_config["dataset_root"], + ) + else: + raise ValueError("given split is invalid") + + def preprocess_image(self, images): + return self.processor(images=images, return_tensors="pt")["pixel_values"][0] + + def tokenize(self, text): + kwargs = {} + return self.processor.tokenizer(text=text, return_tensors="pt", **kwargs) + + def __len__(self) -> int: + return len(self.loaded_dataset) + + def get_language(self): + if self.language == "ja": + language = "jp" + elif self.language == "en": + language = "value" + elif self.language == "both": + if np.random.rand() < 0.5: + language = "jp" + else: + language = "value" + else: + raise ValueError("invalid language") + return language + + def _get_item_train(self, index): + row = self.loaded_dataset[index] + + image_path = os.path.join( + self.dataset_root, "coco/train2014/COCO_train2014_" + row["image"] + ) + image = Image.open(image_path).convert("RGB") + image = np.array(image) + images = [image] + + # ================================ + # tokenize question and answer text + # ================================ + language = self.get_language() + + tokenized_list = [] + labels_list = [] + attn_mask_list = [] + input_text_all = "" + for i, c in enumerate(row["conversations"]): + if i > 0: + drop_eos_token = 1 + else: + drop_eos_token = 0 + agent = c["from"] + if agent == "gpt": + agent_prompt = "" + next_agent_prompt = f"{self.processor.tokenizer.eos_token}\n" + elif agent == "human": + agent_prompt = "##human: " + next_agent_prompt = "\n##gpt: " + message = c[language] + input_text = f"{agent_prompt}{message}{next_agent_prompt}" + input_text_all += input_text + tokenized = self.tokenize(input_text) + tokenized_prompt = tokenized["input_ids"][0][drop_eos_token:] + if agent == "gpt": + labels = tokenized_prompt + elif agent == "human": + labels = torch.full_like(tokenized_prompt, IGNORE_INDEX) + prompt_attn_mask = tokenized["attention_mask"][0][drop_eos_token:] + + tokenized_list.append(tokenized_prompt) + labels_list.append(labels) + attn_mask_list.append(prompt_attn_mask) + + # ================================================= + # concat question and answer, apply max_length + # ================================================= + tokenized_prompt = torch.cat(tokenized_list, dim=-1) + labels = torch.cat(labels_list, dim=-1) + prompt_attn_mask = torch.cat(attn_mask_list, dim=-1) + + if len(tokenized_prompt) < self.max_length: + pad_length = self.max_length - len(tokenized_prompt) + tokenized_prompt = torch.cat( + [ + tokenized_prompt, + torch.tensor([self.processor.tokenizer.pad_token_id] * pad_length), + ], + dim=-1, + ) + labels = torch.cat([labels, torch.tensor([IGNORE_INDEX] * pad_length)], dim=-1) + prompt_attn_mask = torch.cat( + [prompt_attn_mask, torch.tensor([0] * pad_length)], dim=-1 + ) + else: + tokenized_prompt = tokenized_prompt[: self.max_length] + labels = labels[: self.max_length] + prompt_attn_mask = prompt_attn_mask[: self.max_length] + + return_dict = { + "input_ids": tokenized_prompt, + "labels": labels, + "attention_mask": prompt_attn_mask, + "pixel_values": self.preprocess_image(images), + } + return return_dict + + def _get_item_inference(self, index): + row = self.loaded_dataset[index] + + image_path = os.path.join( + self.dataset_root, "coco/train2014/COCO_train2014_" + row["image"] + ) + image = Image.open(image_path).convert("RGB") + image = np.array(image) + images = [image] + + language = self.get_language() + prompt = f"##human: {row['conversations'][language]}\n##gpt: " + + tokenized = self.tokenize(prompt) + tokenized_prompt = tokenized["input_ids"][0] + prompt_attn_mask = tokenized["attention_mask"][0] + + return_dict = { + "input_ids": tokenized_prompt, + "labels": tokenized_prompt, + "attention_mask": prompt_attn_mask, + "pixel_values": self.preprocess_image(images), + "image": images[0], + "conversations": row["conversations"], + "prompt": prompt, + } + return return_dict diff --git a/heron/datasets/m3it_datasets.py b/heron/datasets/m3it_datasets.py index 06858e6..577800e 100644 --- a/heron/datasets/m3it_datasets.py +++ b/heron/datasets/m3it_datasets.py @@ -19,10 +19,11 @@ import cv2 import datasets import numpy as np +import torch from PIL import Image from torch.utils.data import ConcatDataset -from .base_datasets import ResilientDataset +from .base_datasets import IGNORE_INDEX, ResilientDataset HFProcessor = "HFProcessor" @@ -53,7 +54,8 @@ def create( is_inference: bool = False, ): dataset_list = [ - datasets.load_dataset("MMInstruction/M3IT", i, num_proc=16) for i in dataset_config["dataset_names"] + datasets.load_dataset("MMInstruction/M3IT", i, num_proc=16) + for i in dataset_config["dataset_names"] ] # some dataset have no validation @@ -67,6 +69,16 @@ def create( return cls(target_dataframe, processor, max_length, is_inference) + def preprocess_image(self, images): + return self.processor(images=images, return_tensors="pt")["pixel_values"][0] + + def tokenize(self, text): + if self.is_inference: + kwargs = {} + else: + kwargs = {"padding": "max_length", "max_length": self.max_length, "truncation": True} + return self.processor.tokenizer(text=text, return_tensors="pt", **kwargs) + def __len__(self) -> int: return len(self.loaded_dataset) @@ -74,31 +86,33 @@ def _get_item_train(self, index): # cf: https://huggingface.co/datasets/MMInstruction/M3IT#data-instances row = self.loaded_dataset[index] + # imageのロード + image_base64_str_list = row["image_base64_str"] # str (base64) + image = Image.open(BytesIO(b64decode(image_base64_str_list[0]))).convert("RGB") + image = np.array(image) + images = [image] + # some of nlvr data were broken instruction = row["instruction"] # str question = row["inputs"] # str answer = row["outputs"] # str - text = f"##human: {instruction} {question}\n##gpt: {answer}" + prompt = f"##human: {instruction} {question}\n##gpt: {answer}" - # imageのロード - image_base64_str_list = row["image_base64_str"] # str (base64) - img = Image.open(BytesIO(b64decode(image_base64_str_list[0]))).convert("RGB") - img = np.array(img) - if img.shape[2] != 3: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + tokenized = self.tokenize(prompt) + tokenized_prompt = tokenized["input_ids"][0] + labels = torch.full_like(tokenized_prompt, IGNORE_INDEX) + prompt_attn_mask = tokenized["attention_mask"][0] - inputs = self.processor( - images=img, - text=text, - return_tensors="pt", - max_length=self.max_length, - padding="max_length", - truncation=True, - ) - # batch size 1 -> unbatch - inputs = {k: v[0] for k, v in inputs.items()} - inputs["labels"] = inputs["input_ids"] - return inputs + index_ignore_loss = prompt_attn_mask.sum().item() + 1 + labels[:index_ignore_loss] = tokenized_prompt[:index_ignore_loss] + + return_dict = { + "input_ids": tokenized_prompt, + "labels": labels, + "attention_mask": prompt_attn_mask, + "pixel_values": self.preprocess_image(images), + } + return return_dict def _get_item_inference(self, index): # cf: https://huggingface.co/datasets/MMInstruction/M3IT#data-instances @@ -112,15 +126,14 @@ def _get_item_inference(self, index): # imageのロード image_base64_str_list = row["image_base64_str"] # str (base64) - img = Image.open(BytesIO(b64decode(image_base64_str_list[0]))).convert("RGB") - img = np.array(img) - if img.shape[2] != 3: - img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + image = Image.open(BytesIO(b64decode(image_base64_str_list[0]))).convert("RGB") + image = np.array(image) + images = [image] inputs = self.processor( text, - img, + images, return_tensors="pt", ) inputs["labels"] = None - return inputs, img, answer + return inputs, image, answer diff --git a/heron/datasets/m3it_instruct_datasets.py b/heron/datasets/m3it_instruct_datasets.py new file mode 100644 index 0000000..5a5ce5a --- /dev/null +++ b/heron/datasets/m3it_instruct_datasets.py @@ -0,0 +1,186 @@ +# Copyright 2023 Turing Inc. Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from base64 import b64decode +from io import BytesIO + +import cv2 +import datasets +import numpy as np +import torch +from PIL import Image +from torch.utils.data import ConcatDataset + +from .base_datasets import IGNORE_INDEX, ResilientDataset + +HFProcessor = "HFProcessor" + + +class M3ITInstructDataset(ResilientDataset): + """Dataset for M3IT Dataset learning + This dataset is designed for instruction tuning, meaning it considers the lossese associated with gpt responses. + """ + + def __init__( + self, + loaded_dataset: ConcatDataset, + processor: HFProcessor, + max_length: int, + is_inference: bool = False, + ): + super(M3ITInstructDataset, self).__init__(is_inference) + self.loaded_dataset = loaded_dataset + self.max_length = max_length + self.processor = processor + self.is_inference = is_inference + + @classmethod + def create( + cls, + dataset_config: dict, + processor: HFProcessor, + max_length: int, + split: str = "train", + is_inference: bool = False, + ): + dataset_list = [ + datasets.load_dataset("MMInstruction/M3IT", i, num_proc=16) + for i in dataset_config["dataset_names"] + ] + + # some dataset have no validation + target_dataset_list = [] + for d in dataset_list: + try: + target_dataset_list.append(d[split]) + except KeyError: + print(f"{d['train']._info.config_name} has no {split} set.") + target_dataframe = ConcatDataset(target_dataset_list) + + return cls(target_dataframe, processor, max_length, is_inference) + + def preprocess_image(self, images): + return self.processor(images=images, return_tensors="pt")["pixel_values"][0] + + def tokenize(self, text): + kwargs = {} + return self.processor.tokenizer(text=text, return_tensors="pt", **kwargs) + + def __len__(self) -> int: + return len(self.loaded_dataset) + + def _get_item_train(self, index): + # cf: https://huggingface.co/datasets/MMInstruction/M3IT#data-instances + row = self.loaded_dataset[index] + + # imageのロード + image_base64_str_list = row["image_base64_str"] # str (base64) + image = Image.open(BytesIO(b64decode(image_base64_str_list[0]))).convert("RGB") + image = np.array(image) + images = [image] + + tokenized_list = [] + labels_list = [] + attn_mask_list = [] + + # some of nlvr data were broken + instruction = row["instruction"] # str + question = row["inputs"] # str + answer = row["outputs"] # str + prompt_q = f"##human: {instruction} {question}\n##gpt: " + prompt_a = f"{answer}" + + # ================================ + # tokenize question text + # ================================ + tokenized = self.tokenize(prompt_q) + tokenized_prompt = tokenized["input_ids"][0] + # all label should be ignored + labels = torch.full_like(tokenized_prompt, IGNORE_INDEX) + prompt_attn_mask = tokenized["attention_mask"][0] + + tokenized_list.append(tokenized_prompt) + labels_list.append(labels) + attn_mask_list.append(prompt_attn_mask) + + # ================================ + # tokenize answer text + # ================================ + tokenized = self.tokenize(prompt_a) + tokenized_prompt = tokenized["input_ids"][0][1:] + # all label should be included in loss + labels = tokenized_prompt + prompt_attn_mask = tokenized["attention_mask"][0][1:] + + tokenized_list.append(tokenized_prompt) + labels_list.append(labels) + attn_mask_list.append(prompt_attn_mask) + + # ================================================= + # concat question and answer, apply max_length + # ================================================= + tokenized_prompt = torch.cat(tokenized_list, dim=-1) + labels = torch.cat(labels_list, dim=-1) + prompt_attn_mask = torch.cat(attn_mask_list, dim=-1) + + if len(tokenized_prompt) < self.max_length: + pad_length = self.max_length - len(tokenized_prompt) + tokenized_prompt = torch.cat( + [ + tokenized_prompt, + torch.tensor([self.processor.tokenizer.pad_token_id] * pad_length), + ], + dim=-1, + ) + labels = torch.cat([labels, torch.tensor([IGNORE_INDEX] * pad_length)], dim=-1) + prompt_attn_mask = torch.cat( + [prompt_attn_mask, torch.tensor([0] * pad_length)], dim=-1 + ) + else: + tokenized_prompt = tokenized_prompt[: self.max_length] + labels = labels[: self.max_length] + prompt_attn_mask = prompt_attn_mask[: self.max_length] + + return_dict = { + "input_ids": tokenized_prompt, + "labels": labels, + "attention_mask": prompt_attn_mask, + "pixel_values": self.preprocess_image(images), + } + return return_dict + + def _get_item_inference(self, index): + # cf: https://huggingface.co/datasets/MMInstruction/M3IT#data-instances + row = self.loaded_dataset[index] + + # some of nlvr data were broken + instruction = row["instruction"] # str + question = row["inputs"] # str + answer = row["outputs"] # str + text = f"##Instruction: {instruction} ##Question: {question} ##Answer: " + + # imageのロード + image_base64_str_list = row["image_base64_str"] # str (base64) + image = Image.open(BytesIO(b64decode(image_base64_str_list[0]))).convert("RGB") + image = np.array(image) + images = [image] + + inputs = self.processor( + text, + images, + return_tensors="pt", + ) + inputs["labels"] = None + return inputs, image, answer diff --git a/heron/datasets/utils.py b/heron/datasets/utils.py index 096dd94..076c83c 100644 --- a/heron/datasets/utils.py +++ b/heron/datasets/utils.py @@ -22,28 +22,30 @@ from ..models.prepare_processors import get_processor from .ja_csv_datasets import JapaneseCSVDataset +from .ja_csv_instruct_datasets import JapaneseCSVInstructDataset from .llava_datasets import LlavaDataset +from .llava_instruct_datasets import LlavaInstructDataset from .m3it_datasets import M3ITDataset +from .m3it_instruct_datasets import M3ITInstructDataset +dataset_classes = { + "japanese_csv": JapaneseCSVDataset, + "japanese_csv_instruct": JapaneseCSVInstructDataset, + "llava": LlavaDataset, + "llava_instruct": LlavaInstructDataset, + "m3it": M3ITDataset, + "m3it_instruct": M3ITInstructDataset, +} -def get_each_dataset(dataset_config: Dict, processor, max_length: int) -> Tuple[Dataset, Dataset]: - if dataset_config["dataset_type"] == "m3it": - train_dataset = M3ITDataset.create(dataset_config, processor, max_length, "train") - val_dataset = M3ITDataset.create(dataset_config, processor, max_length, "validation") - - elif dataset_config["dataset_type"] == "japanese_csv": - train_dataset = JapaneseCSVDataset.create(dataset_config, processor, max_length, "train") - val_dataset = JapaneseCSVDataset.create( - dataset_config, processor, max_length, "validation" - ) - elif dataset_config["dataset_type"] == "llava": - train_dataset = LlavaDataset.create(dataset_config, processor, max_length, "train") - val_dataset = LlavaDataset.create(dataset_config, processor, max_length, "validation") - - else: - raise ValueError(f"dataset_type: {dataset_config['dataset_type']} is not supported.") +def get_each_dataset(dataset_config: Dict, processor, max_length: int) -> Tuple[Dataset, Dataset]: + dataset_type = dataset_config["dataset_type"] + if dataset_type not in dataset_classes: + raise ValueError(f"dataset_type: {dataset_type} is not supported.") + DatasetClass = dataset_classes[dataset_type] + train_dataset = DatasetClass.create(dataset_config, processor, max_length, "train") + val_dataset = DatasetClass.create(dataset_config, processor, max_length, "validation") return train_dataset, val_dataset