Skip to content

Commit

Permalink
avoid gsheet api quota exceeded error while downloading
Browse files Browse the repository at this point in the history
  • Loading branch information
HiroshigeAoki committed Dec 25, 2023
1 parent 84a5cc3 commit 468927c
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 31 deletions.
92 changes: 71 additions & 21 deletions src/gspread_client_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
)
from gspread_formatting.dataframe import BasicFormatter, format_with_dataframe
from oauth2client.service_account import ServiceAccountCredentials
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from src.gsheet import (
AnnotationCriminalSheet,
AnnotationNotCriminalSheet,
Expand All @@ -29,7 +35,6 @@
ScenarioSheet,
)
from src.slack import SlackClientWrapper
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential

slack_client = SlackClientWrapper()

Expand All @@ -40,7 +45,14 @@
retry_decorator = retry(
stop=stop_after_attempt(5), # 最大5回
wait=wait_exponential(multiplier=1, max=10), # リトライの間隔は1秒から始まり、最大10秒まで指数関数的に増加
retry=retry_if_exception_type((APIError, SSLError, requests.exceptions.RequestException)),
retry=retry_if_exception_type(
(
APIError,
SSLError,
requests.exceptions.RequestException,
gspread.exceptions.GSpreadException,
)
),
)


Expand Down Expand Up @@ -102,8 +114,12 @@ def _get_or_fetch_master_worksheet_data(
logger.debug(f"Updated master_cache[{sheet_name}] with {new_data}")
return new_data

def get_game_info_data(self, channel_name: str, use_catch: bool = False) -> Dict[str, Any]:
worksheet = self._get_or_fetch_master_worksheet_data("game_info", use_catch=use_catch)
def get_game_info_data(
self, channel_name: str, use_catch: bool = False
) -> Dict[str, Any]:
worksheet = self._get_or_fetch_master_worksheet_data(
"game_info", use_catch=use_catch
)
df = pd.DataFrame(worksheet.get_all_records())
if len(df) == 0:
raise ValueError("Game_infoスプレッドシートの取得に失敗しました。")
Expand All @@ -118,19 +134,27 @@ def get_game_info_data(self, channel_name: str, use_catch: bool = False) -> Dict
return {"game_info_sheet": GameInfoSheet(**game_info), "row_index": row_index}

def get_player_df(self, use_catch: bool = False) -> pd.DataFrame:
worksheet = self._get_or_fetch_master_worksheet_data("player", use_catch=use_catch)
worksheet = self._get_or_fetch_master_worksheet_data(
"player", use_catch=use_catch
)
if worksheet is None:
raise ValueError("Playerスプレッドシートの取得に失敗しました。")
df = pd.DataFrame(worksheet.get_all_records())
return df

def get_scenario_data(self, row_index: int, use_catch: bool = False) -> ScenarioSheet:
worksheet = self._get_or_fetch_master_worksheet_data("scenario", use_catch=use_catch)
def get_scenario_data(
self, row_index: int, use_catch: bool = False
) -> ScenarioSheet:
worksheet = self._get_or_fetch_master_worksheet_data(
"scenario", use_catch=use_catch
)
if worksheet is None:
raise ValueError("Scenarioスプレッドシートの取得に失敗しました。")
df = pd.DataFrame(worksheet.get_all_records())
try:
scenario = ScenarioSheet(row_index=row_index, **df.iloc[row_index - 2].to_dict())
scenario = ScenarioSheet(
row_index=row_index, **df.iloc[row_index - 2].to_dict()
)
except TypeError as e:
raise ValueError(f"スプレッドシートに誤りがあります。\n{e}")
return scenario
Expand All @@ -154,7 +178,9 @@ def has_player_played_scenario(
self, player_row_index: int, scenario_row_index: int, use_catch: bool = False
) -> bool:
sheet_name = "player_scenario_mapping"
worksheet = self._get_or_fetch_master_worksheet_data(sheet_name, use_catch=use_catch)
worksheet = self._get_or_fetch_master_worksheet_data(
sheet_name, use_catch=use_catch
)
df = pd.DataFrame(worksheet.get_all_records())
matching_row = df[
(df["player_row_index"] == player_row_index)
Expand All @@ -174,7 +200,9 @@ def save_player_scenario_mapping(
with self._lock:
sheet_name = "player_scenario_mapping"
worksheet = self._get_or_fetch_master_worksheet_data(sheet_name)
worksheet.append_row([player_row_index, scenario_row_index, player_name, scenario_name])
worksheet.append_row(
[player_row_index, scenario_row_index, player_name, scenario_name]
)
return worksheet

# master sheet更新関連
Expand All @@ -188,15 +216,19 @@ def save_vote_results(
for col, vote in enumerate(votes, start=GameInfoSheet.VOTE1_COL):
worksheet.update_cell(player_row_index, col, vote)

for col, reason in enumerate(reasons, start=GameInfoSheet.REASON1_COL): # Column Oからスタート
for col, reason in enumerate(
reasons, start=GameInfoSheet.REASON1_COL
): # Column Oからスタート
worksheet.update_cell(player_row_index, col, reason)
return worksheet

@retry_decorator
def update_sheet_cell(
self, sheet_name: str, row_index: int, col_index: int, value: Any
) -> gspread.Worksheet:
logger.info(f"Updating {sheet_name} cell ({row_index}, {col_index}) with {value}")
logger.info(
f"Updating {sheet_name} cell ({row_index}, {col_index}) with {value}"
)
worksheet = self._get_or_fetch_master_worksheet_data(sheet_name)
logger.debug(f"worksheet: {worksheet}, type: {type(worksheet)}")
worksheet.update_cell(row_index, col_index, value)
Expand All @@ -205,12 +237,18 @@ def update_sheet_cell(
# annotation用spreadsheet関連
@retry_decorator
def share_spreadsheet(self, sheet: gspread.Spreadsheet, email: str) -> None:
sheet.share(email, perm_type="user", role="writer", with_link=False, notify=False)
sheet.share(
email, perm_type="user", role="writer", with_link=False, notify=False
)
logger.info(f"Shared spreadsheet with {email}")

@retry_decorator
def create_player_gsheet(self, player_name: str, player_email: str) -> gspread.Spreadsheet:
sheet = self.client.create(f"{player_name}", folder_id=settings.ANNOTATION_FOLDER_KEY)
def create_player_gsheet(
self, player_name: str, player_email: str
) -> gspread.Spreadsheet:
sheet = self.client.create(
f"{player_name}", folder_id=settings.ANNOTATION_FOLDER_KEY
)
emails = [player_email, settings.STAFF_EMAIL, settings.BOT_EMAIL]
logger.info(f"Created spreadsheet for {player_name}. url: {sheet.url} ")
for email in emails:
Expand All @@ -237,15 +275,19 @@ def create_or_get_worksheet(
return worksheet

@retry_decorator
def _format_headers(self, worksheet: gspread.Worksheet, dialogue_df: pd.DataFrame) -> None:
def _format_headers(
self, worksheet: gspread.Worksheet, dialogue_df: pd.DataFrame
) -> None:
header_formatter = BasicFormatter(freeze_headers=True)
format_with_dataframe(worksheet, dialogue_df, header_formatter)

@retry_decorator
def _format_text_columns(
self, worksheet: gspread.worksheet, num_rows: int, is_criminal: bool
) -> None:
message_column_format = cellFormat(horizontalAlignment="LEFT", wrapStrategy="WRAP")
message_column_format = cellFormat(
horizontalAlignment="LEFT", wrapStrategy="WRAP"
)
columns_to_format = [
(
AnnotationSheet.MESSAGE_COL,
Expand Down Expand Up @@ -284,12 +326,18 @@ def _format_checkboxes(
validation_rule,
)
else:
suspicious_range = f"{AnnotationNotCriminalSheet.SUSPICIOUS_RANGE}{num_rows + 1}"
suspicious_range = (
f"{AnnotationNotCriminalSheet.SUSPICIOUS_RANGE}{num_rows + 1}"
)
deciding_factor_range = (
f"{AnnotationNotCriminalSheet.DECIDING_FACTOR_RANGE}{num_rows + 1}"
)
set_data_validation_for_cell_range(worksheet, suspicious_range, validation_rule)
set_data_validation_for_cell_range(worksheet, deciding_factor_range, validation_rule)
set_data_validation_for_cell_range(
worksheet, suspicious_range, validation_rule
)
set_data_validation_for_cell_range(
worksheet, deciding_factor_range, validation_rule
)

@retry_decorator
def apply_formatting_to_worksheet(
Expand All @@ -299,7 +347,9 @@ def apply_formatting_to_worksheet(
self._format_text_columns(worksheet, dialogue_df.shape[0], is_criminal)
self._format_checkboxes(worksheet, dialogue_df.shape[0], is_criminal)

def apply_protected_range(self, worksheet: gspread.Worksheet, is_criminal: bool) -> None:
def apply_protected_range(
self, worksheet: gspread.Worksheet, is_criminal: bool
) -> None:
worksheet.add_protected_range(
name=AnnotationSheet.PROTECT_RANGE + str(worksheet.row_count),
editor_users_emails=[settings.STAFF_BOT_ID_GMAILS],
Expand Down
88 changes: 78 additions & 10 deletions src/workers/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os
import subprocess
import time
from typing import Any, Dict, List

import pandas as pd
Expand All @@ -25,31 +26,75 @@
logger = logging.getLogger("celery")


# def load_previous_data(file_name):
# if os.path.exists(file_name):
# if file_name.endswith(".pkl"):
# test = pd.read_pickle(file_name)
# saved_channels = list(
# set([row["channel_name"] for row in test["game_info"]])
# )
# return [i[1].to_dict() for i in test.iterrows()], saved_channels
#
# elif file_name.endswith(".jsonl"):
# with open(file_name, "r") as file:
# return [json.loads(line) for line in file]
# else:
# if file_name.endswith(".pkl"):
# return [], []
# elif file_name.endswith(".jsonl"):
# return []
#
#
# old_whole_log = load_previous_data(os.path.join(settings.TMP_DIR, "whole_log.jsonl"))
# old_test_data, saved_channels = load_previous_data(
# os.path.join(settings.TMP_DIR, "test.pkl")
# )


@celery.task(name="download_task", time_limit=CELERY_TIME_LIMIT)
@handle_worker_errors
def download_task() -> None:
with session_scope() as session:
game_info_records = game_info_repo.get_all_records(session=session)

results = []
# for i, record in enumerate(game_info_records):
# if (
# "murder-mystery" in record.channel_name
# and is_game_ready_to_proceed(record.channel_id, next_step="download")
# # and record.channel_id not in saved_channels
# ):
# wait_time = 60 if not i == 0 else 0
# result = process_game_info_record.delay(record.channel_id, wait_time)
# results.append(result)
#
# aggregate_results.delay(results)

task_group = group(
process_game_info_record.s(record.channel_id)
for record in game_info_records
process_game_info_record.s(record.channel_id, i * 60)
for i, record in enumerate(game_info_records)
if "murder-mystery" in record.channel_name
and is_game_ready_to_proceed(record.channel_id, next_step="download")
# and record.channel_id not in saved_channels
)
callback = aggregate_results.s().set(countdown=1)
chord(task_group)(callback)


@celery.task(name="process_game_info_record")
def process_game_info_record(channel_id: str) -> tuple:
def process_game_info_record(channel_id: str, wait_time: int) -> tuple:
time.sleep(wait_time)
with session_scope() as session:
game_info_record = game_info_repo.get_game_info(session=session, channel_id=channel_id)
game_info_record = game_info_repo.get_game_info(
session=session, channel_id=channel_id
)
# Process the record and return the results
return process_record(game_info_record, session)


@celery.task(name="aggregate_results")
def aggregate_results(results: tuple) -> None:
# whole_log, test_data = old_whole_log, old_test_data
whole_log, test_data = [], []
for result in results:
# Unpack result, which is returned by process_game_info_record
Expand All @@ -66,13 +111,16 @@ def aggregate_results(results: tuple) -> None:
file_name = "whole_log.jsonl"
save_to_jsonl(whole_log, file_name=file_name)
after_download_task.delay()
# old_whole_log, old_test_data = whole_log, test_data


@celery.task(name="after_download_task")
def after_download_task() -> None:
if os.path.exists(settings.AFTER_DOWNLOAD_TASK_SCRIPT):
subprocess.run(["bash", settings.AFTER_DOWNLOAD_TASK_SCRIPT])
slack_client.send_direct_message(user_id=settings.STAFF_ID, message="データのダウンロードが完了しました。")
slack_client.send_direct_message(
user_id=settings.STAFF_ID, message="データのダウンロードが完了しました。"
)


def process_record(game_info_record: GameInfoTable, session: Any) -> tuple:
Expand Down Expand Up @@ -130,7 +178,18 @@ def process_record(game_info_record: GameInfoTable, session: Any) -> tuple:
logger.debug(f"columns={list_columns}")

# For pkl file
df = pd.DataFrame(dict(zip(list_columns, list_data)))
for i in range(len(list_columns)):
if i < len(list_data):
logger.debug(f"Column: {list_columns[i]}, Length: {len(list_data[i])}")
else:
logger.error(f"Missing data for column: {list_columns[i]}")

try:
df = pd.DataFrame(dict(zip(list_columns, list_data)))
except ValueError as e:
logger.error(f"Error creating DataFrame: {e}")
return {}, []

test_data = []
for participant_record in participant_records:
speaker_df = df[df["speakers"] == participant_record.char_name]
Expand All @@ -141,7 +200,9 @@ def process_record(game_info_record: GameInfoTable, session: Any) -> tuple:
test_annotation_list_cols.append(speaker_df[col_name].to_list())

annotations = dict(
zip(annotation_col_names, (test_annotation_list_cols + annotation_cols[-2:]))
zip(
annotation_col_names, (test_annotation_list_cols + annotation_cols[-2:])
)
)

game_info_col_names = [
Expand All @@ -163,7 +224,10 @@ def process_record(game_info_record: GameInfoTable, session: Any) -> tuple:

test_data.append(
dict(
nested_utters=utterances, labels=label, annotations=annotations, game_info=game_info
nested_utters=utterances,
labels=label,
annotations=annotations,
game_info=game_info,
)
)

Expand Down Expand Up @@ -191,7 +255,9 @@ def get_annotation_data(
sheet_name += "_criminal"

annotation_worksheet = gsheet_client.get_annotation_sheet(
spreadsheet_key=participant_record.annotation_sheets_key, sheet_name=sheet_name
spreadsheet_key=participant_record.annotation_sheets_key,
sheet_name=sheet_name,
use_catch=True,
)
data = annotation_worksheet.get_all_values()
return pd.DataFrame(data[1:], columns=data[0])
Expand Down Expand Up @@ -223,7 +289,9 @@ def process_annotation_data(
)
cols.append(
convert_bool_str_list_to_int_list(
df.iloc[:, AnnotationNotCriminalSheet.DECIDING_FACTOR_COL_NUM - 1].to_list()
df.iloc[
:, AnnotationNotCriminalSheet.DECIDING_FACTOR_COL_NUM - 1
].to_list()
)
)
col_names.extend(
Expand Down

0 comments on commit 468927c

Please sign in to comment.