Skip to content

Commit

Permalink
犯人役をoversampling
Browse files Browse the repository at this point in the history
  • Loading branch information
HiroshigeAoki committed Dec 2, 2023
1 parent dec9cd8 commit 75aba4b
Showing 1 changed file with 54 additions and 39 deletions.
93 changes: 54 additions & 39 deletions src/workers/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def aggregate_results(results: tuple) -> None:

@celery.task(name="after_download_task")
def after_download_task() -> None:
subprocess.run(["sh", settings.AFTER_DOWNLOAD_TASK_SCRIPT])
if os.path.exists(settings.AFTER_DOWNLOAD_TASK_SCRIPT):
subprocess.run(["bash", settings.AFTER_DOWNLOAD_TASK_SCRIPT])
slack_client.send_direct_message(user_id=settings.STAFF_ID, message="データのダウンロードが完了しました。")


Expand All @@ -93,40 +94,40 @@ def process_record(game_info_record: GameInfoTable, session: Any) -> tuple:
annotation_col_names.extend(whole_log_col_names)
points[participant_record.char_name] = point

# For jsonl file
game_info_col_names = [
"channel_name",
"criminal",
"has_criminal_won",
"points",
"scenario_title",
"scenario_url",
]
game_info_cols = [
game_info_record.channel_name,
game_info_sheet.criminal,
int(game_info_sheet.has_criminal_won),
points,
game_info_record.title,
scenario_sheet.scenario_url,
]

ts = df.iloc[:, AnnotationSheet.TS_COL_NUM - 1].to_list()
speakers = df.iloc[:, AnnotationSheet.NAME_COL_NUM - 1].to_list()
messages = df.iloc[:, AnnotationSheet.MESSAGE_COL_NUM - 1].to_list()

whole_log_col_names = (
["ts", "speakers", "messages"]
+ annotation_col_names
+ [
"channel_name",
"criminal",
"has_criminal_won",
"points",
"scenario_title",
"scenario_url",
]
)
whole_log_cols = (
[ts, speakers, messages]
+ annotation_cols
+ [
game_info_record.channel_name,
int(game_info_sheet.has_criminal_won),
points,
game_info_record.title,
scenario_sheet.scenario_url,
]
["ts", "speakers", "messages"] + annotation_col_names + game_info_col_names
)
whole_log_cols = [ts, speakers, messages] + annotation_cols + game_info_cols
whole_log = dict(zip(whole_log_col_names, whole_log_cols))

list_data = [messages, speakers] + annotation_cols[:-2]
list_columns = ["nested_utters", "speakers"] + annotation_col_names[:-2]
logger.debug(f"list_data={list_data}")
logger.debug(f"columns={list_columns}")

# For pkl file
df = pd.DataFrame(dict(zip(list_columns, list_data)))
test_data = []
for participant_record in participant_records:
Expand All @@ -141,27 +142,41 @@ def process_record(game_info_record: GameInfoTable, session: Any) -> tuple:
zip(annotation_col_names, (test_annotation_list_cols + annotation_cols[-2:]))
)

test_data_col_names = [
"nested_utters",
"labels",
"annotations",
"channel_name",
game_info_col_names = [
"speaker",
"channel_name",
"criminal",
"has_criminal_won",
"point",
"points",
"scenario_title",
]
test_data_cols = [
utterances,
label,
annotations,
game_info_record.channel_name,
participant_record.char_name,
int(game_info_sheet.has_criminal_won),
participant_record.point,
game_info_record.title,
]
test_data.append(dict(zip(test_data_col_names, test_data_cols)))

game_info = dict(
speaker=participant_record.char_name,
channel_name=game_info_record.channel_name,
has_criminal_won=int(game_info_sheet.has_criminal_won),
point=participant_record.point,
scenario_title=game_info_record.title,
)

test_data.append(
dict(
nested_utters=utterances, labels=label, annotations=annotations, game_info=game_info
)
)

# Oversampling(犯人役1人の場合を想定)
if participant_record.is_criminal:
num_non_criminal = len(participant_records) - 2
for _ in range(num_non_criminal):
test_data.append(
dict(
nested_utters=utterances,
labels=label,
annotations=annotations,
game_info=game_info,
)
)

return whole_log, test_data

Expand Down

0 comments on commit 75aba4b

Please sign in to comment.