Skip to content

Commit

Permalink
Merge pull request #204 from vespa-engine/tgm/remove-ml-dependency-fr…
Browse files Browse the repository at this point in the history
…om-package

Remove ml module dependency from package module
  • Loading branch information
lesters authored Sep 14, 2021
2 parents ce7d623 + 9b495bd commit 94074e0
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 145 deletions.
129 changes: 122 additions & 7 deletions vespa/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
from pathlib import Path
from urllib.parse import urlencode

from vespa.package import (
ModelConfig,
Task,
OnnxModel,
QueryTypeField,
Field,
Function,
RankProfile,
)
from vespa.json_serialization import ToJson, FromJson

#
Expand All @@ -23,7 +32,7 @@
raise Exception("Use pip install pyvespa[ml] to install ml dependencies.")


class TextTask(object):
class TextTask(Task):
def __init__(
self,
model_id: str,
Expand All @@ -39,7 +48,7 @@ def __init__(
:param tokenizer: Id of the tokenizer as used by the model hub.
:param output_file: Output file to write output messages.
"""
self.model_id = model_id
super().__init__(model_id=model_id)
self.model = model
self.tokenizer = tokenizer
if not self.tokenizer:
Expand Down Expand Up @@ -138,11 +147,6 @@ def create_url_encoded_tokens(self, x):
return encoded_tokens


class ModelConfig(object):
def __init__(self, model_id) -> None:
self.model_id = model_id


class BertModelConfig(ModelConfig, ToJson, FromJson["BertModelConfig"]):
def __init__(
self,
Expand Down Expand Up @@ -368,6 +372,117 @@ def export_to_onnx(self, output_path: str) -> None:
else:
raise ValueError("No BERT model found to be exported.")

def onnx_model(self):
model_file_path = self.model_id + ".onnx"
self.export_to_onnx(output_path=model_file_path)

return OnnxModel(
model_name=self.model_id,
model_file_path=model_file_path,
inputs={
"input_ids": "input_ids",
"token_type_ids": "token_type_ids",
"attention_mask": "attention_mask",
},
outputs={"output_0": "logits"},
)

def query_profile_type_fields(self):
return [
QueryTypeField(
name="ranking.features.query({})".format(self.query_token_ids_name),
type="tensor<float>(d0[{}])".format(int(self.actual_query_input_size)),
)
]

def document_fields(self, document_field_indexing):
if not document_field_indexing:
document_field_indexing = ["attribute", "summary"]

return [
Field(
name=self.doc_token_ids_name,
type="tensor<float>(d0[{}])".format(int(self.actual_doc_input_size)),
indexing=document_field_indexing,
),
]

def rank_profile(self, include_model_summary_features, **kwargs):
constants = {"TOKEN_NONE": 0, "TOKEN_CLS": 101, "TOKEN_SEP": 102}
if "contants" in kwargs:
constants.update(kwargs.pop("contants"))

functions = [
Function(
name="question_length",
expression="sum(map(query({}), f(a)(a > 0)))".format(
self.query_token_ids_name
),
),
Function(
name="doc_length",
expression="sum(map(attribute({}), f(a)(a > 0)))".format(
self.doc_token_ids_name
),
),
Function(
name="input_ids",
expression="tokenInputIds({}, query({}), attribute({}))".format(
self.input_size,
self.query_token_ids_name,
self.doc_token_ids_name,
),
),
Function(
name="attention_mask",
expression="tokenAttentionMask({}, query({}), attribute({}))".format(
self.input_size,
self.query_token_ids_name,
self.doc_token_ids_name,
),
),
Function(
name="token_type_ids",
expression="tokenTypeIds({}, query({}), attribute({}))".format(
self.input_size,
self.query_token_ids_name,
self.doc_token_ids_name,
),
),
Function(
name="logit0",
expression="onnx(" + self.model_id + ").logits{d0:0,d1:0}",
),
Function(
name="logit1",
expression="onnx(" + self.model_id + ").logits{d0:0,d1:1}",
),
]
if "functions" in kwargs:
functions.extend(kwargs.pop("functions"))

summary_features = []
if include_model_summary_features:
summary_features.extend(
[
"logit0",
"logit1",
"input_ids",
"attention_mask",
"token_type_ids",
]
)
if "summary_features" in kwargs:
summary_features.extend(kwargs.pop("summary_features"))

return RankProfile(
name=self.model_id,
constants=constants,
functions=functions,
summary_features=summary_features,
**kwargs
)

@staticmethod
def from_dict(mapping: Mapping) -> "BertModelConfig":
return BertModelConfig(
Expand Down
Loading

0 comments on commit 94074e0

Please sign in to comment.