diff --git a/brickbreaker_ppo.py b/brickbreaker_ppo.py new file mode 100644 index 0000000..0252b88 --- /dev/null +++ b/brickbreaker_ppo.py @@ -0,0 +1,297 @@ +# Brickbreaker with PPO reinforcement learning using tinygrad +from typing import Tuple +import argparse, time +import numpy as np +from tinygrad import Tensor, TinyJit, nn +from tinygrad.helpers import trange + +# --- Game Constants --- +SCREEN_W, SCREEN_H = 480, 560 +PADDLE_W, PADDLE_H = 80, 12 +PADDLE_Y = SCREEN_H - 40 +PADDLE_SPEED = 7 +BALL_RADIUS = 7 +BALL_SPEED = 5.0 +BRICK_ROWS, BRICK_COLS = 4, 8 +BRICK_W = SCREEN_W // BRICK_COLS +BRICK_H = 22 +BRICK_TOP = 60 +MAX_STEPS = 2000 +N_ACTIONS = 3 # 0=left, 1=stay, 2=right +N_BRICKS = BRICK_ROWS * BRICK_COLS +STATE_SIZE = 5 + N_BRICKS # ball x/y/vx/vy + paddle x + brick grid + +# Colors +BG = (15, 15, 30) +C_PADDLE = (80, 200, 255) +C_BALL = (255, 230, 80) +C_BRICKS = [(255, 80, 80), (255, 160, 60), (100, 220, 100), (100, 150, 255)] +C_TEXT = (200, 200, 220) + +# --- PPO Hyperparameters --- +BATCH_SIZE = 256 +ENTROPY_SCALE = 0.002 +REPLAY_BUFFER = 5000 +PPO_EPSILON = 0.2 +HIDDEN = 128 +LR = 3e-4 +TRAIN_STEPS = 5 +EPISODES = 400 +GAMMA = 0.99 + + +class BrickBreakerEnv: + def __init__(self, render: bool = False): + self.render_mode = render + self._screen = None + self.reset() + + def reset(self) -> np.ndarray: + self.paddle_x = float(SCREEN_W // 2) + self.ball_x = float(SCREEN_W // 2) + self.ball_y = float(PADDLE_Y - BALL_RADIUS - 8) + angle = np.random.uniform(np.pi * 0.3, np.pi * 0.7) + self.ball_vx = np.cos(angle) * BALL_SPEED * np.random.choice([-1, 1]) + self.ball_vy = -abs(np.sin(angle)) * BALL_SPEED + self.bricks = np.ones((BRICK_ROWS, BRICK_COLS), dtype=bool) + self.done = False + self._steps = 0 + return self._state() + + def step(self, action: int) -> Tuple[np.ndarray, float, bool]: + self._steps += 1 + reward = 0.0 + + # Paddle movement + if action == 0: + self.paddle_x = max(PADDLE_W / 2, self.paddle_x - PADDLE_SPEED) + elif action == 2: + self.paddle_x = min(SCREEN_W - PADDLE_W / 2, self.paddle_x + PADDLE_SPEED) + + # Ball physics + self.ball_x += self.ball_vx + self.ball_y += self.ball_vy + + # Side walls + if self.ball_x - BALL_RADIUS <= 0: + self.ball_x = BALL_RADIUS + self.ball_vx = abs(self.ball_vx) + elif self.ball_x + BALL_RADIUS >= SCREEN_W: + self.ball_x = SCREEN_W - BALL_RADIUS + self.ball_vx = -abs(self.ball_vx) + + # Ceiling + if self.ball_y - BALL_RADIUS <= 0: + self.ball_y = BALL_RADIUS + self.ball_vy = abs(self.ball_vy) + + # Paddle collision + if (self.ball_vy > 0 + and PADDLE_Y - PADDLE_H // 2 - BALL_RADIUS <= self.ball_y <= PADDLE_Y + PADDLE_H // 2 + and abs(self.ball_x - self.paddle_x) <= PADDLE_W / 2 + BALL_RADIUS): + self.ball_y = PADDLE_Y - PADDLE_H // 2 - BALL_RADIUS + offset = (self.ball_x - self.paddle_x) / (PADDLE_W / 2) # -1..1 + speed = np.hypot(self.ball_vx, self.ball_vy) + self.ball_vx = np.clip(offset * BALL_SPEED * 1.4, -BALL_SPEED * 1.4, BALL_SPEED * 1.4) + self.ball_vy = -np.sqrt(max(speed**2 - self.ball_vx**2, 0.1)) + + # Brick collisions + for row in range(BRICK_ROWS): + for col in range(BRICK_COLS): + if not self.bricks[row, col]: + continue + bx1 = col * BRICK_W + bx2 = bx1 + BRICK_W + by1 = BRICK_TOP + row * BRICK_H + by2 = by1 + BRICK_H + if bx1 <= self.ball_x <= bx2 and by1 - BALL_RADIUS <= self.ball_y <= by2 + BALL_RADIUS: + self.bricks[row, col] = False + reward += 1.0 + # bounce direction: hit top/bottom → flip vy, hit sides → flip vx + overlap_y = min(abs(self.ball_y - by1), abs(self.ball_y - by2)) + overlap_x = min(abs(self.ball_x - bx1), abs(self.ball_x - bx2)) + if overlap_y <= overlap_x: + self.ball_vy = -self.ball_vy + else: + self.ball_vx = -self.ball_vx + break + + # Ball out of bounds + if self.ball_y - BALL_RADIUS > SCREEN_H: + reward -= 5.0 + self.done = True + + # Level cleared + if not self.bricks.any(): + reward += 50.0 + self.done = True + + if self._steps >= MAX_STEPS: + self.done = True + + return self._state(), reward, self.done + + def _state(self) -> np.ndarray: + return np.array([ + self.ball_x / SCREEN_W, + self.ball_y / SCREEN_H, + self.ball_vx / BALL_SPEED, + self.ball_vy / BALL_SPEED, + self.paddle_x / SCREEN_W, + *self.bricks.flatten().astype(np.float32), + ], dtype=np.float32) + + def render(self, episode: int = 0, total_reward: float = 0.0) -> bool: + import pygame + if self._screen is None: + pygame.init() + self._screen = pygame.display.set_mode((SCREEN_W, SCREEN_H)) + pygame.display.set_caption("Brickbreaker — tinygrad PPO") + self._font = pygame.font.SysFont("monospace", 14) + self._clock = pygame.time.Clock() + + for event in pygame.event.get(): + if event.type == pygame.QUIT: + return False + + self._screen.fill(BG) + + # Bricks + for row in range(BRICK_ROWS): + for col in range(BRICK_COLS): + if self.bricks[row, col]: + r = pygame.Rect(col * BRICK_W + 2, BRICK_TOP + row * BRICK_H + 2, BRICK_W - 4, BRICK_H - 4) + pygame.draw.rect(self._screen, C_BRICKS[row % len(C_BRICKS)], r, border_radius=3) + + # Paddle + pr = pygame.Rect(int(self.paddle_x - PADDLE_W // 2), PADDLE_Y - PADDLE_H // 2, PADDLE_W, PADDLE_H) + pygame.draw.rect(self._screen, C_PADDLE, pr, border_radius=6) + + # Ball + pygame.draw.circle(self._screen, C_BALL, (int(self.ball_x), int(self.ball_y)), BALL_RADIUS) + + # HUD + bricks_left = int(self.bricks.sum()) + txt = self._font.render(f"ep:{episode} bricks:{bricks_left:2d}/{N_BRICKS} reward:{total_reward:+.0f}", True, C_TEXT) + self._screen.blit(txt, (8, 8)) + + pygame.display.flip() + self._clock.tick(60) + return True + + +class ActorCritic: + def __init__(self, in_features: int, out_features: int): + self.a1 = nn.Linear(in_features, HIDDEN) + self.a2 = nn.Linear(HIDDEN, HIDDEN // 2) + self.a3 = nn.Linear(HIDDEN // 2, out_features) + + self.c1 = nn.Linear(in_features, HIDDEN) + self.c2 = nn.Linear(HIDDEN, HIDDEN // 2) + self.c3 = nn.Linear(HIDDEN // 2, 1) + + def __call__(self, obs: Tensor) -> Tuple[Tensor, Tensor]: + a = self.a1(obs).relu() + a = self.a2(a).relu() + act = self.a3(a).log_softmax() + c = self.c1(obs).relu() + c = self.c2(c).relu() + return act, self.c3(c) + + +def watch(model: ActorCritic, episode: int, n_games: int = 3) -> None: + env = BrickBreakerEnv(render=True) + for g in range(n_games): + obs = env.reset() + done, total_rew = False, 0.0 + while not done: + if not env.render(episode=episode, total_reward=total_rew): + return + act = model(Tensor(obs))[0].argmax().item() + obs, rew, done = env.step(act) + total_rew += rew + env.render(episode=episode, total_reward=total_rew) + time.sleep(1.5) + import pygame; pygame.quit() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--episodes", type=int, default=EPISODES) + parser.add_argument("--watch-every", type=int, default=0, help="render a game every N episodes during training (0=off)") + parser.add_argument("--watch-only", action="store_true", help="skip training, just watch a random agent") + args = parser.parse_args() + + model = ActorCritic(STATE_SIZE, N_ACTIONS) + opt = nn.optim.Adam(nn.state.get_parameters(model), lr=LR) + env = BrickBreakerEnv(render=False) + + @TinyJit + def train_step(x: Tensor, a: Tensor, r: Tensor, old_log: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + with Tensor.train(): + log_dist, value = model(x) + mask = (a.reshape(-1, 1) == Tensor.arange(N_ACTIONS).reshape(1, -1).expand(a.shape[0], -1)).float() + advantage = r.reshape(-1, 1) - value + masked_adv = mask * advantage.detach() + ratios = (log_dist - old_log).exp() + action_loss = -(masked_adv * ratios.clip(1 - PPO_EPSILON, 1 + PPO_EPSILON)).sum(-1).mean() + entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean() + critic_loss = advantage.square().mean() + opt.zero_grad() + (action_loss + entropy_loss * ENTROPY_SCALE + critic_loss).backward() + opt.step() + return action_loss.realize(), entropy_loss.realize(), critic_loss.realize() + + @TinyJit + def get_action(obs: Tensor) -> Tensor: + return model(obs)[0].exp().multinomial().realize() + + if args.watch_only: + watch(model, episode=0) + exit(0) + + Xn, An, Rn = [], [], [] + t_start, total_steps = time.perf_counter(), 0 + best_reward = -float("inf") + + for ep in (pbar := trange(args.episodes)): + get_action.reset() + + obs = env.reset() + rews = [] + done = False + while not done: + act = get_action(Tensor(obs)).item() + Xn.append(np.copy(obs)) + An.append(act) + obs, rew, done = env.step(act) + rews.append(rew) + total_steps += len(rews) + + # Discounted reward-to-go + rews_arr = np.array(rews) + discounts = np.power(GAMMA, np.arange(len(rews_arr))) + Rn += [float(np.sum(rews_arr[i:] * discounts[:len(rews_arr) - i])) for i in range(len(rews_arr))] + + Xn, An, Rn = Xn[-REPLAY_BUFFER:], An[-REPLAY_BUFFER:], Rn[-REPLAY_BUFFER:] + + if len(Xn) >= BATCH_SIZE: + X, A, R = Tensor(np.array(Xn)), Tensor(An), Tensor(Rn) + old_log = model(X)[0].detach() + for _ in range(TRAIN_STEPS): + samples = Tensor.randint(BATCH_SIZE, high=X.shape[0]).realize() + a_loss, e_loss, c_loss = train_step(X[samples], A[samples], R[samples], old_log[samples]) + + ep_reward = sum(rews) + best_reward = max(best_reward, ep_reward) + bricks_hit = int(N_BRICKS - env.bricks.sum()) + sps = total_steps / (time.perf_counter() - t_start) + pbar.set_description( + f"ep:{ep:4d} sps:{sps:6.0f} bricks:{bricks_hit:2d}/{N_BRICKS} rew:{ep_reward:+6.1f} best:{best_reward:+6.1f}" + ) + + if args.watch_every > 0 and (ep + 1) % args.watch_every == 0: + watch(model, episode=ep + 1, n_games=1) + + print(f"\nTraining done. Watching agent play...") + watch(model, episode=args.episodes, n_games=3)