diff --git a/main.py b/main.py index a65708d..55fa1b5 100644 --- a/main.py +++ b/main.py @@ -1,43 +1,51 @@ import os import subprocess import cv2 -import mediapipe as mp +from ultralytics import YOLO from exercises import EXERCISES -mp_pose = mp.solutions.pose -mp_draw = mp.solutions.drawing_utils +BEEP_FILE = os.path.join(os.path.dirname(__file__), 'beep.mp3') +MODEL_PATH = os.path.join(os.path.dirname(__file__), 'yolo11x-pose.pt') + +HINT = '[P]ush [U]pull [B]ench [C]url [S]itup [L]plank [R]eset [Q]uit' -BEEP_FILE = os.path.join(os.path.dirname(__file__), 'beep.mp3') def beep(): subprocess.Popen(['afplay', BEEP_FILE], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) -cap = cv2.VideoCapture(0) +model = YOLO(MODEL_PATH) + +cap = cv2.VideoCapture(0) count = 0 stage = None -mode = 'p' # default: push-ups +mode = 'p' -HINT = '[P]ush [U]pull [B]ench [C]url [S]itup [L]plank [R]eset [Q]uit' +while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break -with mp_pose.Pose(min_detection_confidence=0.6, min_tracking_confidence=0.6) as pose: - while cap.isOpened(): - ret, frame = cap.read() - if not ret: - break + results = model(frame, verbose=False) + result = results[0] - image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - image.flags.writeable = False - results = pose.process(image) - image.flags.writeable = True - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + # Draw YOLO skeleton on a copy of the frame (no bounding boxes or conf labels) + image = result.plot(boxes=False, conf=False, labels=False) + # Extract the first detected person's keypoints + kps = conf = None + if (result.keypoints is not None + and result.keypoints.xyn is not None + and len(result.keypoints.xyn) > 0): + kps = result.keypoints.xyn[0].cpu().numpy() # (17, 2) normalised + conf = result.keypoints.conf[0].cpu().numpy() # (17,) + + if kps is not None: try: - lms = results.pose_landmarks.landmark - ex = EXERCISES[mode] - new_stage, counted = ex['fn'](lms, stage) + ex = EXERCISES[mode] + new_stage, counted = ex['fn'](kps, conf, stage) if new_stage != stage: stage = new_stage beep() @@ -46,38 +54,35 @@ with mp_pose.Pose(min_detection_confidence=0.6, min_tracking_confidence=0.6) as except Exception: pass - # --- UI --- - h, w = image.shape[:2] - ex = EXERCISES[mode] - color = ex['color'] + # ── UI overlay ─────────────────────────────────────────────────────────── + h, w = image.shape[:2] + ex = EXERCISES[mode] + color = ex['color'] - cv2.rectangle(image, (0, 0), (w, 80), (15, 15, 15), -1) - cv2.putText(image, ex['name'], (12, 30), - cv2.FONT_HERSHEY_SIMPLEX, 0.85, color, 2) - cv2.putText(image, f'{ex["unit"]}: {count}', (12, 68), - cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 255, 0), 3) - cv2.putText(image, (stage or '---').upper(), (w - 108, 50), - cv2.FONT_HERSHEY_SIMPLEX, 0.9, (180, 180, 0), 2) + cv2.rectangle(image, (0, 0), (w, 80), (15, 15, 15), -1) + cv2.putText(image, ex['name'], (12, 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.85, color, 2) + cv2.putText(image, f'{ex["unit"]}: {count}', (12, 68), + cv2.FONT_HERSHEY_SIMPLEX, 1.3, (0, 255, 0), 3) + cv2.putText(image, (stage or '---').upper(), (w - 108, 50), + cv2.FONT_HERSHEY_SIMPLEX, 0.9, (180, 180, 0), 2) - cv2.rectangle(image, (0, h - 30), (w, h), (15, 15, 15), -1) - cv2.putText(image, HINT, (8, h - 8), - cv2.FONT_HERSHEY_SIMPLEX, 0.48, (160, 160, 160), 1) + cv2.rectangle(image, (0, h - 30), (w, h), (15, 15, 15), -1) + cv2.putText(image, HINT, (8, h - 8), + cv2.FONT_HERSHEY_SIMPLEX, 0.48, (160, 160, 160), 1) - if results.pose_landmarks: - mp_draw.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) + cv2.imshow('Exercise Counter', image) - cv2.imshow('Exercise Counter', image) - - key = cv2.waitKey(10) & 0xFF - if key == ord('q'): - break - elif key == ord('r'): - count = 0 - stage = None - elif key != 255 and chr(key) in EXERCISES: - mode = chr(key) - stage = None - count = 0 + key = cv2.waitKey(1) & 0xFF + if key == ord('q'): + break + elif key == ord('r'): + count = 0 + stage = None + elif key != 255 and chr(key) in EXERCISES: + mode = chr(key) + stage = None + count = 0 cap.release() cv2.destroyAllWindows()