Skip to content

Commit

Permalink
Pose Estimation
Browse files Browse the repository at this point in the history
- Update requirements for Pose Estimation
- Add comments
- Clean code (get rid of prints, format output json)
  • Loading branch information
dweizzz committed Oct 15, 2023
1 parent 8b16843 commit 57f0f1d
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 41 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ venv
.env
__pycache__
yolov8m-pose.pt
*.mp4 # ignores all .mp4 files; remove this line if some .mp4 files should be tracked
*.mp4
ball/lib/python3.11/site-packages/torch/lib/libtorch_cpu.dylib
venv/lib/python3.11/site-packages/torch/lib/libtorch_cpu.dylib
ball/
Expand Down
14 changes: 10 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ opencv-python==4.7.0.72
matplotlib>=3.2.2
Pillow>=7.1.2
PyYAML>=5.3.1
torch==2.0.1
torchvision==0.15.2
torch==2.0.1 # Check if this version exists, if not, use the latest stable version
torchvision==0.15.2 # Same check as torch
tqdm>=4.41.0
seaborn
scipy
Expand All @@ -46,7 +46,13 @@ imageio

# View
streamlit>=1.18.1
hydralit_components>= 1.0.10
hydralit_components>=1.0.10

# Misc
pylint
pylint

# Additional Dependencies for Pose Estimation
json5
ultralytics
imageio==2.9.0
imageio-ffmpeg>=0.4.3
90 changes: 54 additions & 36 deletions src/pose_estimation/pose_estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,100 +4,118 @@
import torch
import math
import json
from ultralytics import YOLO

from ultralytics import YOLO # Ensure ultralytics is installed and configured

class PoseEstimator:
def __init__(self, model_path='yolov8m-pose.pt', video_path='test_video.mp4', combinations=None):
# Initialize paths, model, and combinations of keypoints to calculate angles
self.model_path = model_path
self.video_path = video_path
self.model = YOLO(model_path)
self.model = YOLO(model_path) # Load the YOLO model

# Adjusted combinations of points to calculate 8 angles
self.combinations = combinations if combinations is not None else [
(5, 7, 9), # Left arm: shoulder, elbow, wrist
(6, 8, 10), # Right arm: shoulder, elbow, wrist
(11, 13, 15), # Left leg: hip, knee, ankle
(12, 14, 16), # Right leg: hip, knee, ankle
(5, 6, 8), # Shoulder, right shoulder, right elbow (shoulder angle)
(6, 5, 7), # Right shoulder, left shoulder, left elbow
(11, 12, 14), # Hip, right hip, right knee
(12, 11, 13), # Right hip, left hip, left knee
(5, 11, 12), # Left shoulder, left hip, right hip
(6, 12, 11) # Right shoulder, right hip, left hip
]
(5, 7, 9), (6, 8, 10), (11, 13, 15), (12, 14, 16),
(5, 6, 8), (6, 5, 7), (11, 12, 14), (12, 11, 13)
]

# Names corresponding to the adjusted 8 angle types for better understanding
self.angle_names = [
"left_arm", "right_arm", "left_leg", "right_leg",
"shoulder_left_right_elbow", "shoulder_right_left_elbow",
"hip_left_right_knee", "hip_right_left_knee"
]

@staticmethod
def compute_angle(p1, p2, p3):
vector_a = p1 - p2
vector_b = p3 - p2

vector_a = vector_a / torch.norm(vector_a)
vector_b = vector_b / torch.norm(vector_b)
# Calculate angle given 3 points using the dot product and arc cosine
vector_a = p1 - p2
vector_b = p3 - p2

cosine_angle = torch.dot(vector_a, vector_b)
angle_radians = torch.acos(cosine_angle)
# Normalize the vectors (to make them unit vectors)
vector_a = vector_a / torch.norm(vector_a)
vector_b = vector_b / torch.norm(vector_b)

angle_degrees = angle_radians * 180 / math.pi
# Compute the angle
cosine_angle = torch.dot(vector_a, vector_b)
angle_radians = torch.acos(cosine_angle)
angle_degrees = angle_radians * 180 / math.pi

return angle_degrees
return angle_degrees

def estimate_pose(self):

# Start video capture
cap = cv2.VideoCapture(self.video_path)

# Initialize video writer
writer = imageio.get_writer("tmp/test_result.mp4", mode="I")

# Initialize an empty list to store pose data
pose_data = []

# Names corresponding to angle types for better understanding
angle_names = [
"Left arm", "Right arm", "Left leg", "Right leg",
"Shoulder (L-R-E)", "Right shoulder (R-L-E)", "Hip (L-R-K)",
"Right hip (R-L-K)", "Left shoulder (L-H-RH)", "Right shoulder (R-H-LH)"
"left_arm", "right_arm", "left_leg", "right_leg",
"shoulder_left_right_elbow", "shoulder_right_left_elbow",
"hip_left_right_knee", "hip_right_left_knee",
"shoulder_left_hip_right_hip", "shoulder_right_hip_left_hip"
]

while cap.isOpened():
# Read frame-by-frame
success, frame = cap.read()
if success:
# Measure start time (to calculate FPS later)
start_time = time.time()

# Run pose estimation model on the frame
results = self.model(frame, verbose=False)

# Extract the keypoints
keypoints = results[0].keypoints.xy

# Create annotated frame visualization
annotated_frame = results[0].plot()

# Store frame number in data
frame_pose_data = {'frame': cap.get(cv2.CAP_PROP_POS_FRAMES)}

# Calculate and display angles based on keypoint combinations
for idx, combination in enumerate(self.combinations):
p1 = keypoints[0][combination[0]]
p2 = keypoints[0][combination[1]]
p3 = keypoints[0][combination[2]]

p1, p2, p3 = keypoints[0][combination[0]], keypoints[0][combination[1]], keypoints[0][combination[2]]
angle_degrees = self.compute_angle(p1, p2, p3)

print(f"{angle_names[idx]} Angle: {angle_degrees.item():.2f}°")

frame_pose_data[f'angle_{idx}'] = angle_degrees.item()
# Add angle data to frame_pose_data
frame_pose_data[angle_names[idx]] = angle_degrees.item()

# Display angle on the annotated frame
cv2.putText(annotated_frame, f"{angle_degrees:.2f}°",
(int(p2[0]), int(p2[1])), cv2.FONT_HERSHEY_COMPLEX,
0.5, (255, 255, 0), 1, cv2.LINE_AA)

# Append frame's pose data to the list
pose_data.append(frame_pose_data)

# Calculate and display FPS
end_time = time.time()
fps = 1 / (end_time - start_time)

cv2.putText(annotated_frame, f"FPS: {int(fps)}", (10, 50),
cv2.FONT_HERSHEY_COMPLEX, 1.2, (255, 0, 255), 1, cv2.LINE_AA)

# Write annotated frame to output video
annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
writer.append_data(annotated_frame)

# Break loop on 'q' key press
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
break

# Cleanup and save data
writer.close()
cap.release()
cv2.destroyAllWindows()

with open("tmp/pose_data.json", "w") as f:
json.dump(pose_data, f)
json.dump(pose_data, f)
Binary file modified tmp/court_video.mp4
Binary file not shown.

0 comments on commit 57f0f1d

Please sign in to comment.