forked from ILikeAI/AlwaysReddy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
144 lines (112 loc) · 4.11 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import re
import clipboard
import tiktoken
import re
def read_clipboard():
"""
Read the text from the clipboard.
Returns:
str: The text read from the clipboard.
"""
text = clipboard.paste()
return text
def to_clipboard(text):
"""
Copy the given text to the clipboard.
Args:
text (str): The text to be copied to the clipboard.
"""
clipboard.copy(extract_code_if_only_code_block(text))
def sanitize_text(text):
"""
Remove disallowed characters from a string and replace certain symbols with their text equivalents.
Args:
text (str): The text to be sanitized.
Returns:
str: The sanitized text.
"""
disallowed_chars = '"<>[]{}|\\~`^*!#$()_;'
symbol_text_pairs = [
(' & ', ' and '),
(' % ', ' percent '),
(' @ ', ' at '),
(' = ', ' equals '),
(' + ', ' plus '),
(' / ', ' slash '),
]
sanitized_text = ''.join(filter(lambda x: x not in disallowed_chars, text))
for symbol, text_equivalent in symbol_text_pairs:
sanitized_text = sanitized_text.replace(symbol, text_equivalent)
return sanitized_text
def _trim_messages(messages, max_tokens):
"""
Trim the messages to fit within the maximum token limit.
Args:
messages (list): A list of messages to be trimmed.
max_tokens (int): The maximum number of tokens allowed.
Returns:
list: The trimmed list of messages.
"""
msg_token_count = 0
while True:
msg_token_count = _count_tokens(messages)
if msg_token_count <= max_tokens:
break
# Remove the oldest non-system message
for i in range(len(messages)):
if messages[i].get('role') != 'system':
del messages[i]
break
# Ensure the first non-system message is from the user
first_non_system_msg_index = next((i for i, message in enumerate(messages) if message.get('role') != 'system'), None)
while first_non_system_msg_index is not None and messages[first_non_system_msg_index].get('role') == 'assistant':
del messages[first_non_system_msg_index]
first_non_system_msg_index = next((i for i, message in enumerate(messages) if message.get('role') != 'system'), None)
return messages
def _count_tokens(messages, model="gpt-3.5-turbo"):
"""
Count the tokens in the given messages using the specified model.
Args:
messages (list): A list of messages to count tokens from.
model (str): The model to use for token counting. Defaults to "gpt-3.5-turbo".
Returns:
int: The total count of tokens in the messages.
"""
enc = tiktoken.encoding_for_model(model)
msg_token_count = 0
for message in messages:
for key, value in message.items():
msg_token_count += len(enc.encode(value)) # Add tokens in set message
return msg_token_count
def maintain_token_limit(messages, max_tokens):
"""
Maintain the token limit by trimming messages if the token count exceeds the maximum limit.
Args:
messages (list): A list of messages to maintain.
max_tokens (int): The maximum number of tokens allowed.
Returns:
list: The trimmed list of messages.
"""
if _count_tokens(messages) > max_tokens:
messages = _trim_messages(messages, max_tokens)
return messages
def extract_code_if_only_code_block(markdown_text):
"""
Extracts the code from a markdown text if the text only contains a single code block.
Args:
markdown_text (str): The markdown text to extract the code from.
Returns:
str: The extracted code if the markdown text only contains a single code block,
otherwise the original markdown text.
"""
stripped_text = markdown_text.strip()
# Define the regex pattern
pattern = r'^```(?:\w+)?\n([\s\S]*?)```$'
# Search for the pattern
match = re.match(pattern, stripped_text)
if match:
# Extract and return the code block
return match.group(1)
else:
# Return the original text if it doesn't match the pattern
return markdown_text