-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Japanese LLaVA-Bench (COCO) Evaluation (#29)
* llava-benchの追加 * update directory
- Loading branch information
1 parent
79cd99f
commit c4809ff
Showing
10 changed files
with
1,022 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import argparse | ||
import json | ||
import os | ||
from openai import OpenAI | ||
import time | ||
|
||
NUM_SECONDS_TO_SLEEP = 0.5 | ||
|
||
|
||
def get_eval(content: str, max_tokens: int): | ||
while True: | ||
try: | ||
client = OpenAI() | ||
response = client.chat.completions.create( | ||
model="gpt-4-0314", | ||
messages=[{ | ||
'role': 'system', | ||
'content': 'You are a helpful and precise assistant for checking the quality of the answer.' | ||
}, { | ||
'role': 'user', | ||
'content': content, | ||
}], | ||
temperature=0.2, | ||
max_tokens=max_tokens, | ||
) | ||
break | ||
except Exception as e: | ||
print(e) | ||
time.sleep(NUM_SECONDS_TO_SLEEP) | ||
return response.choices[0].message.content | ||
|
||
|
||
def parse_score(review): | ||
try: | ||
score_pair = review.split('\n')[0] | ||
score_pair = score_pair.replace(',', ' ') | ||
sp = score_pair.split(' ') | ||
if len(sp) == 2: | ||
return [float(sp[0]), float(sp[1])] | ||
else: | ||
print('error', review) | ||
return [-1, -1] | ||
except Exception as e: | ||
print(e) | ||
print('error', review) | ||
return [-1, -1] | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') | ||
parser.add_argument('-q', '--question') | ||
parser.add_argument('-c', '--context') | ||
parser.add_argument('-a', '--answer-list', nargs='+', default=[]) | ||
parser.add_argument('-r', '--rule') | ||
parser.add_argument('-o', '--output') | ||
parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') | ||
args = parser.parse_args() | ||
|
||
f_q = open(os.path.expanduser(args.question)) | ||
f_ans1 = open(os.path.expanduser(args.answer_list[0])) | ||
f_ans2 = open(os.path.expanduser(args.answer_list[1])) | ||
rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) | ||
|
||
if os.path.isfile(os.path.expanduser(args.output)): | ||
cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] | ||
else: | ||
cur_reviews = [] | ||
|
||
review_file = open(f'{args.output}', 'a') | ||
|
||
context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] | ||
image_to_context = {context['image']: context for context in context_list} | ||
|
||
handles = [] | ||
idx = 0 | ||
for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): | ||
ques = json.loads(ques_js) | ||
ans1 = json.loads(ans1_js) | ||
ans2 = json.loads(ans2_js) | ||
|
||
inst = image_to_context[ques['image']] | ||
cap_str = '\n'.join(inst['captions']) | ||
box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) | ||
|
||
category = json.loads(ques_js)['category'] | ||
if category in rule_dict: | ||
rule = rule_dict[category] | ||
else: | ||
assert False, f"Visual QA category not found in rule file: {category}." | ||
prompt = rule['prompt'] | ||
role = rule['role'] | ||
print("ans1", ans1) | ||
print("ans2", ans2) | ||
content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' | ||
f'[Question]\n{ques["text"]}\n\n' | ||
f'[{role} 1]\n{ans1["text_JA"]}\n\n[End of {role} 1]\n\n' | ||
f'[{role} 2]\n{ans2["answer"]}\n\n[End of {role} 2]\n\n' | ||
f'[System]\n{prompt}\n\n') | ||
print("content", content) | ||
cur_js = { | ||
'id': idx+1, | ||
'question_id': ques['question_id'], | ||
'answer1_id': ans1.get('question_id', ans1['question_id']), | ||
'answer2_id': ans2.get('question_id', ans2['question_id']), | ||
'category': category | ||
} | ||
if idx >= len(cur_reviews): | ||
review = get_eval(content, args.max_tokens) | ||
scores = parse_score(review) | ||
cur_js['content'] = review | ||
cur_js['tuple'] = scores | ||
review_file.write(json.dumps(cur_js) + '\n') | ||
review_file.flush() | ||
print("review", review) | ||
else: | ||
print(f'Skipping {idx} as we already have it.') | ||
idx += 1 | ||
print("idx", idx) | ||
review_file.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "bef69431", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import sys\n", | ||
"sys.path.append(\"..\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "c2410943", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[2024-01-08 06:38:34,481] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import torch\n", | ||
"from heron.models.video_blip import VideoBlipForConditionalGeneration, VideoBlipProcessor\n", | ||
"from transformers import LlamaTokenizer\n", | ||
"import wandb\n", | ||
"\n", | ||
"device_id = 0\n", | ||
"device = f\"cuda:{device_id}\"\n", | ||
"\n", | ||
"max_length = 512" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "13242f99", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "d5b69d63148f4820bf917f955aedc853", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Downloading config.json: 0%| | 0.00/577 [00:00<?, ?B/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"data": { | ||
"application/vnd.jupyter.widget-view+json": { | ||
"model_id": "3fe2676936bc4ac090a8fe6554b28f73", | ||
"version_major": 2, | ||
"version_minor": 0 | ||
}, | ||
"text/plain": [ | ||
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Some weights of VideoBlipForConditionalGeneration were not initialized from the model checkpoint at /mnt/disks/disk2/model_out/stablelm-beta/abci-exp001 and are newly initialized because the shapes did not match:\n", | ||
"- text_projection.bias: found shape torch.Size([2560]) in the checkpoint and torch.Size([4096]) in the model instantiated\n", | ||
"- text_projection.weight: found shape torch.Size([2560, 768]) in the checkpoint and torch.Size([4096, 768]) in the model instantiated\n", | ||
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", | ||
"You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"MODEL_NAME = \"turing-motors/heron-chat-blip-ja-stablelm-base-7b-v1\"\n", | ||
"model = VideoBlipForConditionalGeneration.from_pretrained(\n", | ||
" MODEL_NAME, torch_dtype=torch.float16, ignore_mismatched_sizes=True\n", | ||
")\n", | ||
"tokenizer = LlamaTokenizer.from_pretrained(\"novelai/nerdstash-tokenizer-v1\", additional_special_tokens=['▁▁'])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "5321773e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = model.half()\n", | ||
"model.eval()\n", | ||
"model.to(device)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "f8edaeb7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# prepare a processor\n", | ||
"processor = VideoBlipProcessor.from_pretrained(\"Salesforce/blip2-opt-2.7b\")\n", | ||
"processor.tokenizer = tokenizer\n", | ||
"\n", | ||
"import requests\n", | ||
"from PIL import Image\n", | ||
"\n", | ||
"def generate_response(question, image):\n", | ||
" # prepare inputs\n", | ||
" text = f\"##human: {question}\\n##gpt: \"\n", | ||
"\n", | ||
" # do preprocessing\n", | ||
" inputs = processor(\n", | ||
" text=text,\n", | ||
" images=image,\n", | ||
" return_tensors=\"pt\",\n", | ||
" add_special_tokens=False,\n", | ||
" truncation=True,\n", | ||
" )\n", | ||
"\n", | ||
" inputs = {k: v.to(device) for k, v in inputs.items()}\n", | ||
" inputs[\"pixel_values\"] = inputs[\"pixel_values\"].to(device, torch.float16)\n", | ||
" \n", | ||
" # set eos token\n", | ||
" eos_token_id_list = [\n", | ||
" processor.tokenizer.pad_token_id,\n", | ||
" processor.tokenizer.eos_token_id,\n", | ||
" int(tokenizer.convert_tokens_to_ids(\"\\n\"))\n", | ||
" ]\n", | ||
"\n", | ||
" # do inference\n", | ||
" with torch.no_grad():\n", | ||
" out = model.generate(**inputs, max_length=256, do_sample=False, temperature=0., eos_token_id=eos_token_id_list, no_repeat_ngram_size=2)\n", | ||
" res = processor.tokenizer.batch_decode(out, skip_special_tokens=True)\n", | ||
" return res[0]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"id": "3e6c4b34", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import json\n", | ||
"from PIL import Image\n", | ||
"\n", | ||
"def load_q(p):\n", | ||
" data = []\n", | ||
" for line in open(p):\n", | ||
" data.append(json.loads(line))\n", | ||
" return data\n", | ||
"\n", | ||
"q_data = load_q(\"qa90_questions_ja.jsonl\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"id": "993fcaca", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def answer_data(q_data):\n", | ||
" result = []\n", | ||
" for q in q_data:\n", | ||
" image = Image.open(\"val2014/COCO_val2014_\" + q[\"image\"])\n", | ||
" question = q[\"text_JA\"]\n", | ||
" display(image)\n", | ||
" res = generate_response(question, image)\n", | ||
" print(question)\n", | ||
" if \"##\" in res:\n", | ||
" res = res.split(\"##\")[0]\n", | ||
" print(\"final\", res)\n", | ||
" q[\"answer\"] = res\n", | ||
" result.append(q)\n", | ||
" return result" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f659386e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"result = answer_data(q_data)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 128, | ||
"id": "5596a621-706b-4252-8441-2737c784e805", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model_name = \"stablelm-alpha-exp001\"" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a467b8c9", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# wandbに結果をアップロードしたい場合\n", | ||
"def upload_result(result, name):\n", | ||
" wandb.init(project=\"heron-eval\", name=name)\n", | ||
" table = wandb.Table(columns=['ID', 'Image', 'Question', 'Answer'])\n", | ||
" for r in result:\n", | ||
" image = Image.open(\"val2014/COCO_val2014_\" + r[\"image\"])\n", | ||
" answer = r[\"answer\"]\n", | ||
" img = wandb.Image(image, caption=answer)\n", | ||
" idx = r[\"question_id\"]\n", | ||
" table.add_data(idx, img, r[\"text_JA\"], answer)\n", | ||
" wandb.log({\"Table\" : table})\n", | ||
" \n", | ||
"upload_result(result, model_name)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 129, | ||
"id": "7246349b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def save_jsonl(jsonl, model_name):\n", | ||
" with open(f\"{model_name}_answer.jsonl\", \"w\") as f:\n", | ||
" for r in jsonl:\n", | ||
" f.write(json.dumps(r)+\"\\n\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 130, | ||
"id": "1b930312", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"save_jsonl(result, model_name)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.16" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.