# Grid-JEPA: JEPA-Based World Model for ARC-AGI-3 A neural architecture for the [ARC-AGI-3 competition](https://www.kaggle.com/competitions/arc-prize-2026-arc-agi-3) combining Joint Embedding Predictive Architecture (JEPA), Recurrent State-Space Models (RSSM), and Test-Time Training (TTT) to solve novel interactive grid-world tasks. ## Architecture ``` ┌─────────────────────────────────────────────────────────────────────┐ │ ARC-AGI-3 Agent │ ├─────────────────────────────────────────────────────────────────────┤ │ Observation Grid (64x64, 16 colors) │ │ ↓ │ │ ┌─────────────┐ │ │ │ Grid-JEPA │ ← I-JEPA adapted for discrete grid worlds │ │ │ Encoder │ 1×1 patches, latent-space prediction │ │ └─────────────┘ │ │ ↓ Latent Representation │ │ ┌─────────────┐ │ │ │ RSSM │ ← Recurrent State-Space Model (DreamerV3-style) │ │ │ World Model │ GRU dynamics + discrete latents │ │ └─────────────┘ │ │ ↓ Hidden State (PERSISTS across levels!) │ │ ┌─────────────┐ ┌─────────────┐ │ │ │ Planning │ ←→ │ Exploration │ │ │ │ (Imagination│ │ (Novelty) │ │ │ │ Rollouts) │ │ │ │ │ └─────────────┘ └─────────────┘ │ │ ↓ │ │ ┌─────────────┐ │ │ │Goal Inference│ ← Discovers objectives from terminal states │ │ └─────────────┘ │ │ ↓ │ │ Action (key, position) → Environment │ └─────────────────────────────────────────────────────────────────────┘ ``` ## Key Innovations ### 1. JEPA for Discrete Grid Worlds - **1×1 patch embeddings**: Each grid cell is semantically meaningful (colors are categorical) - **Latent-space prediction**: Predicts transformations (rotate, fill, move) without pixel reconstruction - **Action-conditioned predictor**: Inspired by Image World Models (Meta, 2024) ### 2. Persistent World Model State (Critical for ARC-AGI-3) - **RSSM state persists across levels** within the same environment - Level 3 requires knowledge from Level 1-2; resetting = instant failure ### 3. Uncertainty-Aware Prediction - **Tracks prediction errors** over a sliding window - **Triggers hypothesis revision** when errors are consistently high - Prevents "latching onto early hypothesis" failure mode ### 4. Test-Time Training (TTT) - **Per-task LoRA adapters** for each novel environment - Fine-tunes on collected demos with geometric augmentations - Based on TTT for ARC (arXiv:2411.07279) achieving 53% on ARC-AGI-1 ## Repository Structure ``` arc-jepa/ ├── src/ │ ├── models/ │ │ ├── encoder.py # GridPatchEmbed + ViT encoders + EMA │ │ ├── predictor.py # Action-conditioned predictor │ │ ├── grid_jepa.py # Complete Grid-JEPA system │ │ ├── rssm.py # Recurrent State-Space Model │ │ ├── agent.py # Full ARC agent (JEPA + RSSM + planning) │ │ └── ttt_adapter.py # LoRA TTT adapter │ ├── data/ # Dataset loaders + augmentations │ ├── training/ # Training scripts │ └── utils/ # Utilities ├── tests/ # Unit tests └── README.md # This file ``` ## Core Components ### `agent.py` — Complete Agent (Central Module) - `ARCAgent`: Full agent loop encoding the core insight of this project - `GoalInferenceModule`: Discovers objectives from terminal/done states - `ExplorationPolicy`: Novelty-seeking with undo loop avoidance - `PlanningModule`: Imagination-based action selection via RSSM rollouts - `UncertaintyTracker`: Hypothesis revision when predictions fail consistently ### `encoder.py` — Grid-JEPA Encoder - `GridPatchEmbed`: 1×1 patch embeddings for color grids - `ViTEncoder`: Multi-head attention transformer blocks - `EMATargetEncoder`: EMA-updated target encoder (prevents collapse) ### `predictor.py` — Action-Conditioned Predictor - `DiscreteActionEmbed`: Embeds (action_key, cell_position) pairs - `ActionConditionedPredictor`: Predicts target patches from context + action - `GridWorldPredictor`: Full predictor + decoder to color logits ### `rssm.py` — Recurrent State-Space Model - `observe()`: Update state with new observation (posterior) - `imagine()`: Predict next state given action (prior) - `rollout()`: Imagine future trajectories for planning - Straight-through gradients for discrete latents ### `ttt_adapter.py` — Test-Time Training - `LoRALayer`: Low-rank adaptation (W' = W + BA) - `PredictorLoRAAdapter`: Per-task LoRA on JEPA predictor - `TTTTrainer`: Fine-tunes on demos with augmentation voting ## Key Design Decisions | Decision | Rationale | |----------|-----------| | **1×1 patches** | Grid cells are semantically meaningful, unlike image pixels | | **L2 latent loss** | Reconstruction forces modeling irrelevant visual details | | **EMA target encoder** | Prevents representation collapse in self-supervised learning | | **Feature conditioning** | Outperforms concatenation for action conditioning | | **Straight-through latents** | Enables gradient flow through discrete RSSM states | | **State persistence** | ARC-AGI-3 levels build on each other | | **Uncertainty tracking** | Prevents getting stuck on wrong hypotheses | | **LoRA TTT** | Efficient per-task adaptation without catastrophic forgetting | ## Papers 1. **I-JEPA** (arXiv:2301.08243) — Foundation of encoder design 2. **Image World Models** (arXiv:2403.00504) — Action-conditioned predictor 3. **DreamerV3** (arXiv:2301.04104) — RSSM dynamics architecture 4. **TTT for ARC** (arXiv:2411.07279) — Per-task LoRA fine-tuning 5. **ARC-AGI-3** (arXiv:2603.24621) — Competition specification ## License MIT License — Open source as required for ARC Prize eligibility.