diff --git a/src/processing/parse.py b/src/processing/parse.py index 0cb13bdf..cad25ad9 100644 --- a/src/processing/parse.py +++ b/src/processing/parse.py @@ -2,7 +2,7 @@ Parsing module for parsing all models outputs into the state """ -from state import * +from state import GameState, Frame, ObjectType def parse_sort_output(state: GameState, sort_output) -> None: @@ -21,11 +21,10 @@ def parse_sort_output(state: GameState, sort_output) -> None: lines = [[int(x) for x in line.split()] for line in file.readlines()] file.close() - sts = state.states + sts = state.frames b = 0 # index of line in ball s = 0 # index of state while b < len(lines): - # TODO modify for BoT-SORT output frame, obj_type, id, xmin, ymin, xwidth, ywidth = lines[b][:7] sF: Frame = sts[s] if s <= len(sts): # s-1 frameno < bframe, s = len(states) @@ -40,51 +39,20 @@ def parse_sort_output(state: GameState, sort_output) -> None: sF: Frame = sts[s] # frame to be updated DO NOT DELETE LINE assert sF.frameno == frame + box = (xmin, ymin, xmin + xwidth, ymin + ywidth) if obj_type is ObjectType.BALL.value: - bf = BallFrame(xmin, ymin, xmax=xmin + xwidth, ymax=ymin + ywidth) - sF.ball = bf - id = "ball_" + id - sF.balls.update({id: bf}) - if id not in state.balls: # if new ball - state.balls.update({id: BallState()}) - bs: BallState = state.balls.get(id) - bs.frames += 1 + sF.set_ball_frame(id, *box) elif obj_type is ObjectType.PLAYER.value: - pf = PlayerFrame(xmin, ymin, xmax=xmin + xwidth, ymax=ymin + ywidth) - id = "player_" + id - sF.players.update({id: pf}) - if id not in state.players: # if new player - state.players.update({id: PlayerState()}) - ps: PlayerState = state.players.get(id) - ps.frames += 1 + sF.add_player_frame(id, *box) elif obj_type is ObjectType.RIM.value: - box = Box(xmin, ymin, xmax=xmin + xwidth, ymax=ymin + ywidth) - sF.rim = box + sF.set_rim_box(id, *box) b += 1 # process next line -def filter_players(state: GameState, threshold: int) -> None: - "removes all players which appear for less than [threshold] frames" - for k in state.players: - v: PlayerState = state.players.get(k) - if v.frames < threshold: - state.players.pop(k) - - -def filter_balls(state: GameState, threshold: int) -> None: - "removes all balls which appear for less than [threshold] frames" - for k in state.balls: - v: BallState = state.balls.get(k) - if v.frames < threshold: - state.balls.pop(k) - - -def clean(state: GameState, pframe_threshold: int, bframe_threshold: int): +def clean(state: GameState, pframe_threshold: int): """ Imputes missing data and filters outs noise after parsing pframe_threshold: min frames a player should appear for in video - bframe_threshold: max frames a player should appear for in video """ - filter_players(state, pframe_threshold) - filter_balls(state, bframe_threshold) + state.filter_players(pframe_threshold) diff --git a/src/processing/team_detect.py b/src/processing/team_detect.py index 39aae85e..6ca61431 100644 --- a/src/processing/team_detect.py +++ b/src/processing/team_detect.py @@ -174,7 +174,7 @@ def team_split(state: GameState): possession with start and finish frames """ player_list = state.players.keys() - pos_lst = possession_list(state.states, player_list, thresh=11) + pos_lst = possession_list(state.frames, player_list, thresh=11) player_idx = {player: i for i, player in enumerate(player_list)} connects = connections(pos_lst, player_list, player_idx) teams = possible_teams(player_list) diff --git a/src/processrunner.py b/src/processrunner.py index d6a44008..8d6f8df1 100644 --- a/src/processrunner.py +++ b/src/processrunner.py @@ -38,7 +38,7 @@ def run_parse(self): "Runs parse module over SORT (and pose later) outputs to update GameState" parse.parse_sort_output(self.state, self.players_tracking) parse.parse_sort_output(self.state, self.ball_tracking) - parse.clean(self.state, 100, 100) + parse.clean(self.state, 100) def run_team_detect(self): """ @@ -47,7 +47,7 @@ def run_team_detect(self): Splits identified players into teams, then curates: ball state, passes, player possession, and team possession """ - teams, pos_list, playerids = team_detect.team_split(self.state.states) + teams, pos_list, playerids = team_detect.team_split(self.state.frames) self.state.possession_list = pos_list for pid in playerids: self.state.players[pid] = { @@ -57,7 +57,7 @@ def run_team_detect(self): "assists": 0, } self.state.ball_state = general_detect.ball_state_update( - pos_list, len(self.state.states) - 1 + pos_list, len(self.state.frames) - 1 ) self.state.passes = general_detect.player_passes(pos_list) self.state.possession = general_detect.player_possession(pos_list) @@ -87,7 +87,7 @@ def run_video_render(self): """Runs video rendering and reencodes, stores to output_video_path_reenc.""" videoRender = video_render.VideoRender(self.homography) videoRender.render_video( - self.state.states, self.state.players, self.output_video_path + self.state.frames, self.state.players, self.output_video_path ) videoRender.reencode(self.output_video_path, self.output_video_path_reenc) diff --git a/src/state.py b/src/state.py index c14fdb9d..73f4191e 100644 --- a/src/state.py +++ b/src/state.py @@ -204,12 +204,29 @@ def __init__(self, frameno: int) -> None: # MUTABLE self.players: dict = {} # ASSUMPTION: MULITPLE PEOPLE "dictionary of form {player_[id] : PlayerFrame}" - self.ball: BallFrame = None # ASSUMPTION: SINGLE BALL - self.balls: dict = {} # ASSUMPTION: MULITPLE BALLS + self.ball: BallFrame = None # ASSUMPTION: SINGLE BALLS "dictionary of form {ball_[id] : BallFrame}" self.rim: Box = None # ASSUMPTION: SINGLE RIM "bounding box of rim" + def add_player_frame(self, id: int, xmin: int, ymin: int, xmax: int, ymax: int): + "update players in frame given id and bounding boxes" + pf = PlayerFrame(xmin, ymin, xmax, ymax) + id = "player_" + id + self.players.update({id: pf}) + + def set_ball_frame(self, id: int, xmin: int, ymin: int, xmax: int, ymax: int): + "set ball in frame given id and bounding boxes" + bf = BallFrame(xmin, ymin, xmax, ymax) + id = "ball_" + id + self.ball = bf + + def set_rim_box(self, id: int, xmin: int, ymin: int, xmax: int, ymax: int): + "set rim box given bounding boxes" + bf = BallFrame(xmin, ymin, xmax, ymax) + id = "ball_" + id + self.balls.update({id: bf}) + def check(self) -> bool: "verifies if well-defined" try: @@ -270,7 +287,7 @@ def __init__(self) -> None: """ # MUTABLE - self.states: list = [] + self.frames: list = [] "list of frames: [Frame], each frame has player, ball, and rim info" self.players: dict = {} @@ -279,7 +296,8 @@ def __init__(self) -> None: self.balls: dict = {} "{ball_0 : BallState, ball_1 : BallState}" - self.shots + self.shots: list = [] + " list of shots: [(player_[id],start,end)]" # EVERYTHING BELOW THIS POINT IS OUT-OF-DATE @@ -301,6 +319,22 @@ def __init__(self) -> None: self.team1_pos = 0 self.team2_pos = 0 + def recompute_frame_count(self): + "recompute frame count of all players in frames" + for ps in self.players.values(): # reset to 0 + ps.frame = 0 + for frame in self.frames: + for pid in frame.players: + if pid not in self.players: + self.players.update({pid : PlayerState()}) + self.players.get(pid).frame += 1 + + def filter_players(self, threshold: int): + "removes all players which appear for less than [threshold] frames" + self.recompute_frame_count() + for k, v in enumerate(self.players): + if v.frames < threshold: + self.players.pop(k) def update_scores(self, madeshot_list): """ @@ -346,7 +380,7 @@ def __repr__(self) -> str: if len(self.rim) > 0 else "None", "Court lines coordinates": "None", - "Number of frames": str(len(self.states)), + "Number of frames": str(len(self.frames)), "Number of players": str(len(self.players)), "Number of passes": str(len(self.passes)), "Team 1": str(self.team1),