Files
tinygrad-apps/brickbreaker_ppo.py
2026-05-15 11:33:20 +01:00

298 lines
9.9 KiB
Python

# Brickbreaker with PPO reinforcement learning using tinygrad
from typing import Tuple
import argparse, time
import numpy as np
from tinygrad import Tensor, TinyJit, nn
from tinygrad.helpers import trange
# --- Game Constants ---
SCREEN_W, SCREEN_H = 480, 560
PADDLE_W, PADDLE_H = 80, 12
PADDLE_Y = SCREEN_H - 40
PADDLE_SPEED = 7
BALL_RADIUS = 7
BALL_SPEED = 5.0
BRICK_ROWS, BRICK_COLS = 4, 8
BRICK_W = SCREEN_W // BRICK_COLS
BRICK_H = 22
BRICK_TOP = 60
MAX_STEPS = 2000
N_ACTIONS = 3 # 0=left, 1=stay, 2=right
N_BRICKS = BRICK_ROWS * BRICK_COLS
STATE_SIZE = 5 + N_BRICKS # ball x/y/vx/vy + paddle x + brick grid
# Colors
BG = (15, 15, 30)
C_PADDLE = (80, 200, 255)
C_BALL = (255, 230, 80)
C_BRICKS = [(255, 80, 80), (255, 160, 60), (100, 220, 100), (100, 150, 255)]
C_TEXT = (200, 200, 220)
# --- PPO Hyperparameters ---
BATCH_SIZE = 256
ENTROPY_SCALE = 0.002
REPLAY_BUFFER = 5000
PPO_EPSILON = 0.2
HIDDEN = 128
LR = 3e-4
TRAIN_STEPS = 5
EPISODES = 400
GAMMA = 0.99
class BrickBreakerEnv:
def __init__(self, render: bool = False):
self.render_mode = render
self._screen = None
self.reset()
def reset(self) -> np.ndarray:
self.paddle_x = float(SCREEN_W // 2)
self.ball_x = float(SCREEN_W // 2)
self.ball_y = float(PADDLE_Y - BALL_RADIUS - 8)
angle = np.random.uniform(np.pi * 0.3, np.pi * 0.7)
self.ball_vx = np.cos(angle) * BALL_SPEED * np.random.choice([-1, 1])
self.ball_vy = -abs(np.sin(angle)) * BALL_SPEED
self.bricks = np.ones((BRICK_ROWS, BRICK_COLS), dtype=bool)
self.done = False
self._steps = 0
return self._state()
def step(self, action: int) -> Tuple[np.ndarray, float, bool]:
self._steps += 1
reward = 0.0
# Paddle movement
if action == 0:
self.paddle_x = max(PADDLE_W / 2, self.paddle_x - PADDLE_SPEED)
elif action == 2:
self.paddle_x = min(SCREEN_W - PADDLE_W / 2, self.paddle_x + PADDLE_SPEED)
# Ball physics
self.ball_x += self.ball_vx
self.ball_y += self.ball_vy
# Side walls
if self.ball_x - BALL_RADIUS <= 0:
self.ball_x = BALL_RADIUS
self.ball_vx = abs(self.ball_vx)
elif self.ball_x + BALL_RADIUS >= SCREEN_W:
self.ball_x = SCREEN_W - BALL_RADIUS
self.ball_vx = -abs(self.ball_vx)
# Ceiling
if self.ball_y - BALL_RADIUS <= 0:
self.ball_y = BALL_RADIUS
self.ball_vy = abs(self.ball_vy)
# Paddle collision
if (self.ball_vy > 0
and PADDLE_Y - PADDLE_H // 2 - BALL_RADIUS <= self.ball_y <= PADDLE_Y + PADDLE_H // 2
and abs(self.ball_x - self.paddle_x) <= PADDLE_W / 2 + BALL_RADIUS):
self.ball_y = PADDLE_Y - PADDLE_H // 2 - BALL_RADIUS
offset = (self.ball_x - self.paddle_x) / (PADDLE_W / 2) # -1..1
speed = np.hypot(self.ball_vx, self.ball_vy)
self.ball_vx = np.clip(offset * BALL_SPEED * 1.4, -BALL_SPEED * 1.4, BALL_SPEED * 1.4)
self.ball_vy = -np.sqrt(max(speed**2 - self.ball_vx**2, 0.1))
# Brick collisions
for row in range(BRICK_ROWS):
for col in range(BRICK_COLS):
if not self.bricks[row, col]:
continue
bx1 = col * BRICK_W
bx2 = bx1 + BRICK_W
by1 = BRICK_TOP + row * BRICK_H
by2 = by1 + BRICK_H
if bx1 <= self.ball_x <= bx2 and by1 - BALL_RADIUS <= self.ball_y <= by2 + BALL_RADIUS:
self.bricks[row, col] = False
reward += 1.0
# bounce direction: hit top/bottom → flip vy, hit sides → flip vx
overlap_y = min(abs(self.ball_y - by1), abs(self.ball_y - by2))
overlap_x = min(abs(self.ball_x - bx1), abs(self.ball_x - bx2))
if overlap_y <= overlap_x:
self.ball_vy = -self.ball_vy
else:
self.ball_vx = -self.ball_vx
break
# Ball out of bounds
if self.ball_y - BALL_RADIUS > SCREEN_H:
reward -= 5.0
self.done = True
# Level cleared
if not self.bricks.any():
reward += 50.0
self.done = True
if self._steps >= MAX_STEPS:
self.done = True
return self._state(), reward, self.done
def _state(self) -> np.ndarray:
return np.array([
self.ball_x / SCREEN_W,
self.ball_y / SCREEN_H,
self.ball_vx / BALL_SPEED,
self.ball_vy / BALL_SPEED,
self.paddle_x / SCREEN_W,
*self.bricks.flatten().astype(np.float32),
], dtype=np.float32)
def render(self, episode: int = 0, total_reward: float = 0.0) -> bool:
import pygame
if self._screen is None:
pygame.init()
self._screen = pygame.display.set_mode((SCREEN_W, SCREEN_H))
pygame.display.set_caption("Brickbreaker — tinygrad PPO")
self._font = pygame.font.SysFont("monospace", 14)
self._clock = pygame.time.Clock()
for event in pygame.event.get():
if event.type == pygame.QUIT:
return False
self._screen.fill(BG)
# Bricks
for row in range(BRICK_ROWS):
for col in range(BRICK_COLS):
if self.bricks[row, col]:
r = pygame.Rect(col * BRICK_W + 2, BRICK_TOP + row * BRICK_H + 2, BRICK_W - 4, BRICK_H - 4)
pygame.draw.rect(self._screen, C_BRICKS[row % len(C_BRICKS)], r, border_radius=3)
# Paddle
pr = pygame.Rect(int(self.paddle_x - PADDLE_W // 2), PADDLE_Y - PADDLE_H // 2, PADDLE_W, PADDLE_H)
pygame.draw.rect(self._screen, C_PADDLE, pr, border_radius=6)
# Ball
pygame.draw.circle(self._screen, C_BALL, (int(self.ball_x), int(self.ball_y)), BALL_RADIUS)
# HUD
bricks_left = int(self.bricks.sum())
txt = self._font.render(f"ep:{episode} bricks:{bricks_left:2d}/{N_BRICKS} reward:{total_reward:+.0f}", True, C_TEXT)
self._screen.blit(txt, (8, 8))
pygame.display.flip()
self._clock.tick(60)
return True
class ActorCritic:
def __init__(self, in_features: int, out_features: int):
self.a1 = nn.Linear(in_features, HIDDEN)
self.a2 = nn.Linear(HIDDEN, HIDDEN // 2)
self.a3 = nn.Linear(HIDDEN // 2, out_features)
self.c1 = nn.Linear(in_features, HIDDEN)
self.c2 = nn.Linear(HIDDEN, HIDDEN // 2)
self.c3 = nn.Linear(HIDDEN // 2, 1)
def __call__(self, obs: Tensor) -> Tuple[Tensor, Tensor]:
a = self.a1(obs).relu()
a = self.a2(a).relu()
act = self.a3(a).log_softmax()
c = self.c1(obs).relu()
c = self.c2(c).relu()
return act, self.c3(c)
def watch(model: ActorCritic, episode: int, n_games: int = 3) -> None:
env = BrickBreakerEnv(render=True)
for g in range(n_games):
obs = env.reset()
done, total_rew = False, 0.0
while not done:
if not env.render(episode=episode, total_reward=total_rew):
return
act = model(Tensor(obs))[0].argmax().item()
obs, rew, done = env.step(act)
total_rew += rew
env.render(episode=episode, total_reward=total_rew)
time.sleep(1.5)
import pygame; pygame.quit()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--episodes", type=int, default=EPISODES)
parser.add_argument("--watch-every", type=int, default=0, help="render a game every N episodes during training (0=off)")
parser.add_argument("--watch-only", action="store_true", help="skip training, just watch a random agent")
args = parser.parse_args()
model = ActorCritic(STATE_SIZE, N_ACTIONS)
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=LR)
env = BrickBreakerEnv(render=False)
@TinyJit
def train_step(x: Tensor, a: Tensor, r: Tensor, old_log: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
with Tensor.train():
log_dist, value = model(x)
mask = (a.reshape(-1, 1) == Tensor.arange(N_ACTIONS).reshape(1, -1).expand(a.shape[0], -1)).float()
advantage = r.reshape(-1, 1) - value
masked_adv = mask * advantage.detach()
ratios = (log_dist - old_log).exp()
action_loss = -(masked_adv * ratios.clip(1 - PPO_EPSILON, 1 + PPO_EPSILON)).sum(-1).mean()
entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean()
critic_loss = advantage.square().mean()
opt.zero_grad()
(action_loss + entropy_loss * ENTROPY_SCALE + critic_loss).backward()
opt.step()
return action_loss.realize(), entropy_loss.realize(), critic_loss.realize()
@TinyJit
def get_action(obs: Tensor) -> Tensor:
return model(obs)[0].exp().multinomial().realize()
if args.watch_only:
watch(model, episode=0)
exit(0)
Xn, An, Rn = [], [], []
t_start, total_steps = time.perf_counter(), 0
best_reward = -float("inf")
for ep in (pbar := trange(args.episodes)):
get_action.reset()
obs = env.reset()
rews = []
done = False
while not done:
act = get_action(Tensor(obs)).item()
Xn.append(np.copy(obs))
An.append(act)
obs, rew, done = env.step(act)
rews.append(rew)
total_steps += len(rews)
# Discounted reward-to-go
rews_arr = np.array(rews)
discounts = np.power(GAMMA, np.arange(len(rews_arr)))
Rn += [float(np.sum(rews_arr[i:] * discounts[:len(rews_arr) - i])) for i in range(len(rews_arr))]
Xn, An, Rn = Xn[-REPLAY_BUFFER:], An[-REPLAY_BUFFER:], Rn[-REPLAY_BUFFER:]
if len(Xn) >= BATCH_SIZE:
X, A, R = Tensor(np.array(Xn)), Tensor(An), Tensor(Rn)
old_log = model(X)[0].detach()
for _ in range(TRAIN_STEPS):
samples = Tensor.randint(BATCH_SIZE, high=X.shape[0]).realize()
a_loss, e_loss, c_loss = train_step(X[samples], A[samples], R[samples], old_log[samples])
ep_reward = sum(rews)
best_reward = max(best_reward, ep_reward)
bricks_hit = int(N_BRICKS - env.bricks.sum())
sps = total_steps / (time.perf_counter() - t_start)
pbar.set_description(
f"ep:{ep:4d} sps:{sps:6.0f} bricks:{bricks_hit:2d}/{N_BRICKS} rew:{ep_reward:+6.1f} best:{best_reward:+6.1f}"
)
if args.watch_every > 0 and (ep + 1) % args.watch_every == 0:
watch(model, episode=ep + 1, n_games=1)
print(f"\nTraining done. Watching agent play...")
watch(model, episode=args.episodes, n_games=3)