Skip to content

Commit

Permalink
tracks angles for multiple players
Browse files Browse the repository at this point in the history
  • Loading branch information
audreywangg committed Oct 19, 2023
1 parent 57f0f1d commit a29f6df
Showing 1 changed file with 85 additions and 36 deletions.
121 changes: 85 additions & 36 deletions src/pose_estimation/pose_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,45 @@
import json
from ultralytics import YOLO # Ensure ultralytics is installed and configured


class PoseEstimator:
def __init__(self, model_path='yolov8m-pose.pt', video_path='test_video.mp4', combinations=None):
def __init__(
self,
model_path="yolov8m-pose.pt",
video_path="short_training_data.mp4",
combinations=None,
):
# Initialize paths, model, and combinations of keypoints to calculate angles
self.model_path = model_path
self.video_path = video_path
self.model = YOLO(model_path) # Load the YOLO model

# Adjusted combinations of points to calculate 8 angles
self.combinations = combinations if combinations is not None else [
(5, 7, 9), (6, 8, 10), (11, 13, 15), (12, 14, 16),
(5, 6, 8), (6, 5, 7), (11, 12, 14), (12, 11, 13)
]
self.combinations = (
combinations
if combinations is not None
else [
(5, 7, 9),
(6, 8, 10),
(11, 13, 15),
(12, 14, 16),
(5, 6, 8),
(6, 5, 7),
(11, 12, 14),
(12, 11, 13),
]
)

# Names corresponding to the adjusted 8 angle types for better understanding
self.angle_names = [
"left_arm", "right_arm", "left_leg", "right_leg",
"shoulder_left_right_elbow", "shoulder_right_left_elbow",
"hip_left_right_knee", "hip_right_left_knee"
"left_arm",
"right_arm",
"left_leg",
"right_leg",
"shoulder_left_right_elbow",
"shoulder_right_left_elbow",
"hip_left_right_knee",
"hip_right_left_knee",
]

@staticmethod
Expand Down Expand Up @@ -55,18 +76,24 @@ def estimate_pose(self):

# Names corresponding to angle types for better understanding
angle_names = [
"left_arm", "right_arm", "left_leg", "right_leg",
"shoulder_left_right_elbow", "shoulder_right_left_elbow",
"hip_left_right_knee", "hip_right_left_knee",
"shoulder_left_hip_right_hip", "shoulder_right_hip_left_hip"
"left_arm",
"right_arm",
"left_leg",
"right_leg",
"shoulder_left_right_elbow",
"shoulder_right_left_elbow",
"hip_left_right_knee",
"hip_right_left_knee",
"shoulder_left_hip_right_hip",
"shoulder_right_hip_left_hip",
]

while cap.isOpened():
# Read frame-by-frame
success, frame = cap.read()
if success:
# Measure start time (to calculate FPS later)
start_time = time.time()
# start_time = time.time()

# Run pose estimation model on the frame
results = self.model(frame, verbose=False)
Expand All @@ -75,47 +102,69 @@ def estimate_pose(self):
keypoints = results[0].keypoints.xy

# Create annotated frame visualization
annotated_frame = results[0].plot()
# annotated_frame = results[0].plot()

# Store frame number in data
frame_pose_data = {'frame': cap.get(cv2.CAP_PROP_POS_FRAMES)}
frame_pose_data = {"frame": cap.get(cv2.CAP_PROP_POS_FRAMES)}

# Calculate and display angles based on keypoint combinations
for idx, combination in enumerate(self.combinations):
p1, p2, p3 = keypoints[0][combination[0]], keypoints[0][combination[1]], keypoints[0][combination[2]]
angle_degrees = self.compute_angle(p1, p2, p3)

# Add angle data to frame_pose_data
frame_pose_data[angle_names[idx]] = angle_degrees.item()

# Display angle on the annotated frame
cv2.putText(annotated_frame, f"{angle_degrees:.2f}°",
(int(p2[0]), int(p2[1])), cv2.FONT_HERSHEY_COMPLEX,
0.5, (255, 255, 0), 1, cv2.LINE_AA)
for person in range(len(keypoints)):
for idx, combination in enumerate(self.combinations):
p1, p2, p3 = (
keypoints[person][combination[0]],
keypoints[person][combination[1]],
keypoints[person][combination[2]],
)
angle_degrees = self.compute_angle(p1, p2, p3)

# Add angle data to frame_pose_data
frame_pose_data[
"person_" + str(person) + "_" + angle_names[idx]
] = angle_degrees.item()

# Display angle on the annotated frame
# cv2.putText(
# annotated_frame,
# f"{angle_degrees:.2f}°",
# (int(p2[0]), int(p2[1])),
# cv2.FONT_HERSHEY_COMPLEX,
# 0.5,
# (255, 255, 0),
# 1,
# cv2.LINE_AA,
# )

# Append frame's pose data to the list
pose_data.append(frame_pose_data)

# Calculate and display FPS
end_time = time.time()
fps = 1 / (end_time - start_time)
cv2.putText(annotated_frame, f"FPS: {int(fps)}", (10, 50),
cv2.FONT_HERSHEY_COMPLEX, 1.2, (255, 0, 255), 1, cv2.LINE_AA)
# end_time = time.time()
# fps = 1 / (end_time - start_time)
# cv2.putText(
# annotated_frame,
# f"FPS: {int(fps)}",
# (10, 50),
# cv2.FONT_HERSHEY_COMPLEX,
# 1.2,
# (255, 0, 255),
# 1,
# cv2.LINE_AA,
# )

# Write annotated frame to output video
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
writer.append_data(annotated_frame)
# annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
# writer.append_data(annotated_frame)

# Break loop on 'q' key press
if cv2.waitKey(1) & 0xFF == ord('q'):
if cv2.waitKey(1) & 0xFF == ord("q"):
break
else:
break

# Cleanup and save data
writer.close()
# writer.close()
cap.release()
cv2.destroyAllWindows()
# cv2.destroyAllWindows()

with open("tmp/pose_data.json", "w") as f:
json.dump(pose_data, f)
json.dump(pose_data, f)

0 comments on commit a29f6df

Please sign in to comment.