From ee03bb1cd2eb7b36ebe1983f0ea151551cbc6927 Mon Sep 17 00:00:00 2001 From: edknv <109497216+edknv@users.noreply.github.com> Date: Tue, 28 Feb 2023 11:44:24 -0800 Subject: [PATCH] Increase tolerance in retrieval transformer test and random seed (#1007) --- tests/unit/tf/transformers/test_block.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/tf/transformers/test_block.py b/tests/unit/tf/transformers/test_block.py index 35d4e59a61..9a140765bd 100644 --- a/tests/unit/tf/transformers/test_block.py +++ b/tests/unit/tf/transformers/test_block.py @@ -3,6 +3,7 @@ import numpy as np import pytest import tensorflow as tf +from tensorflow.keras.utils import set_random_seed from transformers import BertConfig import merlin.models.tf as mm @@ -27,6 +28,7 @@ def test_import(): @pytest.mark.parametrize("run_eagerly", [True]) def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): + set_random_seed(42) sequence_testing_data.schema = sequence_testing_data.schema.select_by_tag( Tags.SEQUENCE @@ -78,7 +80,7 @@ def test_retrieval_transformer(sequence_testing_data: Dataset, run_eagerly): assert list(item_embeddings.shape) == [51997, d_model] predicitons_2 = np.dot(query_embeddings, item_embeddings.T) - np.testing.assert_allclose(predictions, predicitons_2, atol=1e-7) + np.testing.assert_allclose(predictions, predicitons_2, atol=1e-6) def test_transformer_encoder():