forked from iwyoo/tf_ThinPlateSpline
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test2.py
30 lines (25 loc) · 790 Bytes
/
test2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import tensorflow as tf
import numpy as np
from PIL import Image
from ThinPlateSpline2 import ThinPlateSpline2 as stn
img = np.array(Image.open("original.png"))
out_size = list(img.shape)
shape = [1]+out_size+[1]
s_ = np.array([ # source position
[-0.5, -0.5],
[0.5, -0.5],
[-0.5, 0.5],
[0.5, 0.5]])
t_ = np.array([ # target position
[-0.3, -0.3],
[0.3, -0.3],
[-0.3, 0.3],
[0.3, 0.3]])
s = tf.constant(s_.reshape([1, 4, 2]), dtype=tf.float32)
t = tf.constant(t_.reshape([1, 4, 2]), dtype=tf.float32)
t_img = tf.constant(img.reshape(shape), dtype=tf.float32)
t_img = stn(t_img, s, t, out_size)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
img1 = sess.run(t_img)
Image.fromarray(np.uint8(img1.reshape(out_size))).save("transformed2.png")