Ported to ultralytics for better POSE estimation

This commit is contained in:
Jon
2026-05-13 17:07:43 +01:00
parent 3a57de19b2
commit 2a23eefaa3

45
main.py
View File

@@ -1,43 +1,51 @@
import os import os
import subprocess import subprocess
import cv2 import cv2
import mediapipe as mp from ultralytics import YOLO
from exercises import EXERCISES from exercises import EXERCISES
mp_pose = mp.solutions.pose
mp_draw = mp.solutions.drawing_utils
BEEP_FILE = os.path.join(os.path.dirname(__file__), 'beep.mp3') BEEP_FILE = os.path.join(os.path.dirname(__file__), 'beep.mp3')
MODEL_PATH = os.path.join(os.path.dirname(__file__), 'yolo11x-pose.pt')
HINT = '[P]ush [U]pull [B]ench [C]url [S]itup [L]plank [R]eset [Q]uit'
def beep(): def beep():
subprocess.Popen(['afplay', BEEP_FILE], subprocess.Popen(['afplay', BEEP_FILE],
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
model = YOLO(MODEL_PATH)
cap = cv2.VideoCapture(0) cap = cv2.VideoCapture(0)
count = 0 count = 0
stage = None stage = None
mode = 'p' # default: push-ups mode = 'p'
HINT = '[P]ush [U]pull [B]ench [C]url [S]itup [L]plank [R]eset [Q]uit'
with mp_pose.Pose(min_detection_confidence=0.6, min_tracking_confidence=0.6) as pose:
while cap.isOpened(): while cap.isOpened():
ret, frame = cap.read() ret, frame = cap.read()
if not ret: if not ret:
break break
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = model(frame, verbose=False)
image.flags.writeable = False result = results[0]
results = pose.process(image)
image.flags.writeable = True
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Draw YOLO skeleton on a copy of the frame (no bounding boxes or conf labels)
image = result.plot(boxes=False, conf=False, labels=False)
# Extract the first detected person's keypoints
kps = conf = None
if (result.keypoints is not None
and result.keypoints.xyn is not None
and len(result.keypoints.xyn) > 0):
kps = result.keypoints.xyn[0].cpu().numpy() # (17, 2) normalised
conf = result.keypoints.conf[0].cpu().numpy() # (17,)
if kps is not None:
try: try:
lms = results.pose_landmarks.landmark
ex = EXERCISES[mode] ex = EXERCISES[mode]
new_stage, counted = ex['fn'](lms, stage) new_stage, counted = ex['fn'](kps, conf, stage)
if new_stage != stage: if new_stage != stage:
stage = new_stage stage = new_stage
beep() beep()
@@ -46,7 +54,7 @@ with mp_pose.Pose(min_detection_confidence=0.6, min_tracking_confidence=0.6) as
except Exception: except Exception:
pass pass
# --- UI --- # ── UI overlay ───────────────────────────────────────────────────────────
h, w = image.shape[:2] h, w = image.shape[:2]
ex = EXERCISES[mode] ex = EXERCISES[mode]
color = ex['color'] color = ex['color']
@@ -63,12 +71,9 @@ with mp_pose.Pose(min_detection_confidence=0.6, min_tracking_confidence=0.6) as
cv2.putText(image, HINT, (8, h - 8), cv2.putText(image, HINT, (8, h - 8),
cv2.FONT_HERSHEY_SIMPLEX, 0.48, (160, 160, 160), 1) cv2.FONT_HERSHEY_SIMPLEX, 0.48, (160, 160, 160), 1)
if results.pose_landmarks:
mp_draw.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
cv2.imshow('Exercise Counter', image) cv2.imshow('Exercise Counter', image)
key = cv2.waitKey(10) & 0xFF key = cv2.waitKey(1) & 0xFF
if key == ord('q'): if key == ord('q'):
break break
elif key == ord('r'): elif key == ord('r'):