Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
HiroshigeAoki committed Dec 25, 2023
1 parent 468927c commit d5de264
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 119 deletions.
83 changes: 20 additions & 63 deletions src/gspread_client_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,6 @@
)
from gspread_formatting.dataframe import BasicFormatter, format_with_dataframe
from oauth2client.service_account import ServiceAccountCredentials
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from src.gsheet import (
AnnotationCriminalSheet,
AnnotationNotCriminalSheet,
Expand All @@ -35,6 +29,7 @@
ScenarioSheet,
)
from src.slack import SlackClientWrapper
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential

slack_client = SlackClientWrapper()

Expand Down Expand Up @@ -114,12 +109,8 @@ def _get_or_fetch_master_worksheet_data(
logger.debug(f"Updated master_cache[{sheet_name}] with {new_data}")
return new_data

def get_game_info_data(
self, channel_name: str, use_catch: bool = False
) -> Dict[str, Any]:
worksheet = self._get_or_fetch_master_worksheet_data(
"game_info", use_catch=use_catch
)
def get_game_info_data(self, channel_name: str, use_catch: bool = False) -> Dict[str, Any]:
worksheet = self._get_or_fetch_master_worksheet_data("game_info", use_catch=use_catch)
df = pd.DataFrame(worksheet.get_all_records())
if len(df) == 0:
raise ValueError("Game_infoスプレッドシートの取得に失敗しました。")
Expand All @@ -134,27 +125,19 @@ def get_game_info_data(
return {"game_info_sheet": GameInfoSheet(**game_info), "row_index": row_index}

def get_player_df(self, use_catch: bool = False) -> pd.DataFrame:
worksheet = self._get_or_fetch_master_worksheet_data(
"player", use_catch=use_catch
)
worksheet = self._get_or_fetch_master_worksheet_data("player", use_catch=use_catch)
if worksheet is None:
raise ValueError("Playerスプレッドシートの取得に失敗しました。")
df = pd.DataFrame(worksheet.get_all_records())
return df

def get_scenario_data(
self, row_index: int, use_catch: bool = False
) -> ScenarioSheet:
worksheet = self._get_or_fetch_master_worksheet_data(
"scenario", use_catch=use_catch
)
def get_scenario_data(self, row_index: int, use_catch: bool = False) -> ScenarioSheet:
worksheet = self._get_or_fetch_master_worksheet_data("scenario", use_catch=use_catch)
if worksheet is None:
raise ValueError("Scenarioスプレッドシートの取得に失敗しました。")
df = pd.DataFrame(worksheet.get_all_records())
try:
scenario = ScenarioSheet(
row_index=row_index, **df.iloc[row_index - 2].to_dict()
)
scenario = ScenarioSheet(row_index=row_index, **df.iloc[row_index - 2].to_dict())
except TypeError as e:
raise ValueError(f"スプレッドシートに誤りがあります。\n{e}")
return scenario
Expand All @@ -178,9 +161,7 @@ def has_player_played_scenario(
self, player_row_index: int, scenario_row_index: int, use_catch: bool = False
) -> bool:
sheet_name = "player_scenario_mapping"
worksheet = self._get_or_fetch_master_worksheet_data(
sheet_name, use_catch=use_catch
)
worksheet = self._get_or_fetch_master_worksheet_data(sheet_name, use_catch=use_catch)
df = pd.DataFrame(worksheet.get_all_records())
matching_row = df[
(df["player_row_index"] == player_row_index)
Expand All @@ -200,9 +181,7 @@ def save_player_scenario_mapping(
with self._lock:
sheet_name = "player_scenario_mapping"
worksheet = self._get_or_fetch_master_worksheet_data(sheet_name)
worksheet.append_row(
[player_row_index, scenario_row_index, player_name, scenario_name]
)
worksheet.append_row([player_row_index, scenario_row_index, player_name, scenario_name])
return worksheet

# master sheet更新関連
Expand All @@ -216,19 +195,15 @@ def save_vote_results(
for col, vote in enumerate(votes, start=GameInfoSheet.VOTE1_COL):
worksheet.update_cell(player_row_index, col, vote)

for col, reason in enumerate(
reasons, start=GameInfoSheet.REASON1_COL
): # Column Oからスタート
for col, reason in enumerate(reasons, start=GameInfoSheet.REASON1_COL): # Column Oからスタート
worksheet.update_cell(player_row_index, col, reason)
return worksheet

@retry_decorator
def update_sheet_cell(
self, sheet_name: str, row_index: int, col_index: int, value: Any
) -> gspread.Worksheet:
logger.info(
f"Updating {sheet_name} cell ({row_index}, {col_index}) with {value}"
)
logger.info(f"Updating {sheet_name} cell ({row_index}, {col_index}) with {value}")
worksheet = self._get_or_fetch_master_worksheet_data(sheet_name)
logger.debug(f"worksheet: {worksheet}, type: {type(worksheet)}")
worksheet.update_cell(row_index, col_index, value)
Expand All @@ -237,18 +212,12 @@ def update_sheet_cell(
# annotation用spreadsheet関連
@retry_decorator
def share_spreadsheet(self, sheet: gspread.Spreadsheet, email: str) -> None:
sheet.share(
email, perm_type="user", role="writer", with_link=False, notify=False
)
sheet.share(email, perm_type="user", role="writer", with_link=False, notify=False)
logger.info(f"Shared spreadsheet with {email}")

@retry_decorator
def create_player_gsheet(
self, player_name: str, player_email: str
) -> gspread.Spreadsheet:
sheet = self.client.create(
f"{player_name}", folder_id=settings.ANNOTATION_FOLDER_KEY
)
def create_player_gsheet(self, player_name: str, player_email: str) -> gspread.Spreadsheet:
sheet = self.client.create(f"{player_name}", folder_id=settings.ANNOTATION_FOLDER_KEY)
emails = [player_email, settings.STAFF_EMAIL, settings.BOT_EMAIL]
logger.info(f"Created spreadsheet for {player_name}. url: {sheet.url} ")
for email in emails:
Expand All @@ -275,19 +244,15 @@ def create_or_get_worksheet(
return worksheet

@retry_decorator
def _format_headers(
self, worksheet: gspread.Worksheet, dialogue_df: pd.DataFrame
) -> None:
def _format_headers(self, worksheet: gspread.Worksheet, dialogue_df: pd.DataFrame) -> None:
header_formatter = BasicFormatter(freeze_headers=True)
format_with_dataframe(worksheet, dialogue_df, header_formatter)

@retry_decorator
def _format_text_columns(
self, worksheet: gspread.worksheet, num_rows: int, is_criminal: bool
) -> None:
message_column_format = cellFormat(
horizontalAlignment="LEFT", wrapStrategy="WRAP"
)
message_column_format = cellFormat(horizontalAlignment="LEFT", wrapStrategy="WRAP")
columns_to_format = [
(
AnnotationSheet.MESSAGE_COL,
Expand Down Expand Up @@ -326,18 +291,12 @@ def _format_checkboxes(
validation_rule,
)
else:
suspicious_range = (
f"{AnnotationNotCriminalSheet.SUSPICIOUS_RANGE}{num_rows + 1}"
)
suspicious_range = f"{AnnotationNotCriminalSheet.SUSPICIOUS_RANGE}{num_rows + 1}"
deciding_factor_range = (
f"{AnnotationNotCriminalSheet.DECIDING_FACTOR_RANGE}{num_rows + 1}"
)
set_data_validation_for_cell_range(
worksheet, suspicious_range, validation_rule
)
set_data_validation_for_cell_range(
worksheet, deciding_factor_range, validation_rule
)
set_data_validation_for_cell_range(worksheet, suspicious_range, validation_rule)
set_data_validation_for_cell_range(worksheet, deciding_factor_range, validation_rule)

@retry_decorator
def apply_formatting_to_worksheet(
Expand All @@ -347,9 +306,7 @@ def apply_formatting_to_worksheet(
self._format_text_columns(worksheet, dialogue_df.shape[0], is_criminal)
self._format_checkboxes(worksheet, dialogue_df.shape[0], is_criminal)

def apply_protected_range(
self, worksheet: gspread.Worksheet, is_criminal: bool
) -> None:
def apply_protected_range(self, worksheet: gspread.Worksheet, is_criminal: bool) -> None:
worksheet.add_protected_range(
name=AnnotationSheet.PROTECT_RANGE + str(worksheet.row_count),
editor_users_emails=[settings.STAFF_BOT_ID_GMAILS],
Expand Down
60 changes: 4 additions & 56 deletions src/workers/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,56 +26,17 @@
logger = logging.getLogger("celery")


# def load_previous_data(file_name):
# if os.path.exists(file_name):
# if file_name.endswith(".pkl"):
# test = pd.read_pickle(file_name)
# saved_channels = list(
# set([row["channel_name"] for row in test["game_info"]])
# )
# return [i[1].to_dict() for i in test.iterrows()], saved_channels
#
# elif file_name.endswith(".jsonl"):
# with open(file_name, "r") as file:
# return [json.loads(line) for line in file]
# else:
# if file_name.endswith(".pkl"):
# return [], []
# elif file_name.endswith(".jsonl"):
# return []
#
#
# old_whole_log = load_previous_data(os.path.join(settings.TMP_DIR, "whole_log.jsonl"))
# old_test_data, saved_channels = load_previous_data(
# os.path.join(settings.TMP_DIR, "test.pkl")
# )


@celery.task(name="download_task", time_limit=CELERY_TIME_LIMIT)
@handle_worker_errors
def download_task() -> None:
with session_scope() as session:
game_info_records = game_info_repo.get_all_records(session=session)

results = []
# for i, record in enumerate(game_info_records):
# if (
# "murder-mystery" in record.channel_name
# and is_game_ready_to_proceed(record.channel_id, next_step="download")
# # and record.channel_id not in saved_channels
# ):
# wait_time = 60 if not i == 0 else 0
# result = process_game_info_record.delay(record.channel_id, wait_time)
# results.append(result)
#
# aggregate_results.delay(results)

task_group = group(
process_game_info_record.s(record.channel_id, i * 60)
for i, record in enumerate(game_info_records)
if "murder-mystery" in record.channel_name
and is_game_ready_to_proceed(record.channel_id, next_step="download")
# and record.channel_id not in saved_channels
)
callback = aggregate_results.s().set(countdown=1)
chord(task_group)(callback)
Expand All @@ -85,19 +46,14 @@ def download_task() -> None:
def process_game_info_record(channel_id: str, wait_time: int) -> tuple:
time.sleep(wait_time)
with session_scope() as session:
game_info_record = game_info_repo.get_game_info(
session=session, channel_id=channel_id
)
# Process the record and return the results
game_info_record = game_info_repo.get_game_info(session=session, channel_id=channel_id)
return process_record(game_info_record, session)


@celery.task(name="aggregate_results")
def aggregate_results(results: tuple) -> None:
# whole_log, test_data = old_whole_log, old_test_data
whole_log, test_data = [], []
for result in results:
# Unpack result, which is returned by process_game_info_record
log_entry, test_entries = result
for i in range(len(test_entries)):
test_entries[i]["nested_utters"] = pd.DataFrame(
Expand All @@ -106,21 +62,17 @@ def aggregate_results(results: tuple) -> None:
whole_log.append(log_entry)
test_data.extend(test_entries)

# Save final aggregated results
save_to_pkl(test_data, file_name="test.pkl")
file_name = "whole_log.jsonl"
save_to_jsonl(whole_log, file_name=file_name)
after_download_task.delay()
# old_whole_log, old_test_data = whole_log, test_data


@celery.task(name="after_download_task")
def after_download_task() -> None:
if os.path.exists(settings.AFTER_DOWNLOAD_TASK_SCRIPT):
subprocess.run(["bash", settings.AFTER_DOWNLOAD_TASK_SCRIPT])
slack_client.send_direct_message(
user_id=settings.STAFF_ID, message="データのダウンロードが完了しました。"
)
slack_client.send_direct_message(user_id=settings.STAFF_ID, message="データのダウンロードが完了しました。")


def process_record(game_info_record: GameInfoTable, session: Any) -> tuple:
Expand Down Expand Up @@ -200,9 +152,7 @@ def process_record(game_info_record: GameInfoTable, session: Any) -> tuple:
test_annotation_list_cols.append(speaker_df[col_name].to_list())

annotations = dict(
zip(
annotation_col_names, (test_annotation_list_cols + annotation_cols[-2:])
)
zip(annotation_col_names, (test_annotation_list_cols + annotation_cols[-2:]))
)

game_info_col_names = [
Expand Down Expand Up @@ -289,9 +239,7 @@ def process_annotation_data(
)
cols.append(
convert_bool_str_list_to_int_list(
df.iloc[
:, AnnotationNotCriminalSheet.DECIDING_FACTOR_COL_NUM - 1
].to_list()
df.iloc[:, AnnotationNotCriminalSheet.DECIDING_FACTOR_COL_NUM - 1].to_list()
)
)
col_names.extend(
Expand Down

0 comments on commit d5de264

Please sign in to comment.