298 lines
9.9 KiB
Python
298 lines
9.9 KiB
Python
# Brickbreaker with PPO reinforcement learning using tinygrad
|
|
from typing import Tuple
|
|
import argparse, time
|
|
import numpy as np
|
|
from tinygrad import Tensor, TinyJit, nn
|
|
from tinygrad.helpers import trange
|
|
|
|
# --- Game Constants ---
|
|
SCREEN_W, SCREEN_H = 480, 560
|
|
PADDLE_W, PADDLE_H = 80, 12
|
|
PADDLE_Y = SCREEN_H - 40
|
|
PADDLE_SPEED = 7
|
|
BALL_RADIUS = 7
|
|
BALL_SPEED = 5.0
|
|
BRICK_ROWS, BRICK_COLS = 4, 8
|
|
BRICK_W = SCREEN_W // BRICK_COLS
|
|
BRICK_H = 22
|
|
BRICK_TOP = 60
|
|
MAX_STEPS = 2000
|
|
N_ACTIONS = 3 # 0=left, 1=stay, 2=right
|
|
N_BRICKS = BRICK_ROWS * BRICK_COLS
|
|
STATE_SIZE = 5 + N_BRICKS # ball x/y/vx/vy + paddle x + brick grid
|
|
|
|
# Colors
|
|
BG = (15, 15, 30)
|
|
C_PADDLE = (80, 200, 255)
|
|
C_BALL = (255, 230, 80)
|
|
C_BRICKS = [(255, 80, 80), (255, 160, 60), (100, 220, 100), (100, 150, 255)]
|
|
C_TEXT = (200, 200, 220)
|
|
|
|
# --- PPO Hyperparameters ---
|
|
BATCH_SIZE = 256
|
|
ENTROPY_SCALE = 0.002
|
|
REPLAY_BUFFER = 5000
|
|
PPO_EPSILON = 0.2
|
|
HIDDEN = 128
|
|
LR = 3e-4
|
|
TRAIN_STEPS = 5
|
|
EPISODES = 400
|
|
GAMMA = 0.99
|
|
|
|
|
|
class BrickBreakerEnv:
|
|
def __init__(self, render: bool = False):
|
|
self.render_mode = render
|
|
self._screen = None
|
|
self.reset()
|
|
|
|
def reset(self) -> np.ndarray:
|
|
self.paddle_x = float(SCREEN_W // 2)
|
|
self.ball_x = float(SCREEN_W // 2)
|
|
self.ball_y = float(PADDLE_Y - BALL_RADIUS - 8)
|
|
angle = np.random.uniform(np.pi * 0.3, np.pi * 0.7)
|
|
self.ball_vx = np.cos(angle) * BALL_SPEED * np.random.choice([-1, 1])
|
|
self.ball_vy = -abs(np.sin(angle)) * BALL_SPEED
|
|
self.bricks = np.ones((BRICK_ROWS, BRICK_COLS), dtype=bool)
|
|
self.done = False
|
|
self._steps = 0
|
|
return self._state()
|
|
|
|
def step(self, action: int) -> Tuple[np.ndarray, float, bool]:
|
|
self._steps += 1
|
|
reward = 0.0
|
|
|
|
# Paddle movement
|
|
if action == 0:
|
|
self.paddle_x = max(PADDLE_W / 2, self.paddle_x - PADDLE_SPEED)
|
|
elif action == 2:
|
|
self.paddle_x = min(SCREEN_W - PADDLE_W / 2, self.paddle_x + PADDLE_SPEED)
|
|
|
|
# Ball physics
|
|
self.ball_x += self.ball_vx
|
|
self.ball_y += self.ball_vy
|
|
|
|
# Side walls
|
|
if self.ball_x - BALL_RADIUS <= 0:
|
|
self.ball_x = BALL_RADIUS
|
|
self.ball_vx = abs(self.ball_vx)
|
|
elif self.ball_x + BALL_RADIUS >= SCREEN_W:
|
|
self.ball_x = SCREEN_W - BALL_RADIUS
|
|
self.ball_vx = -abs(self.ball_vx)
|
|
|
|
# Ceiling
|
|
if self.ball_y - BALL_RADIUS <= 0:
|
|
self.ball_y = BALL_RADIUS
|
|
self.ball_vy = abs(self.ball_vy)
|
|
|
|
# Paddle collision
|
|
if (self.ball_vy > 0
|
|
and PADDLE_Y - PADDLE_H // 2 - BALL_RADIUS <= self.ball_y <= PADDLE_Y + PADDLE_H // 2
|
|
and abs(self.ball_x - self.paddle_x) <= PADDLE_W / 2 + BALL_RADIUS):
|
|
self.ball_y = PADDLE_Y - PADDLE_H // 2 - BALL_RADIUS
|
|
offset = (self.ball_x - self.paddle_x) / (PADDLE_W / 2) # -1..1
|
|
speed = np.hypot(self.ball_vx, self.ball_vy)
|
|
self.ball_vx = np.clip(offset * BALL_SPEED * 1.4, -BALL_SPEED * 1.4, BALL_SPEED * 1.4)
|
|
self.ball_vy = -np.sqrt(max(speed**2 - self.ball_vx**2, 0.1))
|
|
|
|
# Brick collisions
|
|
for row in range(BRICK_ROWS):
|
|
for col in range(BRICK_COLS):
|
|
if not self.bricks[row, col]:
|
|
continue
|
|
bx1 = col * BRICK_W
|
|
bx2 = bx1 + BRICK_W
|
|
by1 = BRICK_TOP + row * BRICK_H
|
|
by2 = by1 + BRICK_H
|
|
if bx1 <= self.ball_x <= bx2 and by1 - BALL_RADIUS <= self.ball_y <= by2 + BALL_RADIUS:
|
|
self.bricks[row, col] = False
|
|
reward += 1.0
|
|
# bounce direction: hit top/bottom → flip vy, hit sides → flip vx
|
|
overlap_y = min(abs(self.ball_y - by1), abs(self.ball_y - by2))
|
|
overlap_x = min(abs(self.ball_x - bx1), abs(self.ball_x - bx2))
|
|
if overlap_y <= overlap_x:
|
|
self.ball_vy = -self.ball_vy
|
|
else:
|
|
self.ball_vx = -self.ball_vx
|
|
break
|
|
|
|
# Ball out of bounds
|
|
if self.ball_y - BALL_RADIUS > SCREEN_H:
|
|
reward -= 5.0
|
|
self.done = True
|
|
|
|
# Level cleared
|
|
if not self.bricks.any():
|
|
reward += 50.0
|
|
self.done = True
|
|
|
|
if self._steps >= MAX_STEPS:
|
|
self.done = True
|
|
|
|
return self._state(), reward, self.done
|
|
|
|
def _state(self) -> np.ndarray:
|
|
return np.array([
|
|
self.ball_x / SCREEN_W,
|
|
self.ball_y / SCREEN_H,
|
|
self.ball_vx / BALL_SPEED,
|
|
self.ball_vy / BALL_SPEED,
|
|
self.paddle_x / SCREEN_W,
|
|
*self.bricks.flatten().astype(np.float32),
|
|
], dtype=np.float32)
|
|
|
|
def render(self, episode: int = 0, total_reward: float = 0.0) -> bool:
|
|
import pygame
|
|
if self._screen is None:
|
|
pygame.init()
|
|
self._screen = pygame.display.set_mode((SCREEN_W, SCREEN_H))
|
|
pygame.display.set_caption("Brickbreaker — tinygrad PPO")
|
|
self._font = pygame.font.SysFont("monospace", 14)
|
|
self._clock = pygame.time.Clock()
|
|
|
|
for event in pygame.event.get():
|
|
if event.type == pygame.QUIT:
|
|
return False
|
|
|
|
self._screen.fill(BG)
|
|
|
|
# Bricks
|
|
for row in range(BRICK_ROWS):
|
|
for col in range(BRICK_COLS):
|
|
if self.bricks[row, col]:
|
|
r = pygame.Rect(col * BRICK_W + 2, BRICK_TOP + row * BRICK_H + 2, BRICK_W - 4, BRICK_H - 4)
|
|
pygame.draw.rect(self._screen, C_BRICKS[row % len(C_BRICKS)], r, border_radius=3)
|
|
|
|
# Paddle
|
|
pr = pygame.Rect(int(self.paddle_x - PADDLE_W // 2), PADDLE_Y - PADDLE_H // 2, PADDLE_W, PADDLE_H)
|
|
pygame.draw.rect(self._screen, C_PADDLE, pr, border_radius=6)
|
|
|
|
# Ball
|
|
pygame.draw.circle(self._screen, C_BALL, (int(self.ball_x), int(self.ball_y)), BALL_RADIUS)
|
|
|
|
# HUD
|
|
bricks_left = int(self.bricks.sum())
|
|
txt = self._font.render(f"ep:{episode} bricks:{bricks_left:2d}/{N_BRICKS} reward:{total_reward:+.0f}", True, C_TEXT)
|
|
self._screen.blit(txt, (8, 8))
|
|
|
|
pygame.display.flip()
|
|
self._clock.tick(60)
|
|
return True
|
|
|
|
|
|
class ActorCritic:
|
|
def __init__(self, in_features: int, out_features: int):
|
|
self.a1 = nn.Linear(in_features, HIDDEN)
|
|
self.a2 = nn.Linear(HIDDEN, HIDDEN // 2)
|
|
self.a3 = nn.Linear(HIDDEN // 2, out_features)
|
|
|
|
self.c1 = nn.Linear(in_features, HIDDEN)
|
|
self.c2 = nn.Linear(HIDDEN, HIDDEN // 2)
|
|
self.c3 = nn.Linear(HIDDEN // 2, 1)
|
|
|
|
def __call__(self, obs: Tensor) -> Tuple[Tensor, Tensor]:
|
|
a = self.a1(obs).relu()
|
|
a = self.a2(a).relu()
|
|
act = self.a3(a).log_softmax()
|
|
c = self.c1(obs).relu()
|
|
c = self.c2(c).relu()
|
|
return act, self.c3(c)
|
|
|
|
|
|
def watch(model: ActorCritic, episode: int, n_games: int = 3) -> None:
|
|
env = BrickBreakerEnv(render=True)
|
|
for g in range(n_games):
|
|
obs = env.reset()
|
|
done, total_rew = False, 0.0
|
|
while not done:
|
|
if not env.render(episode=episode, total_reward=total_rew):
|
|
return
|
|
act = model(Tensor(obs))[0].argmax().item()
|
|
obs, rew, done = env.step(act)
|
|
total_rew += rew
|
|
env.render(episode=episode, total_reward=total_rew)
|
|
time.sleep(1.5)
|
|
import pygame; pygame.quit()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--episodes", type=int, default=EPISODES)
|
|
parser.add_argument("--watch-every", type=int, default=0, help="render a game every N episodes during training (0=off)")
|
|
parser.add_argument("--watch-only", action="store_true", help="skip training, just watch a random agent")
|
|
args = parser.parse_args()
|
|
|
|
model = ActorCritic(STATE_SIZE, N_ACTIONS)
|
|
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=LR)
|
|
env = BrickBreakerEnv(render=False)
|
|
|
|
@TinyJit
|
|
def train_step(x: Tensor, a: Tensor, r: Tensor, old_log: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
|
with Tensor.train():
|
|
log_dist, value = model(x)
|
|
mask = (a.reshape(-1, 1) == Tensor.arange(N_ACTIONS).reshape(1, -1).expand(a.shape[0], -1)).float()
|
|
advantage = r.reshape(-1, 1) - value
|
|
masked_adv = mask * advantage.detach()
|
|
ratios = (log_dist - old_log).exp()
|
|
action_loss = -(masked_adv * ratios.clip(1 - PPO_EPSILON, 1 + PPO_EPSILON)).sum(-1).mean()
|
|
entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean()
|
|
critic_loss = advantage.square().mean()
|
|
opt.zero_grad()
|
|
(action_loss + entropy_loss * ENTROPY_SCALE + critic_loss).backward()
|
|
opt.step()
|
|
return action_loss.realize(), entropy_loss.realize(), critic_loss.realize()
|
|
|
|
@TinyJit
|
|
def get_action(obs: Tensor) -> Tensor:
|
|
return model(obs)[0].exp().multinomial().realize()
|
|
|
|
if args.watch_only:
|
|
watch(model, episode=0)
|
|
exit(0)
|
|
|
|
Xn, An, Rn = [], [], []
|
|
t_start, total_steps = time.perf_counter(), 0
|
|
best_reward = -float("inf")
|
|
|
|
for ep in (pbar := trange(args.episodes)):
|
|
get_action.reset()
|
|
|
|
obs = env.reset()
|
|
rews = []
|
|
done = False
|
|
while not done:
|
|
act = get_action(Tensor(obs)).item()
|
|
Xn.append(np.copy(obs))
|
|
An.append(act)
|
|
obs, rew, done = env.step(act)
|
|
rews.append(rew)
|
|
total_steps += len(rews)
|
|
|
|
# Discounted reward-to-go
|
|
rews_arr = np.array(rews)
|
|
discounts = np.power(GAMMA, np.arange(len(rews_arr)))
|
|
Rn += [float(np.sum(rews_arr[i:] * discounts[:len(rews_arr) - i])) for i in range(len(rews_arr))]
|
|
|
|
Xn, An, Rn = Xn[-REPLAY_BUFFER:], An[-REPLAY_BUFFER:], Rn[-REPLAY_BUFFER:]
|
|
|
|
if len(Xn) >= BATCH_SIZE:
|
|
X, A, R = Tensor(np.array(Xn)), Tensor(An), Tensor(Rn)
|
|
old_log = model(X)[0].detach()
|
|
for _ in range(TRAIN_STEPS):
|
|
samples = Tensor.randint(BATCH_SIZE, high=X.shape[0]).realize()
|
|
a_loss, e_loss, c_loss = train_step(X[samples], A[samples], R[samples], old_log[samples])
|
|
|
|
ep_reward = sum(rews)
|
|
best_reward = max(best_reward, ep_reward)
|
|
bricks_hit = int(N_BRICKS - env.bricks.sum())
|
|
sps = total_steps / (time.perf_counter() - t_start)
|
|
pbar.set_description(
|
|
f"ep:{ep:4d} sps:{sps:6.0f} bricks:{bricks_hit:2d}/{N_BRICKS} rew:{ep_reward:+6.1f} best:{best_reward:+6.1f}"
|
|
)
|
|
|
|
if args.watch_every > 0 and (ep + 1) % args.watch_every == 0:
|
|
watch(model, episode=ep + 1, n_games=1)
|
|
|
|
print(f"\nTraining done. Watching agent play...")
|
|
watch(model, episode=args.episodes, n_games=3)
|