From afbb3c031ab88fd4ecb47481ab41e9c2c6722ed7 Mon Sep 17 00:00:00 2001 From: Kirushikesh Date: Mon, 2 Sep 2024 12:30:30 -0400 Subject: [PATCH] updated get_pattern function --- gptcache/processor/pre.py | 41 +++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/gptcache/processor/pre.py b/gptcache/processor/pre.py index 59db670c..bca41a0d 100644 --- a/gptcache/processor/pre.py +++ b/gptcache/processor/pre.py @@ -49,28 +49,27 @@ def last_content_without_prompt(data: Dict[str, Any], **params: Dict[str, Any]) def _get_pattern_value(pattern_str: str, value_str: str): - literal_text_arr = [] - field_name_arr = [] - for literal_text, field_name, _, _ in string.Formatter().parse(pattern_str): - literal_text_arr.append(literal_text) - if field_name is not None: - field_name_arr.append( - field_name if field_name else str(len(field_name_arr)) - ) - - pattern_values = {} + parts = list(string.Formatter().parse(pattern_str)) + field_names = [field_name for _, field_name, _, _ in parts if field_name is not None] + + pattern_values = {field: '' for field in field_names} # Initialize all fields with empty strings + last_end = 0 - for i, literal_text in enumerate(literal_text_arr): - start = value_str.find(literal_text, last_end) - if i == len(literal_text_arr) - 1: - end = len(value_str) - else: - end = value_str.find(literal_text_arr[i + 1], start + 1) - if start == -1 or end == -1: - break - start += len(literal_text) - pattern_values[field_name_arr[i]] = value_str[start:end] - last_end = end + for i, (literal_text, field_name, _, _) in enumerate(parts): + if literal_text: + start = value_str.find(literal_text, last_end) + if start == -1: + break + if i > 0 and field_names[i-1]: + pattern_values[field_names[i-1]] = value_str[last_end:start].strip() + last_end = start + len(literal_text) + elif i > 0 and field_name: # Handle consecutive fields + pattern_values[field_names[i-1]] = '' + + # Handle the last field if it exists + if field_names and last_end < len(value_str): + pattern_values[field_names[-1]] = value_str[last_end:].strip() + return pattern_values