Skip to content

Commit

Permalink
chore: address reviewer comments in HuggingFaceChatTarget and update …
Browse files Browse the repository at this point in the history
…tests
  • Loading branch information
KutalVolkan committed Nov 27, 2024
1 parent 8028d13 commit 20ed6ac
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
34 changes: 20 additions & 14 deletions pyrit/prompt_target/hugging_face/hugging_face_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ def __init__(
) -> None:
super().__init__()

if (model_id is None) == (model_path is None):
raise ValueError("Provide exactly one of `model_id` or `model_path`.")
if not model_id and not model_path:
raise ValueError("Either `model_id` or `model_path` must be provided.")
if model_id and model_path:
raise ValueError("Provide only one of `model_id` or `model_path`, not both.")

self.model_id = model_id
self.model_path = model_path
Expand All @@ -62,13 +64,14 @@ def __init__(
self.trust_remote_code = trust_remote_code

# Only get the Hugging Face token if a model ID is provided
self.huggingface_token = (
default_values.get_required_value(
env_var_name=self.HUGGINGFACE_TOKEN_ENVIRONMENT_VARIABLE, passed_value=hf_access_token
if model_id:
self.huggingface_token = default_values.get_required_value(
env_var_name=self.HUGGINGFACE_TOKEN_ENVIRONMENT_VARIABLE,
passed_value=hf_access_token
)
if model_id
else None
)
else:
self.huggingface_token = None


try:
import torch
Expand All @@ -93,6 +96,14 @@ def __init__(
raise RuntimeError("CUDA requested but not available.")

self.load_model_and_tokenizer_task = asyncio.create_task(self.load_model_and_tokenizer())

def _load_from_path(self, path: str):
"""
Helper function to load the model and tokenizer from a given path.
"""
logger.info(f"Loading model and tokenizer from path: {path}...")
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=self.trust_remote_code)
self.model = AutoModelForCausalLM.from_pretrained(path, trust_remote_code=self.trust_remote_code)

def is_model_id_valid(self) -> bool:
"""
Expand Down Expand Up @@ -130,12 +141,7 @@ async def load_model_and_tokenizer(self):
if self.model_path:
# Load the tokenizer and model from the local directory
logger.info(f"Loading model from local path: {self.model_path}...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path, trust_remote_code=self.trust_remote_code
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_path, trust_remote_code=self.trust_remote_code
)
self._load_from_path(self.model_path)
else:
# Define the default Hugging Face cache directory
cache_dir = os.path.join(
Expand Down
4 changes: 2 additions & 2 deletions tests/target/test_huggingface_chat_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,12 @@ def test_init_with_both_model_id_and_model_path_raises():
"""Ensure providing both `model_id` and `model_path` raises an error."""
with pytest.raises(ValueError) as excinfo:
HuggingFaceChatTarget(model_id="test_model", model_path="./mock_local_model_path", use_cuda=False)
assert "Provide exactly one of `model_id` or `model_path`." in str(excinfo.value)
assert "Provide only one of `model_id` or `model_path`, not both." in str(excinfo.value)


@pytest.mark.skipif(not is_torch_installed(), reason="torch is not installed")
def test_load_model_without_model_id_or_path():
"""Ensure initializing without `model_id` or `model_path` raises an error."""
with pytest.raises(ValueError) as excinfo:
HuggingFaceChatTarget(use_cuda=False)
assert "Provide exactly one of `model_id` or `model_path`." in str(excinfo.value)
assert "Either `model_id` or `model_path` must be provided." in str(excinfo.value)

0 comments on commit 20ed6ac

Please sign in to comment.