Skip to content

Commit

Permalink
Moved frame creation, counting, filtering into state
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikonooooo committed Oct 22, 2023
1 parent f0a9139 commit 5cdb2ac
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 50 deletions.
48 changes: 8 additions & 40 deletions src/processing/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Parsing module for parsing all
models outputs into the state
"""
from state import *
from state import GameState, Frame, ObjectType


def parse_sort_output(state: GameState, sort_output) -> None:
Expand All @@ -21,11 +21,10 @@ def parse_sort_output(state: GameState, sort_output) -> None:
lines = [[int(x) for x in line.split()] for line in file.readlines()]
file.close()

sts = state.states
sts = state.frames
b = 0 # index of line in ball
s = 0 # index of state
while b < len(lines):
# TODO modify for BoT-SORT output
frame, obj_type, id, xmin, ymin, xwidth, ywidth = lines[b][:7]
sF: Frame = sts[s]
if s <= len(sts): # s-1 frameno < bframe, s = len(states)
Expand All @@ -40,51 +39,20 @@ def parse_sort_output(state: GameState, sort_output) -> None:

sF: Frame = sts[s] # frame to be updated DO NOT DELETE LINE
assert sF.frameno == frame
box = (xmin, ymin, xmin + xwidth, ymin + ywidth)
if obj_type is ObjectType.BALL.value:
bf = BallFrame(xmin, ymin, xmax=xmin + xwidth, ymax=ymin + ywidth)
sF.ball = bf
id = "ball_" + id
sF.balls.update({id: bf})
if id not in state.balls: # if new ball
state.balls.update({id: BallState()})
bs: BallState = state.balls.get(id)
bs.frames += 1
sF.set_ball_frame(id, *box)
elif obj_type is ObjectType.PLAYER.value:
pf = PlayerFrame(xmin, ymin, xmax=xmin + xwidth, ymax=ymin + ywidth)
id = "player_" + id
sF.players.update({id: pf})
if id not in state.players: # if new player
state.players.update({id: PlayerState()})
ps: PlayerState = state.players.get(id)
ps.frames += 1
sF.add_player_frame(id, *box)
elif obj_type is ObjectType.RIM.value:
box = Box(xmin, ymin, xmax=xmin + xwidth, ymax=ymin + ywidth)
sF.rim = box
sF.set_rim_box(id, *box)

b += 1 # process next line


def filter_players(state: GameState, threshold: int) -> None:
"removes all players which appear for less than [threshold] frames"
for k in state.players:
v: PlayerState = state.players.get(k)
if v.frames < threshold:
state.players.pop(k)


def filter_balls(state: GameState, threshold: int) -> None:
"removes all balls which appear for less than [threshold] frames"
for k in state.balls:
v: BallState = state.balls.get(k)
if v.frames < threshold:
state.balls.pop(k)


def clean(state: GameState, pframe_threshold: int, bframe_threshold: int):
def clean(state: GameState, pframe_threshold: int):
"""
Imputes missing data and filters outs noise after parsing
pframe_threshold: min frames a player should appear for in video
bframe_threshold: max frames a player should appear for in video
"""
filter_players(state, pframe_threshold)
filter_balls(state, bframe_threshold)
state.filter_players(pframe_threshold)
2 changes: 1 addition & 1 deletion src/processing/team_detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def team_split(state: GameState):
possession with start and finish frames
"""
player_list = state.players.keys()
pos_lst = possession_list(state.states, player_list, thresh=11)
pos_lst = possession_list(state.frames, player_list, thresh=11)
player_idx = {player: i for i, player in enumerate(player_list)}
connects = connections(pos_lst, player_list, player_idx)
teams = possible_teams(player_list)
Expand Down
8 changes: 4 additions & 4 deletions src/processrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def run_parse(self):
"Runs parse module over SORT (and pose later) outputs to update GameState"
parse.parse_sort_output(self.state, self.players_tracking)
parse.parse_sort_output(self.state, self.ball_tracking)
parse.clean(self.state, 100, 100)
parse.clean(self.state, 100)

def run_team_detect(self):
"""
Expand All @@ -47,7 +47,7 @@ def run_team_detect(self):
Splits identified players into teams, then curates:
ball state, passes, player possession, and team possession
"""
teams, pos_list, playerids = team_detect.team_split(self.state.states)
teams, pos_list, playerids = team_detect.team_split(self.state.frames)
self.state.possession_list = pos_list
for pid in playerids:
self.state.players[pid] = {
Expand All @@ -57,7 +57,7 @@ def run_team_detect(self):
"assists": 0,
}
self.state.ball_state = general_detect.ball_state_update(
pos_list, len(self.state.states) - 1
pos_list, len(self.state.frames) - 1
)
self.state.passes = general_detect.player_passes(pos_list)
self.state.possession = general_detect.player_possession(pos_list)
Expand Down Expand Up @@ -87,7 +87,7 @@ def run_video_render(self):
"""Runs video rendering and reencodes, stores to output_video_path_reenc."""
videoRender = video_render.VideoRender(self.homography)
videoRender.render_video(
self.state.states, self.state.players, self.output_video_path
self.state.frames, self.state.players, self.output_video_path
)
videoRender.reencode(self.output_video_path, self.output_video_path_reenc)

Expand Down
44 changes: 39 additions & 5 deletions src/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,12 +204,29 @@ def __init__(self, frameno: int) -> None:
# MUTABLE
self.players: dict = {} # ASSUMPTION: MULITPLE PEOPLE
"dictionary of form {player_[id] : PlayerFrame}"
self.ball: BallFrame = None # ASSUMPTION: SINGLE BALL
self.balls: dict = {} # ASSUMPTION: MULITPLE BALLS
self.ball: BallFrame = None # ASSUMPTION: SINGLE BALLS
"dictionary of form {ball_[id] : BallFrame}"
self.rim: Box = None # ASSUMPTION: SINGLE RIM
"bounding box of rim"

def add_player_frame(self, id: int, xmin: int, ymin: int, xmax: int, ymax: int):
"update players in frame given id and bounding boxes"
pf = PlayerFrame(xmin, ymin, xmax, ymax)
id = "player_" + id
self.players.update({id: pf})

def set_ball_frame(self, id: int, xmin: int, ymin: int, xmax: int, ymax: int):
"set ball in frame given id and bounding boxes"
bf = BallFrame(xmin, ymin, xmax, ymax)
id = "ball_" + id
self.ball = bf

def set_rim_box(self, id: int, xmin: int, ymin: int, xmax: int, ymax: int):
"set rim box given bounding boxes"
bf = BallFrame(xmin, ymin, xmax, ymax)
id = "ball_" + id
self.balls.update({id: bf})

def check(self) -> bool:
"verifies if well-defined"
try:
Expand Down Expand Up @@ -270,7 +287,7 @@ def __init__(self) -> None:
"""
# MUTABLE

self.states: list = []
self.frames: list = []
"list of frames: [Frame], each frame has player, ball, and rim info"

self.players: dict = {}
Expand All @@ -279,7 +296,8 @@ def __init__(self) -> None:
self.balls: dict = {}
"{ball_0 : BallState, ball_1 : BallState}"

self.shots
self.shots: list = []
" list of shots: [(player_[id],start,end)]"

# EVERYTHING BELOW THIS POINT IS OUT-OF-DATE

Expand All @@ -301,6 +319,22 @@ def __init__(self) -> None:
self.team1_pos = 0
self.team2_pos = 0

def recompute_frame_count(self):
"recompute frame count of all players in frames"
for ps in self.players.values(): # reset to 0
ps.frame = 0
for frame in self.frames:
for pid in frame.players:
if pid not in self.players:
self.players.update({pid : PlayerState()})
self.players.get(pid).frame += 1

def filter_players(self, threshold: int):
"removes all players which appear for less than [threshold] frames"
self.recompute_frame_count()
for k, v in enumerate(self.players):
if v.frames < threshold:
self.players.pop(k)

def update_scores(self, madeshot_list):
"""
Expand Down Expand Up @@ -346,7 +380,7 @@ def __repr__(self) -> str:
if len(self.rim) > 0
else "None",
"Court lines coordinates": "None",
"Number of frames": str(len(self.states)),
"Number of frames": str(len(self.frames)),
"Number of players": str(len(self.players)),
"Number of passes": str(len(self.passes)),
"Team 1": str(self.team1),
Expand Down

0 comments on commit 5cdb2ac

Please sign in to comment.