diff --git a/.gitignore b/.gitignore index 58b7753..7a98139 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,8 @@ wandb *egg-info +playground/data/llava-bench-ja/val2014/* + poetry.lock .env diff --git a/eval/llava-bench-ja/gpt_review.py b/eval/llava-bench-ja/gpt_review.py deleted file mode 100644 index 14eb3cd..0000000 --- a/eval/llava-bench-ja/gpt_review.py +++ /dev/null @@ -1,119 +0,0 @@ -import argparse -import json -import os -from openai import OpenAI -import time - -NUM_SECONDS_TO_SLEEP = 0.5 - - -def get_eval(content: str, max_tokens: int): - while True: - try: - client = OpenAI() - response = client.chat.completions.create( - model="gpt-4-0314", - messages=[{ - 'role': 'system', - 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' - }, { - 'role': 'user', - 'content': content, - }], - temperature=0.2, - max_tokens=max_tokens, - ) - break - except Exception as e: - print(e) - time.sleep(NUM_SECONDS_TO_SLEEP) - return response.choices[0].message.content - - -def parse_score(review): - try: - score_pair = review.split('\n')[0] - score_pair = score_pair.replace(',', ' ') - sp = score_pair.split(' ') - if len(sp) == 2: - return [float(sp[0]), float(sp[1])] - else: - print('error', review) - return [-1, -1] - except Exception as e: - print(e) - print('error', review) - return [-1, -1] - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') - parser.add_argument('-q', '--question') - parser.add_argument('-c', '--context') - parser.add_argument('-a', '--answer-list', nargs='+', default=[]) - parser.add_argument('-r', '--rule') - parser.add_argument('-o', '--output') - parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') - args = parser.parse_args() - - f_q = open(os.path.expanduser(args.question)) - f_ans1 = open(os.path.expanduser(args.answer_list[0])) - f_ans2 = open(os.path.expanduser(args.answer_list[1])) - rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) - - if os.path.isfile(os.path.expanduser(args.output)): - cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] - else: - cur_reviews = [] - - review_file = open(f'{args.output}', 'a') - - context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] - image_to_context = {context['image']: context for context in context_list} - - handles = [] - idx = 0 - for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): - ques = json.loads(ques_js) - ans1 = json.loads(ans1_js) - ans2 = json.loads(ans2_js) - - inst = image_to_context[ques['image']] - cap_str = '\n'.join(inst['captions']) - box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) - - category = json.loads(ques_js)['category'] - if category in rule_dict: - rule = rule_dict[category] - else: - assert False, f"Visual QA category not found in rule file: {category}." - prompt = rule['prompt'] - role = rule['role'] - print("ans1", ans1) - print("ans2", ans2) - content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' - f'[Question]\n{ques["text"]}\n\n' - f'[{role} 1]\n{ans1["text_JA"]}\n\n[End of {role} 1]\n\n' - f'[{role} 2]\n{ans2["answer"]}\n\n[End of {role} 2]\n\n' - f'[System]\n{prompt}\n\n') - print("content", content) - cur_js = { - 'id': idx+1, - 'question_id': ques['question_id'], - 'answer1_id': ans1.get('question_id', ans1['question_id']), - 'answer2_id': ans2.get('question_id', ans2['question_id']), - 'category': category - } - if idx >= len(cur_reviews): - review = get_eval(content, args.max_tokens) - scores = parse_score(review) - cur_js['content'] = review - cur_js['tuple'] = scores - review_file.write(json.dumps(cur_js) + '\n') - review_file.flush() - print("review", review) - else: - print(f'Skipping {idx} as we already have it.') - idx += 1 - print("idx", idx) - review_file.close() diff --git a/eval/llava-bench-ja/readme.md b/eval/llava-bench-ja/readme.md deleted file mode 100644 index 859fa4a..0000000 --- a/eval/llava-bench-ja/readme.md +++ /dev/null @@ -1,32 +0,0 @@ -# LLaVA-Bench (COCO) 日本語版 - -このプログラムは本家LLaVAのリポジトリのLLaVA-Benchを日本語に対応させたものです。 -具体的には、比較対象の正解回答を日本語に翻訳し、プロンプトを日本語に対応させたものであり、それ以外は基本的に本家のプログラムと同じです。 - -## 実行手順 - -1. COCO(2014)データのダウンロード - -採点に使う画像のデータをCOCOからダウンロードしてください。 - -``` -wget http://images.cocodataset.org/zips/val2014.zip -unzip val2014.zip -``` - -2. 回答文の推論 - -`inference_coco_bench.ipynb`ノートブックを使って、評価対象の画像に対して回答文を生成する推論を行なってください。 - -3. 評価プログラムの実行 - -以下のコマンドによって採点スクリプトを実行してください。answer.jsonlはステップ1で出力した回答文です。score.jsonは採点プログラムによって出力されるスコアファイルです。 - -``` -OPENAI_API_KEY="sk-..." python gpt_review.py --question qa90_questions_ja.jsonl --contex caps_boxes_coco2014_val_80.jsonl --answer-list qa90_gpt4_answer_ja_v2.jsonl sample_answer.jsonl --rule rule.json --output sample_review.json -``` - -4. スコアの計算と可視化 - -visualize.ipynbを用いて3の結果からLLaVA-Benchのスコアを算出したり、結果を比較して可視化することが可能です。 - diff --git a/heron/eval/eval_gpt_review_visual.py b/heron/eval/eval_gpt_review_visual.py new file mode 100644 index 0000000..6d559a6 --- /dev/null +++ b/heron/eval/eval_gpt_review_visual.py @@ -0,0 +1,245 @@ +import argparse +import asyncio +import json +import os +from collections import defaultdict + +import aiohttp +import matplotlib.pyplot as plt +import numpy as np +import tqdm + +NUM_SECONDS_TO_SLEEP = 0.5 + + +async def get_eval(session, content: str, max_tokens: int): + url = "https://api.openai.com/v1/chat/completions" + headers = { + "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY') }", + "Content-Type": "application/json", + } + payload = { + "model": "gpt-4-0613", + "messages": [ + { + "role": "system", + "content": "You are a helpful and precise assistant for checking the quality of the answer.", + }, + { + "role": "user", + "content": content, + }, + ], + "temperature": 0.2, + "max_tokens": max_tokens, + } + + while True: + try: + async with session.post(url, headers=headers, json=payload) as resp: + if resp.status == 429: + await asyncio.sleep(NUM_SECONDS_TO_SLEEP) + continue + resp.raise_for_status() + data = await resp.json() + return data["choices"][0]["message"]["content"] + except Exception as e: + print(e) + + +def parse_score(review): + try: + score_pair = review.split("\n")[0] + score_pair = score_pair.replace(",", " ") + sp = score_pair.split(" ") + if len(sp) == 2: + return [float(sp[0]), float(sp[1])] + else: + print("error", review) + return [-1, -1] + except Exception as e: + print(e) + print("error", review) + return [-1, -1] + + +def load_jsonl(path, num): + scores = defaultdict(list) + for i, line in enumerate(open(path)): + if i > num: + break + d = json.loads(line) + scores[d["category"]].append(d["tuple"][1]) + scores[d["category"] + "_ref"].append(d["tuple"][0]) + return scores + + +def load_model_results(model_results, num=90): + results = {} + for model_name, result_path in model_results.items(): + scores = load_jsonl(result_path, num) + result = {} + for c, s in scores.items(): + if "ref" not in c: + # 比較対象とターゲットのスコアの平均値の比率をllava-benchのスコアとする + rel_score = 100 * np.mean(s) / np.mean(scores[c + "_ref"]) + result[c] = rel_score + results[model_name] = result + return results + + +def plot_result(model_results, save_plot_name, min_value=0, max_value=110): + # データの設定 + labels = list(model_results[list(model_results.keys())[0]].keys()) + model_scores = {} + for model_name, result in model_results.items(): + model_scores[model_name] = [max(0, result[label]) for label in labels] + model_scores[model_name] += model_scores[model_name][:1] + + # レーダーチャートを描画するための角度を計算 + angles = np.linspace(0, 2 * np.pi, len(labels), endpoint=False).tolist() + angles += angles[:1] # 最初の角度をリストの最後に追加して円を閉じる + + # レーダーチャートの描画 + fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(polar=True)) + + colorlist = ["r", "g", "b", "c", "m", "y", "k", "w"] + for i, (model_name, score) in enumerate(model_scores.items()): + ax.plot(angles, score, color=colorlist[i % len(colorlist)], linewidth=2, label=model_name) + ax.fill(angles, score, color=colorlist[i % len(colorlist)], alpha=0.1) + + # グラフの見た目を調整 + # メモリの追加 + yticks = np.linspace(min_value, max_value, num=5) # min_valueからmax_valueまでを5等分 + ax.set_yticks(yticks) + ax.set_yticklabels([str(round(ytick, 2)) for ytick in yticks]) # メモリに表示する値(小数点第2位まで) + + # ax.set_yticklabels([]) + ax.set_ylim([min_value, max_value]) + ax.set_xticks(angles[:-1]) + ax.set_xticklabels(labels) + ax.legend(loc="upper right", bbox_to_anchor=(0.1, 0.1)) + + # plt.show() + plt.savefig(save_plot_name) + + +async def main(args): + async with aiohttp.ClientSession() as session: + f_q = open(os.path.expanduser(args.question)) + f_ans1 = open(os.path.expanduser(args.answer_list[0])) + f_ans2 = open(os.path.expanduser(args.answer_list[1])) + rule_dict = json.load(open(os.path.expanduser(args.rule), "r")) + + if os.path.isfile(os.path.expanduser(args.output)): + cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] + else: + cur_reviews = [] + + review_file = open(f"{args.output}", "a") + + context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] + image_to_context = {context["image"]: context for context in context_list} + + tasks = [] + cur_js_list = [] + for idx, (ques_js, ans1_js, ans2_js) in enumerate( + tqdm.tqdm(zip(f_q, f_ans1, f_ans2), total=90) + ): + ques = json.loads(ques_js) + ans1 = json.loads(ans1_js) + ans2 = json.loads(ans2_js) + + inst = image_to_context[ques["image"]] + cap_str = "\n".join(inst["captions"]) + box_str = "\n".join( + [f'{instance["category"]}: {instance["bbox"]}' for instance in inst["instances"]] + ) + + category = json.loads(ques_js)["category"] + if category in rule_dict: + rule = rule_dict[category] + else: + assert False, f"Visual QA category not found in rule file: {category}." + prompt = rule["prompt"] + role = rule["role"] + content = ( + f"[Context]\n{cap_str}\n\n{box_str}\n\n" + f'[Question]\n{ques["text"]}\n\n' + f"[{role} 1]\n{ans1[args.gpt4_answer_col]}\n\n[End of {role} 1]\n\n" + f"[{role} 2]\n{ans2[args.answer_col]}\n\n[End of {role} 2]\n\n" + f"[System]\n{prompt}\n\n" + ) + cur_js = { + "id": idx + 1, + "question_id": ques["question_id"], + "answer1_id": ans1.get("answer_id", ans1["question_id"]), + "answer2_id": ans2.get("answer_id", ans2["question_id"]), + "category": category, + } + if idx >= len(cur_reviews): + task = asyncio.create_task(get_eval(session, content, args.max_tokens)) + tasks.append(task) + cur_js_list.append(cur_js) + else: + print(f"Skipping {idx} as we already have it.") + + # Wait for all tasks to complete + results = await asyncio.gather(*tasks) + + # Process results and write to file as before + for result, cur_js in zip(results, cur_js_list): + review = result + scores = parse_score(review) + # Assuming `cur_js` is prepared as before: + cur_js["content"] = review + cur_js["tuple"] = scores + review_file.write(json.dumps(cur_js, ensure_ascii=False) + "\n") + review_file.flush() + + review_file.close() + + name = args.output.split("/")[-1].split(".")[0] + model_results_json = { + name: args.output, + } + model_results = load_model_results(model_results_json) + plot_result(model_results, args.output.replace("json", "png"), 0, 110) + print(f"result: {model_results}") + if args.is_upload_result: + import wandb + + project_name = os.getenv("WANDB_PROJECT_NAME", "default-project") + wandb.init(project=project_name, name=name) + table = wandb.Table(columns=["Name", "mean", "conv", "detail", "complex"]) + for name, ret in model_results.items(): + table.add_data( + name, + (ret["conv"] + ret["detail"] + ret["complex"]) / 3, + ret["conv"], + ret["detail"], + ret["complex"], + ) + wandb.log({"LB: LLaVA Bench Japanese": table}) + print("Upload results to wandb") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ChatGPT-based QA evaluation.") + parser.add_argument("-q", "--question") + parser.add_argument("-c", "--context") + parser.add_argument("-a", "--answer-list", nargs="+", default=[]) + parser.add_argument("-r", "--rule") + parser.add_argument("-o", "--output") + parser.add_argument("-gc", "--gpt4_answer_col", type=str, default="text_JA") + parser.add_argument("-ac", "--answer_col", type=str, default="answer") + parser.add_argument("--is_upload_result", action="store_true") + parser.add_argument( + "--max-tokens", + type=int, + default=1024, + help="maximum number of tokens produced in the output", + ) + args = parser.parse_args() + + asyncio.run(main(args)) diff --git a/heron/eval/gpt_review.py b/heron/eval/gpt_review.py new file mode 100644 index 0000000..4b08c7e --- /dev/null +++ b/heron/eval/gpt_review.py @@ -0,0 +1,132 @@ +import argparse +import json +import os +import time + +from openai import OpenAI + +NUM_SECONDS_TO_SLEEP = 0.5 + + +def get_eval(content: str, max_tokens: int): + while True: + try: + client = OpenAI() + response = client.chat.completions.create( + model="gpt-4-0314", + messages=[ + { + "role": "system", + "content": "You are a helpful and precise assistant for checking the quality of the answer.", + }, + { + "role": "user", + "content": content, + }, + ], + temperature=0.2, + max_tokens=max_tokens, + ) + break + except Exception as e: + print(e) + time.sleep(NUM_SECONDS_TO_SLEEP) + return response.choices[0].message.content + + +def parse_score(review): + try: + score_pair = review.split("\n")[0] + score_pair = score_pair.replace(",", " ") + sp = score_pair.split(" ") + if len(sp) == 2: + return [float(sp[0]), float(sp[1])] + else: + print("error", review) + return [-1, -1] + except Exception as e: + print(e) + print("error", review) + return [-1, -1] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ChatGPT-based QA evaluation.") + parser.add_argument("-q", "--question") + parser.add_argument("-c", "--context") + parser.add_argument("-a", "--answer-list", nargs="+", default=[]) + parser.add_argument("-r", "--rule") + parser.add_argument("-o", "--output") + parser.add_argument( + "--max-tokens", + type=int, + default=1024, + help="maximum number of tokens produced in the output", + ) + args = parser.parse_args() + + f_q = open(os.path.expanduser(args.question)) + f_ans1 = open(os.path.expanduser(args.answer_list[0])) + f_ans2 = open(os.path.expanduser(args.answer_list[1])) + rule_dict = json.load(open(os.path.expanduser(args.rule), "r")) + + if os.path.isfile(os.path.expanduser(args.output)): + cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] + else: + cur_reviews = [] + + review_file = open(f"{args.output}", "a") + + context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] + image_to_context = {context["image"]: context for context in context_list} + + handles = [] + idx = 0 + for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): + ques = json.loads(ques_js) + ans1 = json.loads(ans1_js) + ans2 = json.loads(ans2_js) + + inst = image_to_context[ques["image"]] + cap_str = "\n".join(inst["captions"]) + box_str = "\n".join( + [f'{instance["category"]}: {instance["bbox"]}' for instance in inst["instances"]] + ) + + category = json.loads(ques_js)["category"] + if category in rule_dict: + rule = rule_dict[category] + else: + assert False, f"Visual QA category not found in rule file: {category}." + prompt = rule["prompt"] + role = rule["role"] + print("ans1", ans1) + print("ans2", ans2) + content = ( + f"[Context]\n{cap_str}\n\n{box_str}\n\n" + f'[Question]\n{ques["text"]}\n\n' + f'[{role} 1]\n{ans1["text_JA"]}\n\n[End of {role} 1]\n\n' + f'[{role} 2]\n{ans2["answer"]}\n\n[End of {role} 2]\n\n' + f"[System]\n{prompt}\n\n" + ) + print("content", content) + cur_js = { + "id": idx + 1, + "question_id": ques["question_id"], + "answer1_id": ans1.get("question_id", ans1["question_id"]), + "answer2_id": ans2.get("question_id", ans2["question_id"]), + "category": category, + } + if idx >= len(cur_reviews): + review = get_eval(content, args.max_tokens) + scores = parse_score(review) + cur_js["content"] = review + cur_js["tuple"] = scores + review_file.write(json.dumps(cur_js) + "\n") + review_file.flush() + print("review", review) + else: + print(f"Skipping {idx} as we already have it.") + idx += 1 + print("idx", idx) + review_file.close() diff --git a/heron/eval/inference_llava_bench.py b/heron/eval/inference_llava_bench.py new file mode 100644 index 0000000..4be7694 --- /dev/null +++ b/heron/eval/inference_llava_bench.py @@ -0,0 +1,138 @@ +import json +import os + +import fire +import torch +import yaml +from PIL import Image +from tqdm import tqdm + +import wandb +from heron.models.prepare_processors import get_processor +from heron.models.utils import load_model, load_pretrained_weight + + +def generate_response(question, image, model, processor, device): + """ + Generates a response for a given question and image. + """ + text = f"##human: {question}\n##gpt: " + inputs = processor(text=text, images=image, return_tensors="pt", truncation=True) + inputs = {k: v.to(device) for k, v in inputs.items()} + inputs["pixel_values"] = inputs["pixel_values"].to(device).half() + + eos_token_id_list = [ + processor.tokenizer.pad_token_id, + processor.tokenizer.eos_token_id, + int(processor.tokenizer.convert_tokens_to_ids("\n")), + ] + + with torch.no_grad(): + out = model.generate( + **inputs, + max_length=256, + do_sample=False, + temperature=0.0, + eos_token_id=eos_token_id_list, + no_repeat_ngram_size=2, + ) + return processor.tokenizer.batch_decode(out, skip_special_tokens=True)[0] + + +def load_questions(path): + """ + Loads questions from a JSONL file. + """ + with open(path, "r") as file: + return [json.loads(line) for line in file] + + +def process_questions(img_root, questions, model, processor, device, verbose): + """ + Processes a list of questions, generating answers for each. + """ + results = [] + for q in tqdm(questions): + image = Image.open(os.path.join(img_root, f"COCO_val2014_{q['image']}")) + question = q["text_JA"] + answer = generate_response(question, image, model, processor, device) + if verbose: + print( + f"### ID: {q['question_id']}\n## question: {q['text_JA']}\n## answer: {answer}\n" + ) + q["answer"] = answer + results.append(q) + return results + + +def upload_results(img_root, results, name): + """ + Uploads the results to Weights & Biases. + """ + project_name = os.getenv("WANDB_PROJECT_NAME", "default-project") + wandb.init(project=project_name, name=name) + table = wandb.Table(columns=["ID", "Name", "Image", "Question", "Answer"]) + for r in results: + image = wandb.Image( + Image.open(os.path.join(img_root, f"COCO_val2014_{r['image']}")), caption=r["answer"] + ) + table.add_data(r["question_id"], name, image, r["text_JA"], r["answer"]) + wandb.log({"Table": table}) + + +def save_results(results, output_path, model_name): + """ + Saves the results to a JSONL file. + """ + with open(os.path.join(output_path, f"{model_name}_answers.jsonl"), "w") as file: + for r in results: + file.write(json.dumps(r, ensure_ascii=False) + "\n") + + +def main( + config_file: str, + questions_path: str, + img_root: str, + output_path: str, + device: int = 0, + is_upload_result: bool = False, + verbose: bool = False, +): + with open(config_file, "r") as i_: + config = yaml.safe_load(i_) + model_config = config["model_config"] + + # make output dir + os.makedirs(output_path, exist_ok=True) + + # load model + model = load_model(model_config).to(device) + print("Model loaded") + + # load pretrained weight + if model_config.get("pretrained_path") is not None: + print("load pretrained") + load_pretrained_weight(model, model_config["pretrained_path"]) + print(f'Successfully loading pretrained weights from {model_config["pretrained_path"]}') + + # get preprocessor + processor = get_processor(model_config) + print("Processor loaded") + + questions = load_questions(questions_path) + + print("Start inference") + results = process_questions(img_root, questions, model, processor, device, verbose) + print("Done inference") + + output_model_name = config_file.split("/")[-1].split(".yml")[0] + print("Saving results...") + save_results(results, output_path, output_model_name) + if is_upload_result: + print("Upload to wandb...") + upload_results(img_root, results, output_model_name) + print("Done all evaluation") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/eval/llava-bench-ja/inference_coco_bench.ipynb b/heron/eval/notebook/inference_coco_bench.ipynb similarity index 95% rename from eval/llava-bench-ja/inference_coco_bench.ipynb rename to heron/eval/notebook/inference_coco_bench.ipynb index 1891e1a..2de0434 100644 --- a/eval/llava-bench-ja/inference_coco_bench.ipynb +++ b/heron/eval/notebook/inference_coco_bench.ipynb @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 42, "id": "c2410943", "metadata": {}, "outputs": [ @@ -39,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 45, "id": "13242f99", "metadata": {}, "outputs": [ @@ -93,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 46, "id": "5321773e", "metadata": {}, "outputs": [], @@ -105,7 +105,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 47, "id": "f8edaeb7", "metadata": {}, "outputs": [], @@ -149,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 48, "id": "3e6c4b34", "metadata": {}, "outputs": [], @@ -163,12 +163,12 @@ " data.append(json.loads(line))\n", " return data\n", "\n", - "q_data = load_q(\"qa90_questions_ja.jsonl\")" + "q_data = load_q(\"../../../playground/data/llava-bench-ja/qa90_questions_ja.jsonl\")" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 49, "id": "993fcaca", "metadata": {}, "outputs": [], @@ -176,7 +176,7 @@ "def answer_data(q_data):\n", " result = []\n", " for q in q_data:\n", - " image = Image.open(\"val2014/COCO_val2014_\" + q[\"image\"])\n", + " image = Image.open(\"../../../playground/data/llava-bench-ja/val2014/COCO_val2014_\" + q[\"image\"])\n", " question = q[\"text_JA\"]\n", " display(image)\n", " res = generate_response(question, image)\n", @@ -191,7 +191,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 51, "id": "f659386e", "metadata": {}, "outputs": [], @@ -271,7 +271,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.16" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/eval/llava-bench-ja/visualize.ipynb b/heron/eval/notebook/visualize.ipynb similarity index 100% rename from eval/llava-bench-ja/visualize.ipynb rename to heron/eval/notebook/visualize.ipynb diff --git a/heron/models/prepare_processors.py b/heron/models/prepare_processors.py index ce65917..8bf53b2 100644 --- a/heron/models/prepare_processors.py +++ b/heron/models/prepare_processors.py @@ -21,6 +21,8 @@ LlamaTokenizer, ) +from heron.models.video_blip import VideoBlipProcessor + def get_tokenizer(language_model_name: str) -> "Tokenizer": if "stablelm" in language_model_name: @@ -91,8 +93,8 @@ def get_processor(model_config: Dict) -> "Processor": model_config["vision_model_name"] ) - elif model_type == "video_blip": - processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") + elif "video_blip" in model_type: + processor = VideoBlipProcessor.from_pretrained("Salesforce/blip2-opt-2.7b") else: raise NotImplementedError(f"Processor for model_type: {model_type} is not implemented.") diff --git a/heron/models/utils.py b/heron/models/utils.py index 3aa0a5c..dd641fd 100644 --- a/heron/models/utils.py +++ b/heron/models/utils.py @@ -109,8 +109,11 @@ def load_model( elif model_type == "video_blip": from .video_blip import VideoBlipForConditionalGeneration - model = VideoBlipForConditionalGeneration.create( - language_model, num_frames=num_image_with_embedding, torch_dtype=torch_dtype + model = VideoBlipForConditionalGeneration.from_pretrained( + language_model, + num_frames=num_image_with_embedding, + torch_dtype=torch_dtype, + ignore_mismatched_sizes=False, ) else: diff --git a/images/COCO_val2014_000000441147.jpg b/images/COCO_val2014_000000441147.jpg new file mode 100644 index 0000000..98dc763 Binary files /dev/null and b/images/COCO_val2014_000000441147.jpg differ diff --git a/playground/data/llava-bench-ja/README.md b/playground/data/llava-bench-ja/README.md new file mode 100644 index 0000000..69d732b --- /dev/null +++ b/playground/data/llava-bench-ja/README.md @@ -0,0 +1,52 @@ +