First game brickbreaker
This commit is contained in:
297
brickbreaker_ppo.py
Normal file
297
brickbreaker_ppo.py
Normal file
@@ -0,0 +1,297 @@
|
||||
# Brickbreaker with PPO reinforcement learning using tinygrad
|
||||
from typing import Tuple
|
||||
import argparse, time
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit, nn
|
||||
from tinygrad.helpers import trange
|
||||
|
||||
# --- Game Constants ---
|
||||
SCREEN_W, SCREEN_H = 480, 560
|
||||
PADDLE_W, PADDLE_H = 80, 12
|
||||
PADDLE_Y = SCREEN_H - 40
|
||||
PADDLE_SPEED = 7
|
||||
BALL_RADIUS = 7
|
||||
BALL_SPEED = 5.0
|
||||
BRICK_ROWS, BRICK_COLS = 4, 8
|
||||
BRICK_W = SCREEN_W // BRICK_COLS
|
||||
BRICK_H = 22
|
||||
BRICK_TOP = 60
|
||||
MAX_STEPS = 2000
|
||||
N_ACTIONS = 3 # 0=left, 1=stay, 2=right
|
||||
N_BRICKS = BRICK_ROWS * BRICK_COLS
|
||||
STATE_SIZE = 5 + N_BRICKS # ball x/y/vx/vy + paddle x + brick grid
|
||||
|
||||
# Colors
|
||||
BG = (15, 15, 30)
|
||||
C_PADDLE = (80, 200, 255)
|
||||
C_BALL = (255, 230, 80)
|
||||
C_BRICKS = [(255, 80, 80), (255, 160, 60), (100, 220, 100), (100, 150, 255)]
|
||||
C_TEXT = (200, 200, 220)
|
||||
|
||||
# --- PPO Hyperparameters ---
|
||||
BATCH_SIZE = 256
|
||||
ENTROPY_SCALE = 0.002
|
||||
REPLAY_BUFFER = 5000
|
||||
PPO_EPSILON = 0.2
|
||||
HIDDEN = 128
|
||||
LR = 3e-4
|
||||
TRAIN_STEPS = 5
|
||||
EPISODES = 400
|
||||
GAMMA = 0.99
|
||||
|
||||
|
||||
class BrickBreakerEnv:
|
||||
def __init__(self, render: bool = False):
|
||||
self.render_mode = render
|
||||
self._screen = None
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> np.ndarray:
|
||||
self.paddle_x = float(SCREEN_W // 2)
|
||||
self.ball_x = float(SCREEN_W // 2)
|
||||
self.ball_y = float(PADDLE_Y - BALL_RADIUS - 8)
|
||||
angle = np.random.uniform(np.pi * 0.3, np.pi * 0.7)
|
||||
self.ball_vx = np.cos(angle) * BALL_SPEED * np.random.choice([-1, 1])
|
||||
self.ball_vy = -abs(np.sin(angle)) * BALL_SPEED
|
||||
self.bricks = np.ones((BRICK_ROWS, BRICK_COLS), dtype=bool)
|
||||
self.done = False
|
||||
self._steps = 0
|
||||
return self._state()
|
||||
|
||||
def step(self, action: int) -> Tuple[np.ndarray, float, bool]:
|
||||
self._steps += 1
|
||||
reward = 0.0
|
||||
|
||||
# Paddle movement
|
||||
if action == 0:
|
||||
self.paddle_x = max(PADDLE_W / 2, self.paddle_x - PADDLE_SPEED)
|
||||
elif action == 2:
|
||||
self.paddle_x = min(SCREEN_W - PADDLE_W / 2, self.paddle_x + PADDLE_SPEED)
|
||||
|
||||
# Ball physics
|
||||
self.ball_x += self.ball_vx
|
||||
self.ball_y += self.ball_vy
|
||||
|
||||
# Side walls
|
||||
if self.ball_x - BALL_RADIUS <= 0:
|
||||
self.ball_x = BALL_RADIUS
|
||||
self.ball_vx = abs(self.ball_vx)
|
||||
elif self.ball_x + BALL_RADIUS >= SCREEN_W:
|
||||
self.ball_x = SCREEN_W - BALL_RADIUS
|
||||
self.ball_vx = -abs(self.ball_vx)
|
||||
|
||||
# Ceiling
|
||||
if self.ball_y - BALL_RADIUS <= 0:
|
||||
self.ball_y = BALL_RADIUS
|
||||
self.ball_vy = abs(self.ball_vy)
|
||||
|
||||
# Paddle collision
|
||||
if (self.ball_vy > 0
|
||||
and PADDLE_Y - PADDLE_H // 2 - BALL_RADIUS <= self.ball_y <= PADDLE_Y + PADDLE_H // 2
|
||||
and abs(self.ball_x - self.paddle_x) <= PADDLE_W / 2 + BALL_RADIUS):
|
||||
self.ball_y = PADDLE_Y - PADDLE_H // 2 - BALL_RADIUS
|
||||
offset = (self.ball_x - self.paddle_x) / (PADDLE_W / 2) # -1..1
|
||||
speed = np.hypot(self.ball_vx, self.ball_vy)
|
||||
self.ball_vx = np.clip(offset * BALL_SPEED * 1.4, -BALL_SPEED * 1.4, BALL_SPEED * 1.4)
|
||||
self.ball_vy = -np.sqrt(max(speed**2 - self.ball_vx**2, 0.1))
|
||||
|
||||
# Brick collisions
|
||||
for row in range(BRICK_ROWS):
|
||||
for col in range(BRICK_COLS):
|
||||
if not self.bricks[row, col]:
|
||||
continue
|
||||
bx1 = col * BRICK_W
|
||||
bx2 = bx1 + BRICK_W
|
||||
by1 = BRICK_TOP + row * BRICK_H
|
||||
by2 = by1 + BRICK_H
|
||||
if bx1 <= self.ball_x <= bx2 and by1 - BALL_RADIUS <= self.ball_y <= by2 + BALL_RADIUS:
|
||||
self.bricks[row, col] = False
|
||||
reward += 1.0
|
||||
# bounce direction: hit top/bottom → flip vy, hit sides → flip vx
|
||||
overlap_y = min(abs(self.ball_y - by1), abs(self.ball_y - by2))
|
||||
overlap_x = min(abs(self.ball_x - bx1), abs(self.ball_x - bx2))
|
||||
if overlap_y <= overlap_x:
|
||||
self.ball_vy = -self.ball_vy
|
||||
else:
|
||||
self.ball_vx = -self.ball_vx
|
||||
break
|
||||
|
||||
# Ball out of bounds
|
||||
if self.ball_y - BALL_RADIUS > SCREEN_H:
|
||||
reward -= 5.0
|
||||
self.done = True
|
||||
|
||||
# Level cleared
|
||||
if not self.bricks.any():
|
||||
reward += 50.0
|
||||
self.done = True
|
||||
|
||||
if self._steps >= MAX_STEPS:
|
||||
self.done = True
|
||||
|
||||
return self._state(), reward, self.done
|
||||
|
||||
def _state(self) -> np.ndarray:
|
||||
return np.array([
|
||||
self.ball_x / SCREEN_W,
|
||||
self.ball_y / SCREEN_H,
|
||||
self.ball_vx / BALL_SPEED,
|
||||
self.ball_vy / BALL_SPEED,
|
||||
self.paddle_x / SCREEN_W,
|
||||
*self.bricks.flatten().astype(np.float32),
|
||||
], dtype=np.float32)
|
||||
|
||||
def render(self, episode: int = 0, total_reward: float = 0.0) -> bool:
|
||||
import pygame
|
||||
if self._screen is None:
|
||||
pygame.init()
|
||||
self._screen = pygame.display.set_mode((SCREEN_W, SCREEN_H))
|
||||
pygame.display.set_caption("Brickbreaker — tinygrad PPO")
|
||||
self._font = pygame.font.SysFont("monospace", 14)
|
||||
self._clock = pygame.time.Clock()
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
return False
|
||||
|
||||
self._screen.fill(BG)
|
||||
|
||||
# Bricks
|
||||
for row in range(BRICK_ROWS):
|
||||
for col in range(BRICK_COLS):
|
||||
if self.bricks[row, col]:
|
||||
r = pygame.Rect(col * BRICK_W + 2, BRICK_TOP + row * BRICK_H + 2, BRICK_W - 4, BRICK_H - 4)
|
||||
pygame.draw.rect(self._screen, C_BRICKS[row % len(C_BRICKS)], r, border_radius=3)
|
||||
|
||||
# Paddle
|
||||
pr = pygame.Rect(int(self.paddle_x - PADDLE_W // 2), PADDLE_Y - PADDLE_H // 2, PADDLE_W, PADDLE_H)
|
||||
pygame.draw.rect(self._screen, C_PADDLE, pr, border_radius=6)
|
||||
|
||||
# Ball
|
||||
pygame.draw.circle(self._screen, C_BALL, (int(self.ball_x), int(self.ball_y)), BALL_RADIUS)
|
||||
|
||||
# HUD
|
||||
bricks_left = int(self.bricks.sum())
|
||||
txt = self._font.render(f"ep:{episode} bricks:{bricks_left:2d}/{N_BRICKS} reward:{total_reward:+.0f}", True, C_TEXT)
|
||||
self._screen.blit(txt, (8, 8))
|
||||
|
||||
pygame.display.flip()
|
||||
self._clock.tick(60)
|
||||
return True
|
||||
|
||||
|
||||
class ActorCritic:
|
||||
def __init__(self, in_features: int, out_features: int):
|
||||
self.a1 = nn.Linear(in_features, HIDDEN)
|
||||
self.a2 = nn.Linear(HIDDEN, HIDDEN // 2)
|
||||
self.a3 = nn.Linear(HIDDEN // 2, out_features)
|
||||
|
||||
self.c1 = nn.Linear(in_features, HIDDEN)
|
||||
self.c2 = nn.Linear(HIDDEN, HIDDEN // 2)
|
||||
self.c3 = nn.Linear(HIDDEN // 2, 1)
|
||||
|
||||
def __call__(self, obs: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
a = self.a1(obs).relu()
|
||||
a = self.a2(a).relu()
|
||||
act = self.a3(a).log_softmax()
|
||||
c = self.c1(obs).relu()
|
||||
c = self.c2(c).relu()
|
||||
return act, self.c3(c)
|
||||
|
||||
|
||||
def watch(model: ActorCritic, episode: int, n_games: int = 3) -> None:
|
||||
env = BrickBreakerEnv(render=True)
|
||||
for g in range(n_games):
|
||||
obs = env.reset()
|
||||
done, total_rew = False, 0.0
|
||||
while not done:
|
||||
if not env.render(episode=episode, total_reward=total_rew):
|
||||
return
|
||||
act = model(Tensor(obs))[0].argmax().item()
|
||||
obs, rew, done = env.step(act)
|
||||
total_rew += rew
|
||||
env.render(episode=episode, total_reward=total_rew)
|
||||
time.sleep(1.5)
|
||||
import pygame; pygame.quit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--episodes", type=int, default=EPISODES)
|
||||
parser.add_argument("--watch-every", type=int, default=0, help="render a game every N episodes during training (0=off)")
|
||||
parser.add_argument("--watch-only", action="store_true", help="skip training, just watch a random agent")
|
||||
args = parser.parse_args()
|
||||
|
||||
model = ActorCritic(STATE_SIZE, N_ACTIONS)
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(model), lr=LR)
|
||||
env = BrickBreakerEnv(render=False)
|
||||
|
||||
@TinyJit
|
||||
def train_step(x: Tensor, a: Tensor, r: Tensor, old_log: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
||||
with Tensor.train():
|
||||
log_dist, value = model(x)
|
||||
mask = (a.reshape(-1, 1) == Tensor.arange(N_ACTIONS).reshape(1, -1).expand(a.shape[0], -1)).float()
|
||||
advantage = r.reshape(-1, 1) - value
|
||||
masked_adv = mask * advantage.detach()
|
||||
ratios = (log_dist - old_log).exp()
|
||||
action_loss = -(masked_adv * ratios.clip(1 - PPO_EPSILON, 1 + PPO_EPSILON)).sum(-1).mean()
|
||||
entropy_loss = (log_dist.exp() * log_dist).sum(-1).mean()
|
||||
critic_loss = advantage.square().mean()
|
||||
opt.zero_grad()
|
||||
(action_loss + entropy_loss * ENTROPY_SCALE + critic_loss).backward()
|
||||
opt.step()
|
||||
return action_loss.realize(), entropy_loss.realize(), critic_loss.realize()
|
||||
|
||||
@TinyJit
|
||||
def get_action(obs: Tensor) -> Tensor:
|
||||
return model(obs)[0].exp().multinomial().realize()
|
||||
|
||||
if args.watch_only:
|
||||
watch(model, episode=0)
|
||||
exit(0)
|
||||
|
||||
Xn, An, Rn = [], [], []
|
||||
t_start, total_steps = time.perf_counter(), 0
|
||||
best_reward = -float("inf")
|
||||
|
||||
for ep in (pbar := trange(args.episodes)):
|
||||
get_action.reset()
|
||||
|
||||
obs = env.reset()
|
||||
rews = []
|
||||
done = False
|
||||
while not done:
|
||||
act = get_action(Tensor(obs)).item()
|
||||
Xn.append(np.copy(obs))
|
||||
An.append(act)
|
||||
obs, rew, done = env.step(act)
|
||||
rews.append(rew)
|
||||
total_steps += len(rews)
|
||||
|
||||
# Discounted reward-to-go
|
||||
rews_arr = np.array(rews)
|
||||
discounts = np.power(GAMMA, np.arange(len(rews_arr)))
|
||||
Rn += [float(np.sum(rews_arr[i:] * discounts[:len(rews_arr) - i])) for i in range(len(rews_arr))]
|
||||
|
||||
Xn, An, Rn = Xn[-REPLAY_BUFFER:], An[-REPLAY_BUFFER:], Rn[-REPLAY_BUFFER:]
|
||||
|
||||
if len(Xn) >= BATCH_SIZE:
|
||||
X, A, R = Tensor(np.array(Xn)), Tensor(An), Tensor(Rn)
|
||||
old_log = model(X)[0].detach()
|
||||
for _ in range(TRAIN_STEPS):
|
||||
samples = Tensor.randint(BATCH_SIZE, high=X.shape[0]).realize()
|
||||
a_loss, e_loss, c_loss = train_step(X[samples], A[samples], R[samples], old_log[samples])
|
||||
|
||||
ep_reward = sum(rews)
|
||||
best_reward = max(best_reward, ep_reward)
|
||||
bricks_hit = int(N_BRICKS - env.bricks.sum())
|
||||
sps = total_steps / (time.perf_counter() - t_start)
|
||||
pbar.set_description(
|
||||
f"ep:{ep:4d} sps:{sps:6.0f} bricks:{bricks_hit:2d}/{N_BRICKS} rew:{ep_reward:+6.1f} best:{best_reward:+6.1f}"
|
||||
)
|
||||
|
||||
if args.watch_every > 0 and (ep + 1) % args.watch_every == 0:
|
||||
watch(model, episode=ep + 1, n_games=1)
|
||||
|
||||
print(f"\nTraining done. Watching agent play...")
|
||||
watch(model, episode=args.episodes, n_games=3)
|
||||
Reference in New Issue
Block a user