Ported to ultralytics for better POSE estimation
This commit is contained in:
47
main.py
47
main.py
@@ -1,43 +1,51 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import cv2
|
import cv2
|
||||||
import mediapipe as mp
|
from ultralytics import YOLO
|
||||||
|
|
||||||
from exercises import EXERCISES
|
from exercises import EXERCISES
|
||||||
|
|
||||||
mp_pose = mp.solutions.pose
|
|
||||||
mp_draw = mp.solutions.drawing_utils
|
|
||||||
|
|
||||||
BEEP_FILE = os.path.join(os.path.dirname(__file__), 'beep.mp3')
|
BEEP_FILE = os.path.join(os.path.dirname(__file__), 'beep.mp3')
|
||||||
|
MODEL_PATH = os.path.join(os.path.dirname(__file__), 'yolo11x-pose.pt')
|
||||||
|
|
||||||
|
HINT = '[P]ush [U]pull [B]ench [C]url [S]itup [L]plank [R]eset [Q]uit'
|
||||||
|
|
||||||
|
|
||||||
def beep():
|
def beep():
|
||||||
subprocess.Popen(['afplay', BEEP_FILE],
|
subprocess.Popen(['afplay', BEEP_FILE],
|
||||||
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||||
|
|
||||||
|
|
||||||
|
model = YOLO(MODEL_PATH)
|
||||||
|
|
||||||
cap = cv2.VideoCapture(0)
|
cap = cv2.VideoCapture(0)
|
||||||
count = 0
|
count = 0
|
||||||
stage = None
|
stage = None
|
||||||
mode = 'p' # default: push-ups
|
mode = 'p'
|
||||||
|
|
||||||
HINT = '[P]ush [U]pull [B]ench [C]url [S]itup [L]plank [R]eset [Q]uit'
|
while cap.isOpened():
|
||||||
|
|
||||||
with mp_pose.Pose(min_detection_confidence=0.6, min_tracking_confidence=0.6) as pose:
|
|
||||||
while cap.isOpened():
|
|
||||||
ret, frame = cap.read()
|
ret, frame = cap.read()
|
||||||
if not ret:
|
if not ret:
|
||||||
break
|
break
|
||||||
|
|
||||||
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
results = model(frame, verbose=False)
|
||||||
image.flags.writeable = False
|
result = results[0]
|
||||||
results = pose.process(image)
|
|
||||||
image.flags.writeable = True
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
|
||||||
|
|
||||||
|
# Draw YOLO skeleton on a copy of the frame (no bounding boxes or conf labels)
|
||||||
|
image = result.plot(boxes=False, conf=False, labels=False)
|
||||||
|
|
||||||
|
# Extract the first detected person's keypoints
|
||||||
|
kps = conf = None
|
||||||
|
if (result.keypoints is not None
|
||||||
|
and result.keypoints.xyn is not None
|
||||||
|
and len(result.keypoints.xyn) > 0):
|
||||||
|
kps = result.keypoints.xyn[0].cpu().numpy() # (17, 2) normalised
|
||||||
|
conf = result.keypoints.conf[0].cpu().numpy() # (17,)
|
||||||
|
|
||||||
|
if kps is not None:
|
||||||
try:
|
try:
|
||||||
lms = results.pose_landmarks.landmark
|
|
||||||
ex = EXERCISES[mode]
|
ex = EXERCISES[mode]
|
||||||
new_stage, counted = ex['fn'](lms, stage)
|
new_stage, counted = ex['fn'](kps, conf, stage)
|
||||||
if new_stage != stage:
|
if new_stage != stage:
|
||||||
stage = new_stage
|
stage = new_stage
|
||||||
beep()
|
beep()
|
||||||
@@ -46,7 +54,7 @@ with mp_pose.Pose(min_detection_confidence=0.6, min_tracking_confidence=0.6) as
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# --- UI ---
|
# ── UI overlay ───────────────────────────────────────────────────────────
|
||||||
h, w = image.shape[:2]
|
h, w = image.shape[:2]
|
||||||
ex = EXERCISES[mode]
|
ex = EXERCISES[mode]
|
||||||
color = ex['color']
|
color = ex['color']
|
||||||
@@ -63,12 +71,9 @@ with mp_pose.Pose(min_detection_confidence=0.6, min_tracking_confidence=0.6) as
|
|||||||
cv2.putText(image, HINT, (8, h - 8),
|
cv2.putText(image, HINT, (8, h - 8),
|
||||||
cv2.FONT_HERSHEY_SIMPLEX, 0.48, (160, 160, 160), 1)
|
cv2.FONT_HERSHEY_SIMPLEX, 0.48, (160, 160, 160), 1)
|
||||||
|
|
||||||
if results.pose_landmarks:
|
|
||||||
mp_draw.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
|
|
||||||
|
|
||||||
cv2.imshow('Exercise Counter', image)
|
cv2.imshow('Exercise Counter', image)
|
||||||
|
|
||||||
key = cv2.waitKey(10) & 0xFF
|
key = cv2.waitKey(1) & 0xFF
|
||||||
if key == ord('q'):
|
if key == ord('q'):
|
||||||
break
|
break
|
||||||
elif key == ord('r'):
|
elif key == ord('r'):
|
||||||
|
|||||||
Reference in New Issue
Block a user