Skip to content

Commit

Permalink
Hashed cross (#587)
Browse files Browse the repository at this point in the history
* Add hashed cross.

* Only hashed cross.

Co-authored-by: mengyao <[email protected]>
Co-authored-by: Marc Romeyn <[email protected]>
Co-authored-by: Gabriel Moreira <[email protected]>
  • Loading branch information
4 people authored Jul 19, 2022
1 parent 4d95bca commit 8e8144e
Show file tree
Hide file tree
Showing 3 changed files with 416 additions and 0 deletions.
2 changes: 2 additions & 0 deletions merlin/models/tf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
AsSparseFeatures,
CategoricalOneHot,
ExpandDims,
HashedCross,
LabelToOneHot,
)

Expand Down Expand Up @@ -151,6 +152,7 @@
"AsRaggedFeatures",
"AsSparseFeatures",
"CategoricalOneHot",
"HashedCross",
"ElementwiseSum",
"ElementwiseSumItemMulti",
"AsTabular",
Expand Down
151 changes: 151 additions & 0 deletions merlin/models/tf/core/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import warnings
from typing import Dict, Optional, Sequence, Union

import tensorflow as tf
from keras.layers.preprocessing import preprocessing_utils

from merlin.models.config.schema import requires_schema
from merlin.models.tf.core.base import Block, PredictionOutput
Expand Down Expand Up @@ -639,3 +641,152 @@ def _check_items_cardinality(self, item_freq_probs):
f"(expected {cardinalities[item_id_feature_name]}"
f", got {tf.shape(item_freq_probs)[0]})"
)


@Block.registry.register("hashed_cross")
@tf.keras.utils.register_keras_serializable(package="merlin.models")
class HashedCross(TabularBlock):
"""A transformation block which crosses categorical features using the "hasing trick".
Conceptually, the transformation can be thought of as: hash(concatenation of features) %
num_bins
Example usage::
model_body = ParallelBlock(
TabularBlock.from_schema(schema=cross_schema, pre=ml.HashedCross(cross_schema,
num_bins = 1000)),
is_input=True).connect(ml.MLPBlock([64, 32]))
model = ml.Model(model_body, ml.BinaryClassificationTask("click"))
Parameters
----------
schema : Schema
The `Schema` with the input features
num_bins : int
Number of hash bins.
output_mode: string
Specification for the output of the layer. Defaults to
`"int"`. Values can be `"int"`, or `"one_hot"` configuring the layer as
follows:
- `"int"`: Return the integer bin indices directly.
- `"one_hot"`: Encodes each individual element in the input into an
array the same size as `num_bins`, containing a 1 at the input's bin
index.
sparse : bool
Boolean. Only applicable to `"one_hot"` mode. If True, returns a
`SparseTensor` instead of a dense `Tensor`. Defaults to False.
output_name : string
Name of output feature, if not specified, default would be
cross_<feature_name>_<feature_name>_<...>
"""

def __init__(
self,
schema: Schema,
num_bins: int,
sparse: bool = False,
output_mode: str = "int",
output_name: str = None,
**kwargs,
):
super().__init__(**kwargs)

if not (output_mode in ["int", "one_hot"]):
raise ValueError("output_mode must be 'int' or 'one_hot'")
self.schema = schema
self.num_bins = num_bins
self.output_mode = output_mode
self.sparse = sparse
if not output_name:
self.output_name = "cross"
for name in self.schema.column_names:
self.output_name = self.output_name + "_" + name
else:
self.output_name = output_name

def call(self, inputs):
self._check_at_least_two_inputs()
_inputs = {}
for name in self.schema.column_names:
_inputs[name] = inputs[name]
rank = _inputs[name].shape.rank
if rank < 2:
_inputs[name] = tf.expand_dims(_inputs[name], -1)
if rank < 1:
_inputs[name] = tf.expand_dims(_inputs[name], -1)

# Perform the cross and convert to dense
output = tf.sparse.cross_hashed(list(_inputs.values()), self.num_bins)
output = tf.sparse.to_dense(output)

# Fix output shape and downrank to match input rank.
if rank == 2:
# tf.sparse.cross_hashed output shape will always be None on the last
# dimension. Given our input shape restrictions, we want to force shape 1
# instead.
output = tf.reshape(output, [-1, 1])
elif rank == 1:
output = tf.reshape(output, [-1])
elif rank == 0:
output = tf.reshape(output, [])

# Encode outputs.
outputs = {}
outputs[self.output_name] = preprocessing_utils.encode_categorical_inputs(
output,
output_mode=self.output_mode,
depth=self.num_bins,
sparse=self.sparse,
)
return outputs

def compute_output_shape(self, input_shapes):
self._check_at_least_two_inputs()
self._check_input_shape_and_type(input_shapes)
output_shape = {}
one_input = list(input_shapes.values())[0]
output_shape[self.output_name] = preprocessing_utils.compute_shape_for_encode_categorical(
shape=one_input, output_mode=self.output_mode, depth=self.num_bins
)
return output_shape

def get_config(self):
config = super().get_config()
config.update(
{
"num_bins": self.num_bins,
"output_mode": self.output_mode,
"sparse": self.sparse,
"output_name": self.output_name,
}
)
if self.schema:
config["schema"] = schema_utils.schema_to_tensorflow_metadata_json(self.schema)
return config

def _check_at_least_two_inputs(self):
if len(self.schema) < 2:
raise ValueError(
"`HashedCrossing` should be called on at least two features. "
f"Received: {len(self.schema)} schemas"
)
for name, column_schema in self.schema.column_schemas.items():
if Tags.CATEGORICAL not in column_schema.tags:
warnings.warn(
f"Please make sure input features to be categorical, detect {name} "
"has no categorical tag"
)

def _check_input_shape_and_type(self, inputs_shapes) -> TabularData:
_inputs_shapes = []
for name in self.schema.column_names:
_inputs_shapes.append(inputs_shapes[name])
first_shape = _inputs_shapes[0].as_list()
rank = len(first_shape)
if rank > 2 or (rank == 2 and first_shape[-1] != 1):
raise ValueError(
"All `HashedCrossing` inputs should have shape `[]`, `[batch_size]` "
f"or `[batch_size, 1]`. Received: input {name} with shape={first_shape}"
)
if not all(x.as_list() == first_shape for x in _inputs_shapes):
raise ValueError(
"All `HashedCrossing` inputs should have equal shape. "
f"Received: inputs={_inputs_shapes}"
)
Loading

0 comments on commit 8e8144e

Please sign in to comment.