# Brickbreaker with PPO reinforcement learning using tinygrad from typing import Tuple import argparse, time import numpy as np from tinygrad import Tensor, TinyJit, nn from tinygrad.helpers import trange # --- Game Constants --- SCREEN_W, SCREEN_H = 480, 560 PADDLE_W, PADDLE_H = 80, 12 PADDLE_Y = SCREEN_H - 40 PADDLE_SPEED = 7 BALL_RADIUS = 7 BALL_SPEED = 5.0 BRICK_ROWS, BRICK_COLS = 4, 8 BRICK_W = SCREEN_W // BRICK_COLS BRICK_H = 22 BRICK_TOP = 60 MAX_STEPS = 2000 N_ACTIONS = 3 # 0=left, 1=stay, 2=right N_BRICKS = BRICK_ROWS * BRICK_COLS STATE_SIZE = 5 + N_BRICKS # ball x/y/vx/vy + paddle x + brick grid # Colors BG = (15, 15, 30) C_PADDLE = (80, 200, 255) C_BALL = (255, 230, 80) C_BRICKS = [(255, 80, 80), (255, 160, 60), (100, 220, 100), (100, 150, 255)] C_TEXT = (200, 200, 220) # --- PPO Hyperparameters --- BATCH_SIZE = 256 ENTROPY_SCALE = 0.002 REPLAY_BUFFER = 5000 PPO_EPSILON = 0.2 HIDDEN = 128 LR = 3e-4 TRAIN_STEPS = 5 EPISODES = 400 GAMMA = 0.99 class BrickBreakerEnv: def __init__(self, render: bool = False): self.render_mode = render self._screen = None self.reset() def reset(self) -> np.ndarray: self.paddle_x = float(SCREEN_W // 2) self.ball_x = float(SCREEN_W // 2) self.ball_y = float(PADDLE_Y - BALL_RADIUS - 8) angle = np.random.uniform(np.pi * 0.3, np.pi * 0.7) self.ball_vx = np.cos(angle) * BALL_SPEED * np.random.choice([-1, 1]) self.ball_vy = -abs(np.sin(angle)) * BALL_SPEED self.bricks = np.ones((BRICK_ROWS, BRICK_COLS), dtype=bool) self.done = False self._steps = 0 return self._state() def step(self, action: int) -> Tuple[np.ndarray, float, bool]: self._steps += 1 reward = 0.0 # Paddle movement if action == 0: self.paddle_x = max(PADDLE_W / 2, self.paddle_x - PADDLE_SPEED) elif action == 2: self.paddle_x = min(SCREEN_W - PADDLE_W / 2, self.paddle_x + PADDLE_SPEED) # Ball physics self.ball_x += self.ball_vx self.ball_y += self.ball_vy # Side walls if self.ball_x - BALL_RADIUS <= 0: self.ball_x = BALL_RADIUS self.ball_vx = abs(self.ball_vx) elif self.ball_x + BALL_RADIUS >= SCREEN_W: self.ball_x = SCREEN_W - BALL_RADIUS self.ball_vx = -abs(self.ball_vx) # Ceiling if self.ball_y - BALL_RADIUS <= 0: self.ball_y = BALL_RADIUS self.ball_vy = abs(self.ball_vy) # Paddle collision if (self.ball_vy > 0 and PADDLE_Y - PADDLE_H // 2 - BALL_RADIUS <= self.ball_y <= PADDLE_Y + PADDLE_H // 2 and abs(self.ball_x - self.paddle_x) <= PADDLE_W / 2 + BALL_RADIUS): self.ball_y = PADDLE_Y - PADDLE_H // 2 - BALL_RADIUS offset = (self.ball_x - self.paddle_x) / (PADDLE_W / 2) # -1..1 speed = np.hypot(self.ball_vx, self.ball_vy) self.ball_vx = np.clip(offset * BALL_SPEED * 1.4, -BALL_SPEED * 1.4, BALL_SPEED * 1.4) self.ball_vy = -np.sqrt(max(speed**2 - self.ball_vx**2, 0.1)) # Brick collisions for row in range(BRICK_ROWS): for col in range(BRICK_COLS): if not self.bricks[row, col]: continue bx1 = col * BRICK_W bx2 = bx1 + BRICK_W by1 = BRICK_TOP + row * BRICK_H by2 = by1 + BRICK_H if bx1 <= self.ball_x <= bx2 and by1 - BALL_RADIUS <= self.ball_y <= by2 + BALL_RADIUS: self.bricks[row, col] = False reward += 1.0 # bounce direction: hit top/bottom → flip vy, hit sides → flip vx overlap_y = min(abs(self.ball_y - by1), abs(self.ball_y - by2)) overlap_x = min(abs(self.ball_x - bx1), abs(self.ball_x - bx2)) if overlap_y <= overlap_x: self.ball_vy = -self.ball_vy else: self.ball_vx = -self.ball_vx break # Ball out of bounds if self.ball_y - BALL_RADIUS > SCREEN_H: reward -= 5.0 self.done = True # Level cleared if not self.bricks.any(): reward += 50.0 self.done = True if self._steps >= MAX_STEPS: self.done = True return self._state(), reward, self.done def _state(self) -> np.ndarray: return np.array([ self.ball_x / SCREEN_W, self.ball_y / SCREEN_H, self.ball_vx / BALL_SPEED, self.ball_vy / BALL_SPEED, self.paddle_x / SCREEN_W, *self.bricks.flatten().astype(np.float32), ], dtype=np.float32) def render(self, episode: int = 0, total_reward: float = 0.0) -> bool: import pygame if self._screen is None: pygame.init() self._screen = pygame.display.set_mode((SCREEN_W, SCREEN_H)) pygame.display.set_caption("Brickbreaker — tinygrad PPO") self._font = pygame.font.SysFont("monospace", 14) self._clock = pygame.time.Clock() for event in pygame.event.get(): if event.type == pygame.QUIT: return False self._screen.fill(BG) # Bricks for row in range(BRICK_ROWS): for col in range(BRICK_COLS): if self.bricks[row, col]: r = pygame.Rect(col * BRICK_W + 2, BRICK_TOP + row * BRICK_H + 2, BRICK_W - 4, BRICK_H - 4) pygame.draw.rect(self._screen, C_BRICKS[row % len(C_BRICKS)], r, border_radius=3) # Paddle pr = pygame.Rect(int(self.paddle_x - PADDLE_W // 2), PADDLE_Y - PADDLE_H // 2, PADDLE_W, PADDLE_H) pygame.draw.rect(self._screen, C_PADDLE, pr, border_radius=6) # Ball pygame.draw.circle(self._screen, C_BALL, (int(self.ball_x), int(self.ball_y)), BALL_RADIUS) # HUD bricks_left = int(self.bricks.sum()) txt = self._font.render(f"ep:{episode} bricks:{bricks_left:2d}/{N_BRICKS} reward:{total_reward:+.0f}", True, C_TEXT) self._screen.blit(txt, (8, 8)) pygame.display.flip() self._clock.tick(60) return True class ActorCritic: def __init__(self, in_features: int, out_features: int): self.a1 = nn.Linear(in_features, HIDDEN) self.a2 = nn.Linear(HIDDEN, HIDDEN // 2) self.a3 = nn.Linear(HIDDEN // 2, out_features) self.c1 = nn.Linear(in_features, HIDDEN) self.c2 = nn.Linear(HIDDEN, HIDDEN // 2) self.c3 = nn.Linear(HIDDEN // 2, 1) def __call__(self, obs: Tensor) -> Tuple[Tensor, Tensor]: a = self.a1(obs).relu() a = self.a2(a).relu() act = self.a3(a).log_softmax() c = self.c1(obs).relu() c = self.c2(c).relu() return act, self.c3(c) def watch(model: ActorCritic, episode: int, n_games: int = 3) -> None: env = BrickBreakerEnv(render=True) for g in range(n_games): obs = env.reset() done, total_rew = False, 0.0 while not done: if not env.render(episode=episode, total_reward=total_rew): return act = model(Tensor(obs))[0].argmax().item() obs, rew, done = env.step(act) total_rew += rew env.render(episode=episode, total_reward=total_rew) time.sleep(1.5) import pygame; pygame.quit() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--episodes", type=int, default=EPISODES) parser.add_argument("--watch-every", type=int, default=0, help="render a game every N episodes during training (0=off)") parser.add_argument("--watch-only", action="store_true", help="skip training, just watch a random agent") args = parser.parse_args() model = ActorCritic(STATE_SIZE, N_ACTIONS) opt = nn.optim.Adam(nn.state.get_parameters(model), lr=LR) env = BrickBreakerEnv(render=False) @TinyJit def train_step(x: Tensor, a: Tensor, r: Tensor, old_log: Tensor) -> Tuple[Tensor, Tensor, Tensor]: with Tensor.train(): log_dist, value = model(x) mask = (a.reshape(-1, 1) == Tensor.arange(N_ACTIONS).reshape(1, -1).expand(a.shape[0], -1)).float() advantage = r.reshape(-1, 1) - value masked_adv = mask * advantage.detach() ratios = (log_dist - old_log).exp() action_loss = -(masked_adv * ratios.clip(1 - PPO_EPSILON, 1 + PPO_EPSILON)).sum(-1).mean() entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean() critic_loss = advantage.square().mean() opt.zero_grad() (action_loss + entropy_loss * ENTROPY_SCALE + critic_loss).backward() opt.step() return action_loss.realize(), entropy_loss.realize(), critic_loss.realize() @TinyJit def get_action(obs: Tensor) -> Tensor: return model(obs)[0].exp().multinomial().realize() if args.watch_only: watch(model, episode=0) exit(0) Xn, An, Rn = [], [], [] t_start, total_steps = time.perf_counter(), 0 best_reward = -float("inf") for ep in (pbar := trange(args.episodes)): get_action.reset() obs = env.reset() rews = [] done = False while not done: act = get_action(Tensor(obs)).item() Xn.append(np.copy(obs)) An.append(act) obs, rew, done = env.step(act) rews.append(rew) total_steps += len(rews) # Discounted reward-to-go rews_arr = np.array(rews) discounts = np.power(GAMMA, np.arange(len(rews_arr))) Rn += [float(np.sum(rews_arr[i:] * discounts[:len(rews_arr) - i])) for i in range(len(rews_arr))] Xn, An, Rn = Xn[-REPLAY_BUFFER:], An[-REPLAY_BUFFER:], Rn[-REPLAY_BUFFER:] if len(Xn) >= BATCH_SIZE: X, A, R = Tensor(np.array(Xn)), Tensor(An), Tensor(Rn) old_log = model(X)[0].detach() for _ in range(TRAIN_STEPS): samples = Tensor.randint(BATCH_SIZE, high=X.shape[0]).realize() a_loss, e_loss, c_loss = train_step(X[samples], A[samples], R[samples], old_log[samples]) ep_reward = sum(rews) best_reward = max(best_reward, ep_reward) bricks_hit = int(N_BRICKS - env.bricks.sum()) sps = total_steps / (time.perf_counter() - t_start) pbar.set_description( f"ep:{ep:4d} sps:{sps:6.0f} bricks:{bricks_hit:2d}/{N_BRICKS} rew:{ep_reward:+6.1f} best:{best_reward:+6.1f}" ) if args.watch_every > 0 and (ep + 1) % args.watch_every == 0: watch(model, episode=ep + 1, n_games=1) print(f"\nTraining done. Watching agent play...") watch(model, episode=args.episodes, n_games=3)