Skip to content

Commit

Permalink
Integrated StrongSORT -- issue now is that the video is just the rand…
Browse files Browse the repository at this point in the history
…om saved video output from YOLOv5 after putting benson video through website
  • Loading branch information
Eric Guo authored and Mikonooooo committed Sep 2, 2023
1 parent c425148 commit e0fe33f
Show file tree
Hide file tree
Showing 8 changed files with 33 additions and 43 deletions.
Binary file modified model/StrongSORT-YOLO/testing.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model/StrongSORT-YOLO/weights/yolov5m.pt
Binary file not shown.
Binary file added model/StrongSORT-YOLO/yolov5/models/yolov5m.pt
Binary file not shown.
2 changes: 1 addition & 1 deletion model/StrongSORT-YOLO/yolov5/utils/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def curl_download(url, filename, *, silent: bool = False) -> bool:

def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
# Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
from utils.general import LOGGER
from yolov5.utils.general import LOGGER

file = Path(file)
assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}"
Expand Down
Empty file modified model/temp/user.mp4
100755 → 100644
Empty file.
8 changes: 4 additions & 4 deletions src/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ model_vars:
output_file_name: 'user.mp4'
file_name: 'new.mp4'
frame_reduction_factor: 2
yolo_weights: 'yolov5m.pt'
conf_thres: 0.6
iou_thres: 0.45
yolo_weights: 'model/StrongSORT-YOLO/yolov5/models/yolov5m.pt'
conf_thres: '0.6'
iou_thres: '0.45'
classes: '0 32'
exist_ok: True
exist_ok: 'True'
21 changes: 15 additions & 6 deletions src/modelrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
import sys
import os
import cv2
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../model/StrongSORT-YOLO')))
import importlib
strongsort = importlib.import_module("model.StrongSORT-YOLO")
import track_v5
from pathlib import Path
# strongsort = importlib.import_module("StrongSORT-YOLO")
import subprocess

# run the Python file as a script



def run_models(user_file_path, model_configs):
Expand All @@ -22,17 +28,20 @@ def run_models(user_file_path, model_configs):
file_name = model_vars['file_name']
frame_reduction_factor = model_vars['frame_reduction_factor']
weights = model_vars['yolo_weights']
weights = Path(weights)
conf_threshold = model_vars['conf_thres']
iou_threshold = model_vars['iou_thres']
classes = model_vars['classes']
exist_ok = model_vars['exist_ok']
drop_frames(f'{export_path}/{output_file_name}', frame_reduction_factor)
# detect.run(source=user_file_path, project=export_path,
# name=name, exist_ok=True)
track_v5.run(source=user_file_path, yolo_weights=weights, conf_thres=conf_threshold, iou_thres=ios_threshold,
project=export_path, classes=classes, name=name, exist_ok=exist_ok)
reencode(f'{export_path}/{name}/{output_file_name}',
f'{export_path}/{name}/{file_name}')
subprocess.run(['python', '/track_v5.py', '--source', user_file_path, '--conf_thres', conf_threshold, '--iou_thres', iou_threshold,
'--project', export_path, '--classes', classes, '--name', name, '--exist_ok', exist_ok])
# track_v5.run(source=user_file_path, conf_thres=conf_threshold, iou_thres=iou_threshold,
# project=export_path, classes=classes, name=name, exist_ok=exist_ok)
# reencode(f'{export_path}/{name}/{output_file_name}',
# f'{export_path}/{name}/{file_name}')
return os.path.join(export_path, name, file_name)


Expand Down
45 changes: 13 additions & 32 deletions view/app.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
Frontend
"""

import io
import sys
import os
Expand All @@ -11,6 +10,7 @@
import requests
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src import main

# Set up tab title and favicon
st.set_page_config(page_title='HoopTracker', page_icon=':basketball:')

Expand All @@ -21,8 +21,7 @@
if 'state' not in st.session_state:
st.session_state.state = 0
st.session_state.logo = 'view/media/basketball.png'
st.session_state.video_file = io.BytesIO(
open('view/media/demo_basketball.mov', 'rb').read())
st.session_state.video_file = io.BytesIO(open('view/media/demo_basketball.mov', 'rb').read())
st.session_state.results = pd.read_csv('view/media/demo_results.csv')
st.session_state.is_downloaded = False
st.session_state.upload_name = None
Expand All @@ -31,38 +30,31 @@
# Backend Connection -----------------------------------------
SERVER_URL = "http://127.0.0.1:8000/"
# Send request to Google Compute Machine


def process_video(video_file):
response = requests.post(SERVER_URL+"upload",
response = requests.post(SERVER_URL+"upload",
files={"video_file": video_file}, timeout=30)
if response.status_code == 200:
data = response.json()
st.session_state.upload_name = data.get('message')
else:
# maybe make an error handler in frontend
print('error uploading file')
print('error uploading file') # maybe make an error handler in frontend
st.session_state.is_downloaded = False
return video_file is not None


def fetch_csv():
if not st.session_state.is_downloaded:
out = requests.get(SERVER_URL+"results", timeout=30)
st.session_state.results = pd.read_csv(
io.StringIO(out.content.decode()))
st.session_state.results = pd.read_csv(io.StringIO(out.content.decode()))
st.session_state.is_downloaded = True
return st.session_state.results.to_csv()


def fetch_processed_video():
return main.main(st.session_state.upload_name)

# ------------------------------------------------------------

# Main Page


def main_page():
st.markdown('''
# HoopTracker
Expand All @@ -74,9 +66,9 @@ def main_page():
st.button(label="Having Trouble?", on_click=change_state, args=(-1,))

# Basketball Icon Filler
_, col2, _ = st.columns([0.5, 5, 0.5])
_, col2, _ = st.columns([0.5,5,0.5])
with col2:
st.image(image=st.session_state.logo, use_column_width=True)
st.image(image=st.session_state.logo,use_column_width=True)


# Tips Page
Expand All @@ -93,8 +85,6 @@ def tips_page():
st.button(label='Back to Home', on_click=change_state, args=(0,))

# Loading Screen


def processing_page():
st.markdown('''
# Processing...
Expand All @@ -110,8 +100,6 @@ def processing_page():
st.experimental_rerun()

# Display Data


def results_page():
st.video(open(fetch_processed_video(), 'rb').read())

Expand All @@ -136,10 +124,11 @@ def results_page():
y=('TO')
)


st.bar_chart(
data=st.session_state.results,
x='PLAYER',
y=('2PA', '3PA')
y=('2PA','3PA')
)

st.markdown('### Raw Data')
Expand All @@ -150,16 +139,13 @@ def results_page():
file_name="results.csv")

# Error Page


def error_page():
st.markdown('''
# Error: Webpage Not Found
Try reloading the page to fix the error.
''')
st.button(label='Back to Home', on_click=change_state, args=(0,))


def setup_sidebar():
# Display upload file widget
st.sidebar.markdown('# Upload')
Expand All @@ -175,15 +161,15 @@ def setup_sidebar():
st.sidebar.video(data=st.session_state.video_file)

# Process options to move to next state
col1, col2 = st.sidebar.columns([1, 17])
col1, col2 = st.sidebar.columns([1,17])
consent_check = col1.checkbox(label=" ", label_visibility='hidden')
col2.caption('''
I have read and agree to HoopTracker's
[terms of services.](https://github.com/CornellDataScience/Ball-101)
''')

st.sidebar.button(label='Upload & Process Video',
disabled=not consent_check,
disabled= not consent_check,
use_container_width=True,
on_click=change_state, args=(1,),
type='primary')
Expand All @@ -198,18 +184,13 @@ def setup_sidebar():
file_name="results.csv")

# Call back function to change page


def change_state(state: int):
def change_state(state:int):
st.session_state.state = state

# Updates video on screen


def update_video(video_file):
st.session_state.video_file = video_file


# Set home info page test
if st.session_state.state == -1:
tips_page()
Expand All @@ -222,4 +203,4 @@ def update_video(video_file):
else:
error_page()

setup_sidebar()
setup_sidebar()

0 comments on commit e0fe33f

Please sign in to comment.