Skip to content

Commit

Permalink
Merge pull request #9 from CornellDataScience/siamese-model
Browse files Browse the repository at this point in the history
Model cleanup
  • Loading branch information
dhan0779 authored Dec 2, 2023
2 parents 40d9da2 + 92f0cee commit b6b1877
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 147 deletions.
197 changes: 68 additions & 129 deletions src/model/Siamese_Network.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# https://keras.io/examples/vision/siamese_network/

import matplotlib.pyplot as plt
import numpy as np
import os
import random
import tensorflow as tf
from pathlib import Path
from tensorflow.keras import applications

from utils import preprocess_image, preprocess_triplets

from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import optimizers
from tensorflow.keras import metrics
from tensorflow.keras import Model
Expand All @@ -20,113 +19,6 @@
anchor_images_path = cache_dir / "left"
positive_images_path = cache_dir / "right"


def preprocess_image(filename):
"""
Load the specified file as a JPEG image, preprocess it and
resize it to the target shape.
"""

image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)
return image


def preprocess_triplets(anchor, positive, negative):
"""
Given the filenames corresponding to the three images, load and
preprocess them.
"""

return (
preprocess_image(anchor),
preprocess_image(positive),
preprocess_image(negative),
)


# We need to make sure both the anchor and positive images are loaded in
# sorted order so we can match them together.
anchor_images = sorted(
[str(anchor_images_path / f) for f in os.listdir(anchor_images_path)]
)

positive_images = sorted(
[str(positive_images_path / f) for f in os.listdir(positive_images_path)]
)

all_images = anchor_images + positive_images
random.shuffle(all_images)
print(all_images)

image_count = len(anchor_images)
print(image_count)

anchor_dataset = tf.data.Dataset.from_tensor_slices(anchor_images)
positive_dataset = tf.data.Dataset.from_tensor_slices(positive_images)
negative_dataset = tf.data.Dataset.from_tensor_slices(all_images[:100])

print(len(negative_dataset))
print(len(anchor_dataset))

dataset = tf.data.Dataset.zip((anchor_dataset, positive_dataset, negative_dataset))
dataset = dataset.shuffle(buffer_size=1024)
dataset = dataset.map(preprocess_triplets)

# Let's now split our dataset in train and validation.
train_dataset = dataset.take(round(image_count * 0.8))
val_dataset = dataset.skip(round(image_count * 0.8))

batch_size = 16

train_dataset = train_dataset.batch(batch_size, drop_remainder=False)
train_dataset = train_dataset.prefetch(8)

val_dataset = val_dataset.batch(batch_size, drop_remainder=False)
val_dataset = val_dataset.prefetch(8)


def visualize(anchor, positive, negative):
"""Visualize a few triplets from the supplied batches."""

def show(ax, image):
ax.imshow(image)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)

fig = plt.figure(figsize=(9, 9))

axs = fig.subplots(3, 3)
for i in range(3):
show(axs[i, 0], anchor[i])
show(axs[i, 1], positive[i])
show(axs[i, 2], negative[i])


visualize(*list(train_dataset.take(1).as_numpy_iterator())[0])

base_cnn = resnet.ResNet50(
weights="imagenet", input_shape=target_shape + (3,), include_top=False
)

flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.Dense(512, activation="relu")(flatten)
dense1 = layers.BatchNormalization()(dense1)
dense2 = layers.Dense(256, activation="relu")(dense1)
dense2 = layers.BatchNormalization()(dense2)
output = layers.Dense(256)(dense2)

embedding = Model(base_cnn.input, output, name="Embedding")

trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable


class DistanceLayer(layers.Layer):
"""
This layer is responsible for computing the distance between the anchor
Expand All @@ -142,22 +34,6 @@ def call(self, anchor, positive, negative):
an_distance = tf.reduce_sum(tf.square(anchor - negative), -1)
return (ap_distance, an_distance)


anchor_input = layers.Input(name="anchor", shape=target_shape + (3,))
positive_input = layers.Input(name="positive", shape=target_shape + (3,))
negative_input = layers.Input(name="negative", shape=target_shape + (3,))

distances = DistanceLayer()(
embedding(resnet.preprocess_input(anchor_input)),
embedding(resnet.preprocess_input(positive_input)),
embedding(resnet.preprocess_input(negative_input)),
)

siamese_network = Model(
inputs=[anchor_input, positive_input, negative_input], outputs=distances
)


class SiameseModel(Model):
"""The Siamese Network model with a custom training and testing loops.
Expand Down Expand Up @@ -224,9 +100,72 @@ def metrics(self):
return [self.loss_tracker]


# Training
# We need to make sure both the anchor and positive images are loaded in
# sorted order so we can match them together.
anchor_images = sorted(
[str(anchor_images_path / f) for f in os.listdir(anchor_images_path)]
)

positive_images = sorted(
[str(positive_images_path / f) for f in os.listdir(positive_images_path)]
)

all_images = anchor_images + positive_images
random.shuffle(all_images)

image_count = len(anchor_images)

anchor_dataset = tf.data.Dataset.from_tensor_slices(anchor_images)
positive_dataset = tf.data.Dataset.from_tensor_slices(positive_images)
negative_dataset = tf.data.Dataset.from_tensor_slices(all_images[:100])

dataset = tf.data.Dataset.zip((anchor_dataset, positive_dataset, negative_dataset))
dataset = dataset.shuffle(buffer_size=1024)
dataset = dataset.map(preprocess_triplets)

# Let's now split our dataset in train and validation.
train_dataset = dataset.take(round(image_count * 0.8))
val_dataset = dataset.skip(round(image_count * 0.8))

batch_size = 16

train_dataset = train_dataset.batch(batch_size, drop_remainder=False).prefetch(8)
val_dataset = val_dataset.batch(batch_size, drop_remainder=False).prefetch(8)

base_cnn = resnet.ResNet50(
weights="imagenet", input_shape=target_shape + (3,), include_top=False
)

flatten = layers.Flatten()(base_cnn.output)
dense1 = layers.BatchNormalization()(layers.Dense(512, activation="relu")(flatten))
dense2 = layers.BatchNormalization()(layers.Dense(256, activation="relu")(dense1))
output = layers.Dense(256)(dense2)

embedding = Model(base_cnn.input, output, name="Embedding")

trainable = False
for layer in base_cnn.layers:
if layer.name == "conv5_block1_out":
trainable = True
layer.trainable = trainable

anchor_input = layers.Input(name="anchor", shape=target_shape + (3,))
positive_input = layers.Input(name="positive", shape=target_shape + (3,))
negative_input = layers.Input(name="negative", shape=target_shape + (3,))

distances = DistanceLayer()(
embedding(resnet.preprocess_input(anchor_input)),
embedding(resnet.preprocess_input(positive_input)),
embedding(resnet.preprocess_input(negative_input)),
)

siamese_network = Model(
inputs=[anchor_input, positive_input, negative_input], outputs=distances
)

# Training
siamese_model = SiameseModel(siamese_network)
siamese_model.compile(optimizer=optimizers.Adam(0.0001)) # 0..0001
siamese_model.fit(train_dataset, epochs=7, validation_data=val_dataset)
embedding.save("siamese_feature.h5")
embedding.save_weights("weights/embedding_weights.h5")
siamese_model.save_weights("weights/siamese_weights.h5")
25 changes: 7 additions & 18 deletions src/model/Siamese_Predictor.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,17 @@
# from Siamese_Network import preprocess_image
import threading
import tensorflow as tf
import time

from utils import preprocess_image

from tensorflow.keras.applications import resnet
import argparse
from tensorflow.keras import metrics
import time

from watchdog.observers import Observer
from watchdog.events import PatternMatchingEventHandler
import threading

def preprocess_image(filename):
"""
Load the specified file as a JPEG image, preprocess it and
resize it to the target shape.
"""
target_shape = (200, 200)

image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)
return image

# load the model
embedding = tf.keras.models.load_model("siamese_feature.h5", compile=False)
embedding = tf.keras.models.load_model("siamese_weights.h5", compile=False)
print("Model is done loading")

# load david base
Expand Down
28 changes: 28 additions & 0 deletions src/model/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import tensorflow as tf

target_shape = (200, 200)

def preprocess_image(filename):
"""
Load the specified file as a JPEG image, preprocess it and
resize it to the target shape.
"""

image_string = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, tf.float32)
image = tf.image.resize(image, target_shape)
return image


def preprocess_triplets(anchor, positive, negative):
"""
Given the filenames corresponding to the three images, load and
preprocess them.
"""

return (
preprocess_image(anchor),
preprocess_image(positive),
preprocess_image(negative),
)

0 comments on commit b6b1877

Please sign in to comment.