Skip to content

Commit

Permalink
Threaded model similarity score calculation
Browse files Browse the repository at this point in the history
Co-authored-by: JamesZhang2<[email protected]>
Co-authored-by: lisarli <[email protected]>
  • Loading branch information
srishagaur and lisarli committed Dec 2, 2023
1 parent 6bd431f commit 0a0a716
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/model/Siamese_Predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler
import threading

def preprocess_image(filename):
"""
Expand All @@ -22,32 +23,36 @@ def preprocess_image(filename):

# load the model
embedding = tf.keras.models.load_model("siamese_feature.h5", compile=False)
print("Model is done loading")

# load david base
david_2 = preprocess_image("david_base.jpg")
david_2 = tf.expand_dims(david_2, axis=0)

def get_similarity_score(img_path):
print("Wait a bit for image to write?")
time.sleep(5);
# TODO: lol, fix this
print("After 5 seconds")
time.sleep(0.5);
david_1 = preprocess_image(img_path)
david_2 = preprocess_image("david_base.jpg")

# add a dimension to the tensor
david_1 = tf.expand_dims(david_1, axis=0)
david_2 = tf.expand_dims(david_2, axis=0)


# get embeddings
anchor_embedding, positive_embedding = (
embedding(resnet.preprocess_input(david_1)),
embedding(resnet.preprocess_input(david_2)),
)


# get similarity score
cosine_similarity = metrics.CosineSimilarity()

positive_similarity = cosine_similarity(anchor_embedding, positive_embedding)
print("Similarity score: ", positive_similarity.numpy())
print(f"Similarity score for {img_path}: ", positive_similarity.numpy())

# set up on created
def on_created(event):
print(f"hey, {event.src_path} has been created!")
get_similarity_score(event.src_path)
model_thread = threading.Thread(target=lambda: get_similarity_score(event.src_path))
model_thread.start()

# set up watchdog
patterns = ["*"]
Expand All @@ -58,7 +63,7 @@ def on_created(event):
my_event_handler.on_created = on_created

# create observer
path = "../../Images"
path = "../../Images/bounding_boxes"
go_recursively = True
my_observer = Observer()
my_observer.schedule(my_event_handler, path, recursive=go_recursively)
Expand Down

0 comments on commit 0a0a716

Please sign in to comment.