Skip to content

Commit

Permalink
Fix llava
Browse files Browse the repository at this point in the history
  • Loading branch information
pierre.delaunay committed Sep 5, 2024
1 parent ea44ea6 commit e5505ee
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 19 deletions.
42 changes: 42 additions & 0 deletions benchmarks/llava/benchfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from milabench.pack import Package
from milabench.commands import AccelerateAllNodes


class Llava(Package):
# Requirements file installed by install(). It can be empty or absent.
base_requirements = "requirements.in"

# The preparation script called by prepare(). It must be executable,
# but it can be any type of script. It can be empty or absent.
prepare_script = "prepare.py"

# The main script called by run(). It must be a Python file. It has to
# be present.
main_script = "main.py"

# You can remove the functions below if you don't need to modify them.

def make_env(self):
# Return a dict of environment variables for prepare_script and
# main_script.
return super().make_env()

async def install(self):
await super().install() # super() call installs the requirements

async def prepare(self):
await super().prepare() # super() call executes prepare_script

def build_run_plan(self):
from milabench.commands import PackCommand

main = self.dirs.code / self.main_script
plan = PackCommand(self, *self.argv, lazy=True)

if False:
plan = VoirCommand(plan, cwd=main.parent)

return AccelerateAllNodes(plan).use_stdout()


__pack__ = Llava
53 changes: 40 additions & 13 deletions benchmarks/llava/main.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# This is the script run by milabench run (by default)
#!/usr/bin/env python

import time
from dataclasses import dataclass

import numpy as np
import torch
from accelerate import Accelerator
from accelerate.utils import set_seed
Expand All @@ -12,6 +11,7 @@
from torch.utils.data.dataloader import default_collate
from transformers import AutoProcessor, LlavaForConditionalGeneration

import argklass
from benchmate.observer import BenchObserver


Expand All @@ -34,30 +34,45 @@ def custom_collate(batch):
return default_collate(batch)


@dataclass
class Arguments:
batch_size: int = 10
epochs: int = 10
seed: int = 42
num_workers: int = 5
gradient_accumulation_steps: int = 4


def main():
parser = argklass.ArgumentParser(description="llava")
parser.add_arguments(Arguments)
args = parser.parse_args()

accelerator = Accelerator(
mixed_precision="no",
gradient_accumulation_steps=4,
gradient_accumulation_steps=args.gradient_accumulation_steps,
log_with="all",
project_dir="logs",
)

set_seed(42)
batch_size = 1 # Set to 1 for now, but can be easily changed
num_epochs = 1
set_seed(args.seed)

# Load LLaVA model and processor with device_map="auto"
model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.float32, # Change to float32
torch_dtype=torch.bfloat16,
device_map="auto",
)
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

# Load dataset and create DataLoader
dataset = load_dataset("HuggingFaceM4/the_cauldron", "aokvqa")["train"]
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate
dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=custom_collate,
num_workers=args.num_workers
)

def batch_size_fn(batch):
Expand All @@ -68,13 +83,14 @@ def batch_size_fn(batch):
)

observer = BenchObserver(
batch_size_fn=batch_size_fn, earlystop=70, raise_stop_program=True
batch_size_fn=batch_size_fn, earlystop=70, raise_stop_program=True,
stdout=True,
)
optimizer = observer.optimizer(torch.optim.AdamW(model.parameters(), lr=5e-5))
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)

for epoch in range(num_epochs):
for batch in observer.iterate(dataloader):
for epoch in range(args.epochs):
for i, batch in enumerate(observer.iterate(dataloader)):
images = batch["images"][0] # Access the first item in the list of images
texts = batch["texts"]
prompt = apply_chat_template(texts)
Expand All @@ -93,7 +109,11 @@ def batch_size_fn(batch):
)
for k, v in inputs.items()
}

inputs["labels"] = inputs["input_ids"]

outputs = model(**inputs)

loss = outputs.loss
accelerator.backward(loss)

Expand All @@ -111,4 +131,11 @@ def batch_size_fn(batch):


if __name__ == "__main__":
main()
from voir.phase import StopProgram
from benchmate.monitor import bench_monitor

try:
with bench_monitor():
main()
except StopProgram:
pass
22 changes: 22 additions & 0 deletions benchmarks/llava/prepare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env python

import torch
from datasets import load_dataset
from transformers import AutoProcessor, LlavaForConditionalGeneration


def main():
# Load LLaVA model and processor with device_map="auto"
_ = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.float32, # Change to float32
device_map="auto",
)
_ = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

# Load dataset and create DataLoader
_ = load_dataset("HuggingFaceM4/the_cauldron", "aokvqa")["train"]


if __name__ == "__main__":
main()
7 changes: 7 additions & 0 deletions benchmarks/llava/requirements.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
voir>=0.2.19,<0.3
torch
numpy
accelerate
pillow
datasets
transformers
38 changes: 38 additions & 0 deletions benchmarks/llava/voirfile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dataclasses import dataclass

from voir import configurable
from voir.instruments import dash, early_stop, log, rate
from benchmate.monitor import monitor_monogpu

@dataclass
class Config:
"""voir configuration"""

# Whether to display the dash or not
dash: bool = False

# How often to log the rates
interval: str = "1s"

# Number of rates to skip before logging
skip: int = 5

# Number of rates to log before stopping
stop: int = 20

# Number of seconds between each gpu poll
gpu_poll: int = 3


@configurable
def instrument_main(ov, options: Config):
yield ov.phases.init

if options.dash:
ov.require(dash)

ov.require(
log("value", "progress", "rate", "units", "loss", "gpudata", context="task"),
early_stop(n=options.stop, key="rate", task="train"),
monitor_monogpu(poll_interval=options.gpu_poll),
)
1 change: 0 additions & 1 deletion benchmarks/purejaxrl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# clone_subtree in the benchfile.py, in which case this file can simply
# be deleted.

import argparse
import argklass


Expand Down
43 changes: 43 additions & 0 deletions config/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -732,3 +732,46 @@ torchatari:
--num-envs: auto({cpu_per_gpu}, 128)
--total-timesteps: 1000000
--env-id: Breakout-v5


llava:
inherits: _defaults
definition: ../benchmarks/llava
install_group: torch
plan:
method: per_gpu

tags:
- llm
argv:
--batch_size: 1
--num_workers: 4


llava-single:
inherits: _defaults
definition: ../benchmarks/llava
install_group: torch
plan:
method: per_gpu

tags:
- llm
argv:
--batch_size: 1
--num_workers: 4

llava-gpus:
inherits: _defaults
definition: ../benchmarks/llava
install_group: torch
plan:
method: njobs
n: 1

tags:
- llm
argv:
--batch_size: 1
--num_workers: 4
--gradient_accumulation_steps: 1
7 changes: 3 additions & 4 deletions milabench/_version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""This file is generated, do not modify"""

__tag__ = "v0.1.0-28-g8069946"
__commit__ = "8069946d331fb92090057d7eedd598515249521d"
__date__ = "2024-08-01 12:39:13 -0400"

__tag__ = "v0.1.0-82-gea44ea63"
__commit__ = "ea44ea63be161bea2dd22c6dd23b1386474f09a7"
__date__ = "2024-09-05 12:03:19 -0400"
2 changes: 1 addition & 1 deletion scripts/article/run_cuda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ install_prepare() {
#
# Install milabench's benchmarks in their venv
#
# milabench pin --variant cuda --from-scratch $ARGS
milabench pin --variant cuda --from-scratch $ARGS
milabench install --system $MILABENCH_WORDIR/system.yaml $ARGS

which pip
Expand Down

0 comments on commit e5505ee

Please sign in to comment.