-
Notifications
You must be signed in to change notification settings - Fork 1.3k
/
inference.py
179 lines (158 loc) · 5.78 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import cv2
import argparse
import numpy as np
from hivision.error import FaceError
from hivision.utils import hex_to_rgb, resize_image_to_kb, add_background, save_image_dpi_to_bytes
from hivision import IDCreator
from hivision.creator.layout_calculator import (
generate_layout_array,
generate_layout_image,
)
from hivision.creator.choose_handler import choose_handler
from hivision.utils import hex_to_rgb, resize_image_to_kb
INFERENCE_TYPE = [
"idphoto",
"human_matting",
"add_background",
"generate_layout_photos",
"idphoto_crop",
]
MATTING_MODEL = [
"hivision_modnet",
"modnet_photographic_portrait_matting",
"mnn_hivision_modnet",
"rmbg-1.4",
"birefnet-v1-lite",
]
FACE_DETECT_MODEL = [
"mtcnn",
"face_plusplus",
"retinaface-resnet50",
]
RENDER = [0, 1, 2]
parser = argparse.ArgumentParser(description="HivisionIDPhotos 证件照制作推理程序。")
parser.add_argument(
"-t",
"--type",
help="请求 API 的种类",
choices=INFERENCE_TYPE,
default="idphoto",
)
parser.add_argument("-i", "--input_image_dir", help="输入图像路径", required=True)
parser.add_argument("-o", "--output_image_dir", help="保存图像路径", required=True)
parser.add_argument("--height", help="证件照尺寸-高", default=413)
parser.add_argument("--width", help="证件照尺寸-宽", default=295)
parser.add_argument("-c", "--color", help="证件照背景色", default="638cce")
parser.add_argument("--hd", type=bool, help="是否输出高清照", default=True)
parser.add_argument(
"-k", "--kb", help="输出照片的 KB 值,仅对换底和制作排版照生效", default=None
)
parser.add_argument(
"-r",
"--render",
type=int,
help="底色合成的模式,有 0:纯色、1:上下渐变、2:中心渐变 可选",
choices=RENDER,
default=0,
)
parser.add_argument(
"--dpi",
type=int,
help="输出照片的 DPI 值",
default=300,
)
parser.add_argument(
"--face_align",
type=bool,
help="是否进行人脸旋转矫正",
default=False,
)
parser.add_argument(
"--matting_model",
help="抠图模型权重",
default="modnet_photographic_portrait_matting",
choices=MATTING_MODEL,
)
parser.add_argument(
"--face_detect_model",
help="人脸检测模型",
default="mtcnn",
choices=FACE_DETECT_MODEL,
)
args = parser.parse_args()
# ------------------- 选择抠图与人脸检测模型 -------------------
creator = IDCreator()
choose_handler(creator, args.matting_model, args.face_detect_model)
root_dir = os.path.dirname(os.path.abspath(__file__))
input_image = cv2.imread(args.input_image_dir, cv2.IMREAD_UNCHANGED)
# 如果模式是生成证件照
if args.type == "idphoto":
# 将字符串转为元组
size = (int(args.height), int(args.width))
try:
result = creator(input_image, size=size, face_alignment=args.face_align)
except FaceError:
print("人脸数量不等于 1,请上传单张人脸的图像。")
else:
# 保存标准照
save_image_dpi_to_bytes(cv2.cvtColor(result.standard, cv2.COLOR_RGBA2BGRA), args.output_image_dir, dpi=args.dpi)
# 保存高清照
file_name, file_extension = os.path.splitext(args.output_image_dir)
new_file_name = file_name + "_hd" + file_extension
save_image_dpi_to_bytes(cv2.cvtColor(result.hd, cv2.COLOR_RGBA2BGRA), new_file_name, dpi=args.dpi)
# 如果模式是人像抠图
elif args.type == "human_matting":
result = creator(input_image, change_bg_only=True)
cv2.imwrite(args.output_image_dir, result.hd)
# 如果模式是添加背景
elif args.type == "add_background":
render_choice = ["pure_color", "updown_gradient", "center_gradient"]
# 将字符串转为元组
color = hex_to_rgb(args.color)
# 将元祖的 0 和 2 号数字交换
color = (color[2], color[1], color[0])
result_image = add_background(
input_image, bgr=color, mode=render_choice[args.render]
)
result_image = result_image.astype(np.uint8)
result_image = cv2.cvtColor(result_image, cv2.COLOR_RGBA2BGRA)
if args.kb:
resize_image_to_kb(result_image, args.output_image_dir, int(args.kb), dpi=args.dpi)
else:
save_image_dpi_to_bytes(cv2.cvtColor(result_image, cv2.COLOR_RGBA2BGRA), args.output_image_dir, dpi=args.dpi)
# 如果模式是生成排版照
elif args.type == "generate_layout_photos":
size = (int(args.height), int(args.width))
typography_arr, typography_rotate = generate_layout_array(
input_height=size[0], input_width=size[1]
)
result_layout_image = generate_layout_image(
input_image,
typography_arr,
typography_rotate,
height=size[0],
width=size[1],
)
if args.kb:
result_layout_image = cv2.cvtColor(result_layout_image, cv2.COLOR_RGB2BGR)
result_layout_image = resize_image_to_kb(
result_layout_image, args.output_image_dir, int(args.kb), dpi=args.dpi
)
else:
save_image_dpi_to_bytes(cv2.cvtColor(result_layout_image, cv2.COLOR_RGBA2BGRA), args.output_image_dir, dpi=args.dpi)
# 如果模式是证件照裁切
elif args.type == "idphoto_crop":
# 将字符串转为元组
size = (int(args.height), int(args.width))
try:
result = creator(input_image, size=size, crop_only=True)
except FaceError:
print("人脸数量不等于 1,请上传单张人脸的图像。")
else:
# 保存标准照
save_image_dpi_to_bytes(cv2.cvtColor(result.standard, cv2.COLOR_RGBA2BGRA), args.output_image_dir, dpi=args.dpi)
# 保存高清照
file_name, file_extension = os.path.splitext(args.output_image_dir)
new_file_name = file_name + "_hd" + file_extension
save_image_dpi_to_bytes(cv2.cvtColor(result.hd, cv2.COLOR_RGBA2BGRA), new_file_name, dpi=args.dpi)