import os, random, json, sys from pathlib import Path import numpy as np, torch from torch.utils.data import Dataset, DataLoader HUB_DIR = Path(__file__).parent.resolve() sys.path.insert(0, str(HUB_DIR / "src")) sys.path.insert(0, str(HUB_DIR / "src" / "models")) from models.grid_jepa import GridJEPA import trackio MAX_GRID = 30 def load_tasks(d): tasks = [] for p in sorted(Path(d).glob("*.json")): with open(p) as f: tasks.append(json.load(f)) return tasks def to_tensor(grid): arr = np.array(grid, dtype=np.int64) H, W = arr.shape t = torch.zeros(MAX_GRID, MAX_GRID, dtype=torch.long) t[:H, :W] = torch.from_numpy(arr) return t class ARCDataset(Dataset): def __init__(self, d): self.samples = [] for task in load_tasks(d): for pair in task.get("train", []): self.samples.append({"input": to_tensor(pair["input"]), "output": to_tensor(pair["output"])}) def __len__(self): return len(self.samples) def __getitem__(self, idx): s = self.samples[idx] return {"context_grid": s["input"], "target_grid": s["output"]} def collate(batch): return {"context_grid": torch.stack([b["context_grid"] for b in batch]), "target_grid": torch.stack([b["target_grid"] for b in batch])} def sample_masks(B, H, W, ratio=0.4, device="cpu"): N = H * W; nt = max(1, int(N * ratio)) ctx = torch.zeros(B, N, dtype=torch.bool, device=device) tgt = torch.zeros(B, N, dtype=torch.bool, device=device) for b in range(B): idx = list(range(N)); random.shuffle(idx) tgt[b, idx[:nt]] = True; ctx[b, idx[nt:]] = True return ctx, tgt device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"[TRAIN] Device: {device}", flush=True) run = trackio.init(project="arc-agi-3", name="grid-jepa-fast", group="pretrain") print("[TRAIN] Trackio OK", flush=True) model = GridJEPA(num_colors=10, embed_dim=192, encoder_depth=6, predictor_depth=6, num_heads=6, max_grid_size=MAX_GRID, ema_decay=0.996).to(device) print(f"[TRAIN] Params: {sum(p.numel() for p in model.parameters()):,}", flush=True) data_dir = HUB_DIR / "data" / "training" if not data_dir.exists(): data_dir = Path("/app/arc_data_source/data/training") ds = ARCDataset(str(data_dir)) print(f"[TRAIN] Samples: {len(ds)}", flush=True) loader = DataLoader(ds, batch_size=8, shuffle=True, collate_fn=collate, num_workers=0) opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.05) save_dir = Path("/app/arc-jepa-hf/checkpoints") save_dir.mkdir(exist_ok=True) for epoch in range(1, 51): model.train(); epoch_loss = 0.0; n = 0 for batch in loader: ctx_g = batch["context_grid"].to(device) tgt_g = batch["target_grid"].to(device) B = ctx_g.shape[0] ctx_mask, target_mask = sample_masks(B, MAX_GRID, MAX_GRID, ratio=0.5, device=device) a_key = torch.zeros(B, dtype=torch.long, device=device) a_pos = torch.zeros(B, dtype=torch.long, device=device) opt.zero_grad() loss, _ = model(tgt_g, ctx_mask, target_mask, a_key, a_pos) if torch.isnan(loss) or loss.item() == 0: continue loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step(); model.update_ema() epoch_loss += loss.item(); n += 1 avg = epoch_loss / max(n, 1) print(f"[TRAIN] Epoch {epoch}/50: loss={avg:.4f}", flush=True) run.log({"epoch_loss": avg, "epoch": epoch}) if epoch % 10 == 0 or epoch == 50: ck = save_dir / f"ckpt_{epoch}.pt" torch.save({"epoch": epoch, "model": model.state_dict()}, ck) print(f"[TRAIN] Saved {ck}", flush=True) final = save_dir / "final.pt" torch.save({"model": model.state_dict(), "epoch": 50}, final) print(f"[TRAIN] Final: {final}", flush=True) from huggingface_hub import HfApi try: api = HfApi() api.upload_file(path_or_fileobj=str(final), path_in_repo="checkpoints/final.pt", repo_id="guychuk/arc-agi-3-grid-jepa", repo_type="model") print("[TRAIN] Pushed to hub", flush=True) except Exception as e: print(f"[TRAIN] Push failed: {e}", flush=True) run.finish() print("[TRAIN] Done", flush=True)