""" Train Grid-JEPA on ARC-AGI-1 data. Self-contained — can run in HuggingFace Zero-GPU, local GPU, or via hf_jobs. Usage (CPU): python train_arc1.py --epochs 50 --batch_size 32 --lr 3e-4 Usage (GPU): python train_arc1.py --epochs 50 --batch_size 32 --lr 3e-4 --device cuda After training, the model pushes to: https://huggingface.co/guychuk/arc-agi-3-grid-jepa """ import os, random, sys, argparse, json from pathlib import Path import numpy as np, torch from torch.utils.data import Dataset, DataLoader _SCRIPT_DIR = Path(__file__).parent.resolve() sys.path.insert(0, str(_SCRIPT_DIR / "src")) sys.path.insert(0, str(_SCRIPT_DIR / "src" / "models")) from models.grid_jepa import GridJEPA import trackio MAX_GRID = 30 def load_tasks(data_dir): tasks = [] for p in sorted(Path(data_dir).glob("*.json")): with open(p) as f: tasks.append(json.load(f)) return tasks def to_tensor(grid): arr = np.array(grid, dtype=np.int64) H, W = arr.shape t = torch.zeros(MAX_GRID, MAX_GRID, dtype=torch.long) t[:H, :W] = torch.from_numpy(arr) return t class ARCDataset(Dataset): def __init__(self, data_dir): self.samples = [] for task in load_tasks(data_dir): for pair in task.get("train", []): self.samples.append({"input": to_tensor(pair["input"]), "output": to_tensor(pair["output"])}) def __len__(self): return len(self.samples) def __getitem__(self, idx): s = self.samples[idx] return {"context_grid": s["input"], "target_grid": s["output"]} def collate(batch): return {"context_grid": torch.stack([b["context_grid"] for b in batch]), "target_grid": torch.stack([b["target_grid"] for b in batch])} def sample_masks(B, H, W, ratio=0.4, device="cpu"): N = H * W; n_t = max(1, int(N * ratio)) ctx = torch.zeros(B, N, dtype=torch.bool, device=device) tgt = torch.zeros(B, N, dtype=torch.bool, device=device) for b in range(B): idx = list(range(N)); random.shuffle(idx) tgt[b, idx[:n_t]] = True; ctx[b, idx[n_t:]] = True return ctx, tgt def main(): p = argparse.ArgumentParser() p.add_argument("--data_dir", default=str(_SCRIPT_DIR / "data" / "training")) p.add_argument("--epochs", type=int, default=50) p.add_argument("--batch_size", type=int, default=32) p.add_argument("--lr", type=float, default=3e-4) p.add_argument("--ema", type=float, default=0.996) p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") p.add_argument("--save_dir", default="./checkpoints") p.add_argument("--hub_model_id", default="guychuk/arc-agi-3-grid-jepa") p.add_argument("--push_to_hub", action="store_true") args = p.parse_args() # Auto-download data if not present if not Path(args.data_dir).exists(): print("ARC-AGI-1 data not found. Cloning from GitHub...") os.system("git clone --depth 1 https://github.com/fchollet/ARC-AGI.git " + str(_SCRIPT_DIR / "arc_data_source")) args.data_dir = str(_SCRIPT_DIR / "arc_data_source" / "data" / "training") print(f"Data dir set to: {args.data_dir}") dev = torch.device(args.device) os.makedirs(args.save_dir, exist_ok=True) run = trackio.init(project="arc-agi-3", name="grid-jepa-arc1", group="pretrain") print(f"Device: {dev}") model = GridJEPA( num_colors=10, embed_dim=192, encoder_depth=6, predictor_depth=6, num_heads=6, max_grid_size=MAX_GRID, ema_decay=args.ema ).to(dev) print(f"Params: {sum(p.numel() for p in model.parameters()):,}") ds = ARCDataset(args.data_dir) print(f"Samples: {len(ds)}") ld = DataLoader(ds, batch_size=args.batch_size, shuffle=True, collate_fn=collate, num_workers=0) opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.05) for epoch in range(1, args.epochs + 1): model.train(); epoch_loss = 0.0; n = 0 for batch in ld: ctx_g = batch["context_grid"].to(dev) tgt_g = batch["target_grid"].to(dev) B = ctx_g.shape[0] ctx_mask, target_mask = sample_masks(B, MAX_GRID, MAX_GRID, ratio=0.5, device=dev) a_key = torch.zeros(B, dtype=torch.long, device=dev) a_pos = torch.zeros(B, dtype=torch.long, device=dev) opt.zero_grad() loss, _ = model(tgt_g, ctx_mask, target_mask, a_key, a_pos) if torch.isnan(loss) or loss.item() == 0: continue loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step(); model.update_ema() epoch_loss += loss.item(); n += 1 avg = epoch_loss / max(n, 1) print(f"Epoch {epoch}/{args.epochs}: loss={avg:.4f}") run.log({"epoch_loss": avg, "epoch": epoch}) if epoch % 10 == 0: ck = os.path.join(args.save_dir, f"ckpt_{epoch}.pt") torch.save({"epoch": epoch, "model": model.state_dict()}, ck) print(f" Saved {ck}") # FINAL SAVE + PUSH (job storage is ephemeral) final = os.path.join(args.save_dir, "final.pt") torch.save({"model": model.state_dict(), "epoch": args.epochs}, final) print(f"Final checkpoint: {final}") if args.push_to_hub: from huggingface_hub import HfApi try: api = HfApi() api.upload_file(path_or_fileobj=final, path_in_repo="checkpoints/final.pt", repo_id=args.hub_model_id, repo_type="model") print(f"Pushed to: https://huggingface.co/{args.hub_model_id}") except Exception as e: print(f"Push failed: {e}") run.finish() if __name__ == "__main__": main()