Skip to content

Commit

Permalink
feat: update heron's datasets (#31)
Browse files Browse the repository at this point in the history
* feat: update llava dataset for ignoring padding loss, and add llava instruct dataset designed for only apply losses to gpt responses

* fix m3it ignore index

* add m3it instruction dataset

* minor fix and add comments

* fix ignore index of japanese csv dataset, and add instruction tuning dataset of japanese csv

* -100 to IGNORE_INDEX

* refactor get_each_dataset function

* remove redundant conversion to RGB

* minor fix: change order of dataset_classes dict
  • Loading branch information
Ino-Ichan authored Mar 4, 2024
1 parent 5a6e018 commit 1c4d50b
Show file tree
Hide file tree
Showing 17 changed files with 840 additions and 91 deletions.
5 changes: 5 additions & 0 deletions configs/datasets/japanese_csv_instruct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
dataset_type: japanese_csv_instruct
dataset_root: "./"
dataset_names:
- coco
- visual_genome
6 changes: 6 additions & 0 deletions configs/datasets/llava_both.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
dataset_type: llava
dataset_root: ./
language: "both"
jsonl_path:
n_train: 157000
n_val: 712
6 changes: 6 additions & 0 deletions configs/datasets/llava_both_instruct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
dataset_type: llava_instruct
dataset_root: ./
jsonl_path:
language: "both"
n_train: 157000
n_val: 712
5 changes: 4 additions & 1 deletion configs/datasets/llava_en.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
dataset_type: llava
dataset_root: "./"
dataset_root: ./
language: "en"
jsonl_path:
n_train: 157000
n_val: 712
6 changes: 6 additions & 0 deletions configs/datasets/llava_en_instruct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
dataset_type: llava_instruct
dataset_root: ./
jsonl_path:
language: "en"
n_train: 157000
n_val: 712
5 changes: 4 additions & 1 deletion configs/datasets/llava_ja.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
dataset_type: llava
dataset_root: "./"
dataset_root: ./
language: "ja"
jsonl_path:
n_train: 157000
n_val: 712
6 changes: 6 additions & 0 deletions configs/datasets/llava_ja_instruct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
dataset_type: llava_instruct
dataset_root: ./
jsonl_path:
language: "ja"
n_train: 157000
n_val: 712
22 changes: 22 additions & 0 deletions configs/datasets/m3it_instruct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
dataset_type: m3it
dataset_names:
- textcap
- image-paragraph-captioning
- coco-goi
- coco-itm
- vqa-v2
- shapes
- docvqa
- ocr-vqa
- st-vqa
- text-vqa
- gqa
- okvqa
- a-okvqa
- viquae
- clevr
- nlvr
- vcr
- visual-mrc
- visual-dialog
- multi30k
3 changes: 3 additions & 0 deletions configs/datasets/m3it_ipc_instruct.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dataset_type: m3it_instruct
dataset_names:
- image-paragraph-captioning
11 changes: 6 additions & 5 deletions heron/datasets/base_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@


import abc
import traceback

from torch.utils.data import Dataset
import traceback

IGNORE_INDEX = -100


class BaseDataset(Dataset):
def __init__(self, is_inference: bool = False):
Expand Down Expand Up @@ -44,13 +47,11 @@ def _get_item_inference(self, index):


class ResilientDataset(BaseDataset):

def __init__(self, is_inference: bool = False, max_trials: int = 5):
def __init__(self, is_inference: bool = False, max_trials: int = 5):
super().__init__(is_inference)
self.max_trials = max_trials

def __getitem__(self, index: int):

if self.is_inference:
return self._get_item_inference(index)
else:
Expand Down
51 changes: 24 additions & 27 deletions heron/datasets/ja_csv_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import cv2
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.utils.data import Dataset

from .base_datasets import BaseDataset
from .base_datasets import IGNORE_INDEX, BaseDataset

HFProcessor = "HFProcessor"

Expand Down Expand Up @@ -91,8 +92,8 @@ def create(
def __len__(self) -> int:
return len(self.unique_img_path)

def preprocess_image(self, image):
return self.processor(images=[image], return_tensors="pt")["pixel_values"][0]
def preprocess_image(self, images):
return self.processor(images=images, return_tensors="pt")["pixel_values"][0]

def tokenize(self, text):
if self.is_inference:
Expand All @@ -108,7 +109,12 @@ def _get_item_train(self, index):
df_interest = self.loaded_dataset[self.loaded_dataset.img_path == img_path].reset_index(
drop=True
)
text = ""
# imageのロード
image = Image.open(os.path.join(self.dataset_root, img_path)).convert("RGB")
image = np.array(image)
images = [image]

prompt = ""

# concatenate text data
order = list(range(len(df_interest)))
Expand All @@ -117,30 +123,23 @@ def _get_item_train(self, index):
row = df_interest.iloc[i]
question = row["question"] # str
answer = row["caption"] # str
text += f"##human: {question}\n##gpt: {answer}\n"

# remove final space
text = text[: len(text) - 1]
prompt += f"##human: {question}\n##gpt: {answer}\n"

# imageのロード
image = Image.open(os.path.join(self.dataset_root, img_path)).convert("RGB")
img = np.array(image)
if img.shape[2] != 3:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
tokenized = self.tokenize(prompt)
tokenized_prompt = tokenized["input_ids"][0]
labels = torch.full_like(tokenized_prompt, IGNORE_INDEX)
prompt_attn_mask = tokenized["attention_mask"][0]

inputs = self.processor(
images=img,
text=text,
return_tensors="pt",
max_length=self.max_length,
padding="max_length",
truncation=True,
)
index_ignore_loss = prompt_attn_mask.sum().item() + 1
labels[:index_ignore_loss] = tokenized_prompt[:index_ignore_loss]

# batch size 1 -> unbatch
inputs = {k: v[0] for k, v in inputs.items()}
inputs["labels"] = inputs["input_ids"]
return inputs
return_dict = {
"input_ids": tokenized_prompt,
"labels": labels,
"attention_mask": prompt_attn_mask,
"pixel_values": self.preprocess_image(images),
}
return return_dict

def _get_item_inference(self, index):
# cf: https://huggingface.co/datasets/MMInstruction/M3IT#data-instances
Expand All @@ -159,8 +158,6 @@ def _get_item_inference(self, index):
# imageのロード
img = Image.open(os.path.join(self.dataset_root, img_path)).convert("RGB")
img = np.array(img)
if img.shape[2] != 3:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

inputs = self.processor(
text,
Expand Down
Loading

0 comments on commit 1c4d50b

Please sign in to comment.