-
Notifications
You must be signed in to change notification settings - Fork 0
/
caption.py
135 lines (101 loc) · 3.85 KB
/
caption.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import base64
import requests
import os
import sys
import argparse
# OpenAI API Key
api_key = os.environ.get("OPENAI_API_KEY")
# Overwrite files or only generate captions for images that don't have captions
overwrite = False
# The extension for the caption file
caption_extension = "txt"
if api_key is None:
print("Error: OPENAI_API_KEY environment variable not set")
sys.exit(1)
# Parse the arguments
parser = argparse.ArgumentParser(description="Generate captions for a folder of images")
parser.add_argument("keyword", help="The token/keyword for the session")
parser.add_argument("image_folder", help="The folder containing the images")
# Optional arguments
parser.add_argument("--ext", help="The extension for the caption file", default=".txt")
parser.add_argument("--overwrite", help="Overwrite existing caption files", action="store_true",default=False)
args = parser.parse_args()
keyword = args.keyword
image_folder = args.image_folder
overwrite = args.overwrite
#
# Function to encode the image
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def generate_caption(image_path):
# Getting the base64 string
base64_image = encode_image(image_path)
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
payload = {
"model": "gpt-4-vision-preview",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": '''Classify image with precision using session keywords for object/subject captions for Stable Diffusion. Begin captions with the session keyword, focusing on actions, clothing, photo style, and scenery. Add descriptions of the photo itself such as "selfie" or "full body shot". If the photo does not have many details or is very blurry mention that it is low quality. Refer to the subject by the keyword. Avoid artistic interpretation, text, and meta commentary. IMPORTANT: NEVER CHANGE THE CAPITALIZATION OF THE SESSION KEYWORD. The session keyword is "{keyword}" '''.format(keyword=keyword)
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
}
}
]
}
],
"max_tokens": 500
}
response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
# Get the response
response = response.json()
return response['choices'][0]['message']['content']
# Grab a folder of images and generate captions for each image
# Ensure the folder exists
if not os.path.exists(image_folder):
print(f"Error: {image_folder} does not exist")
quit()
# Check if caption files already exist, if so, ask the user if they want to overwrite them or continue
caption_files = [file for file in os.listdir(image_folder) if file.endswith(caption_extension)]
if overwrite and len(caption_files) > 0:
for file in caption_files:
os.remove(os.path.join(image_folder, file))
images = []
valid_extensions = [".jpg", ".jpeg", ".png"]
for file in os.listdir(image_folder):
# Check if the file is an image
if file.lower().endswith(tuple(valid_extensions)):
images.append(os.path.join(image_folder, file))
if len(images) == 0:
print(f"No images found in {image_folder}")
quit()
print(f"Found {len(images)} images")
images_captioned = 0
for image in images:
# Check if the caption file already exists
if not overwrite:
caption_file = os.path.splitext(image)[0] + "." + caption_extension
if os.path.exists(caption_file):
# Skip this image
continue
caption = generate_caption(image)
# Print the caption
print (f"Caption for {image}: {caption}")
# Write the caption to a file
# Get the filename without the extension
filename = os.path.splitext(image)[0]
# Write the caption to a file
with open(filename +"."+ caption_extension, 'w') as f:
f.write(caption)
images_captioned += 1
print(f"Generated captions for {images_captioned} images")