Skip to content

Thin Plate Spline Spatial Transformer Network implementation with TensorFlow

License

Notifications You must be signed in to change notification settings

shengzhang90/tf_ThinPlateSpline

 
 

Repository files navigation

tf_ThinPlateSpline

TensorFlow implementation of Thin Plate Spline (TPS) Spatial Transformer Network (STN). The old version implementation based on STN paper is TPS_STN. This implmentation, however, can work with non-fixed control points. Each solution of TPS system is dynamically calculated every feed-forward step.

Usage :
  V = ThinPlateSpline(U, coord, vector, out_size)
  V = ThinPlateSpline2(U, source, target, out_size)

Args :
  U : float Tensor [num_batch, height, width, num_channels].
    Input Tensor.
  coord : float Tensor [num_batch, num_point, 2]
    Relative coordinates of the control points.
  vector : float Tensor [num_batch, num_point, 2]
    The vector on the control points.
  out_size: tuple of two integers [height, width]
    The size of the output of the network (height, width)

  For version 2.
    source : float Tensor [num_batch, num_point, 2]
      Relative coordinates of the source points.
    target : float Tensor [num_batch, num_point, 2]
      Relative coordinates of the target points.

Example

Version 1

# test.py
p = np.array([
  [-0.5, -0.5],
  [0.5, -0.5],
  [-0.5, 0.5],
  [0.5, 0.5]])

v = np.array([
  [0.2, 0.2],
  [0.4, 0.4],
  [0.6, 0.6],
  [0.8, 0.8]])

p = tf.constant(p.reshape([1, 4, 2]), dtype=tf.float32)
v = tf.constant(v.reshape([1, 4, 2]), dtype=tf.float32)
t_img = tf.constant(img.reshape(shape), dtype=tf.float32)
t_img = ThinPlateSpline(t_img, p, v, out_size)

alt tag original.png

alt tag transformed.png

Version 2

# test2.py
s_ = np.array([ # source position
  [-0.5, -0.5],
  [0.5, -0.5],
  [-0.5, 0.5],
  [0.5, 0.5]])

t_ = np.array([ # target position
  [-0.3, -0.3],
  [0.3, -0.3],
  [-0.3, 0.3],
  [0.3, 0.3]])

s = tf.constant(s_.reshape([1, 4, 2]), dtype=tf.float32)
t = tf.constant(t_.reshape([1, 4, 2]), dtype=tf.float32)
t_img = tf.constant(img.reshape(shape), dtype=tf.float32)
t_img = ThinPlateSpline2(t_img, s, t, out_size)

alt tag original.png

alt tag transformed2.png

References

About

Thin Plate Spline Spatial Transformer Network implementation with TensorFlow

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%