Skip to content

Commit

Permalink
More cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
dhan0779 committed Dec 2, 2023
1 parent ac9c2db commit 8d73ba0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 45 deletions.
32 changes: 5 additions & 27 deletions src/model/Siamese_Network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import random
import tensorflow as tf
from pathlib import Path

from utils import preprocess_image, preprocess_triplets

from tensorflow.keras import layers
from tensorflow.keras import optimizers
from tensorflow.keras import metrics
Expand All @@ -16,32 +19,6 @@
anchor_images_path = cache_dir / "left"
positive_images_path = cache_dir / "right"

def preprocess_image(filename):
"""
Load the specified file as a JPEG image, preprocess it and
resize it to the target shape.
"""

image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)
return image


def preprocess_triplets(anchor, positive, negative):
"""
Given the filenames corresponding to the three images, load and
preprocess them.
"""

return (
preprocess_image(anchor),
preprocess_image(positive),
preprocess_image(negative),
)


class DistanceLayer(layers.Layer):
"""
This layer is responsible for computing the distance between the anchor
Expand Down Expand Up @@ -190,4 +167,5 @@ def metrics(self):
siamese_model = SiameseModel(siamese_network)
siamese_model.compile(optimizer=optimizers.Adam(0.0001)) # 0..0001
siamese_model.fit(train_dataset, epochs=7, validation_data=val_dataset)
embedding.save("siamese_feature.h5")
embedding.save_weights("weights/embedding_weights.h5")
siamese_model.save_weights("weights/siamese_weights.h5")
25 changes: 7 additions & 18 deletions src/model/Siamese_Predictor.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,17 @@
# from Siamese_Network import preprocess_image
import threading
import tensorflow as tf
import time

from utils import preprocess_image

from tensorflow.keras.applications import resnet
import argparse
from tensorflow.keras import metrics
import time

from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler
import threading

def preprocess_image(filename):
"""
Load the specified file as a JPEG image, preprocess it and
resize it to the target shape.
"""
target_shape = (200, 200)

image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)
return image

# load the model
embedding = tf.keras.models.load_model("siamese_feature.h5", compile=False)
embedding = tf.keras.models.load_model("siamese_weights.h5", compile=False)
print("Model is done loading")

# load david base
Expand Down
28 changes: 28 additions & 0 deletions src/model/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import tensorflow as tf

target_shape = (200, 200)

def preprocess_image(filename):
"""
Load the specified file as a JPEG image, preprocess it and
resize it to the target shape.
"""

image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)
return image


def preprocess_triplets(anchor, positive, negative):
"""
Given the filenames corresponding to the three images, load and
preprocess them.
"""

return (
preprocess_image(anchor),
preprocess_image(positive),
preprocess_image(negative),
)

0 comments on commit 8d73ba0

Please sign in to comment.