Skip to content

geometric-intelligence/group-agf

Repository files navigation

Sequential Group Composition

A Window into the Mechanics of Deep Learning

arXiv CI License: MIT

PaperPDFInstallUsageCitation

Giovanni Luca Marchetti* · Daniel Kunin* · Adele Myers · Francisco Acosta · Nina Miolane


How do neural networks trained over sequences acquire the ability to perform structured operations, such as arithmetic, geometric, and algorithmic computation?

We introduce the sequential group composition task: networks receive a sequence of elements from a finite group encoded in a real vector space and must predict their cumulative product. We prove that two-layer networks learn this task one irreducible representation at a time, in an order determined by the Fourier statistics of the encoding -- producing a characteristic staircase in the training loss.

Staircase learning across five finite groups

Two-layer networks learn group composition one irreducible representation at a time. Top: power spectrum of the learned function over training. Bottom: training loss showing a characteristic staircase. Each column is a different finite group.


Installation

Prerequisites

  • Conda (Miniconda or Anaconda)

Setup (Linux)

# Create and activate the conda environment
conda env create -f conda.yaml
conda activate group-agf

# Install all Python dependencies (pinned versions from poetry.lock)
poetry install

Usage

Single Run

Train a model on a specific group:

python src/main.py --config src/configs/config_d5.yaml

Results (loss curves, predictions, power spectra) are saved to a timestamped directory under runs/.

Supported Groups

The repository includes preconfigured experiments for eight groups:

Group Config Order k Architecture
Cyclic $C_{10}$ config_c10_k3.yaml 10 3 TwoLayerMLP
Cyclic $C_{11}$ config_c11.yaml 11 2 TwoLayerMLP
Product $C_4 \times C_4$ config_c4x4_k3.yaml 16 3 TwoLayerMLP
Product $C_5 \times C_5$ config_c5xc5.yaml 25 2 TwoLayerMLP
Dihedral $D_3$ config_d3.yaml 6 2 TwoLayerMLP
Dihedral $D_5$ config_d5.yaml 10 2 TwoLayerMLP
Octahedral $O_h$ config_oh.yaml 24 2 TwoLayerMLP
Icosahedral $A_5$ config_a5.yaml 60 2 TwoLayerMLP

Reproduce Paper's Figure

Reproduce the paper's figure (training loss + power spectrum for C11, C5xC5, D5, Oh, A5):

python src/main.py --combined-plot

This uses precomputed data from runs_data/ and produces combined_loss_and_power.pdf in seconds. No GPU or training is needed.

If you want to retrain from scratch instead, delete runs_data/ first — the command will automatically detect CUDA, train each group, and generate the plot.

Parameter Sweeps

Run experiments across multiple configurations and random seeds:

python src/run_sweep.py --sweep src/sweep_configs/example_sweep.yaml

Multi-GPU support:

# Auto-detect and use all available GPUs
python src/run_sweep.py --sweep src/sweep_configs/example_sweep.yaml --gpus auto

# Use specific GPUs
python src/run_sweep.py --sweep src/sweep_configs/example_sweep.yaml --gpus 0,1,2,3

Sweep results are saved to sweeps/{sweep_name}_{timestamp}/ with per-seed results and aggregated summaries.

Configuration

Key parameters in the YAML config files:

Parameter Options Description
data.group_name cn, cnxcn, dihedral, octahedral, A5 Group to learn
data.k integer Number of elements to compose
data.template_type custom_fourier, onehot, mnist Template generation method
model.model_type QuadraticRNN, TwoLayerMLP Architecture
model.hidden_dim integer Hidden layer size
model.init_scale float Weight initialization scale
training.optimizer auto, adam, per_neuron, hybrid Optimizer (auto recommended)
training.learning_rate float Base learning rate
training.mode online, offline Training mode
training.epochs integer Number of epochs (offline mode)
Example config -- D5 with custom Fourier template
data:
  group_name: dihedral
  group_n: 5
  k: 2
  template_type: custom_fourier
  powers: [0.0, 3000.0, 2000.0, 1000.0]

model:
  model_type: TwoLayerMLP
  hidden_dim: 300
  init_scale: 0.0001

training:
  optimizer: per_neuron
  learning_rate: 0.006
  mode: offline
  epochs: 5000

Repository Structure

group-agf/
├── src/                          # Source code
│   ├── main.py                   # Training entry point (CLI)
│   ├── model.py                  # TwoLayerMLP, QuadraticRNN
│   ├── optimizer.py              # PerNeuronScaledSGD, HybridRNNOptimizer
│   ├── dataset.py                # Dataset generation and loading
│   ├── template.py               # Template construction functions
│   ├── viz.py                    # Plotting, visualization, and power spectrum helpers
│   ├── train.py                  # Training loops (offline and online)
│   ├── run_sweep.py              # Parameter sweep runner
│   ├── sweep_analysis.py         # Sweep result loading and analysis
│   ├── groups/                   # Finite group implementations
│   │   ├── group.py              # Abstract Group base class (Fourier, power spectrum)
│   │   ├── cn.py                 # CyclicGroup (C_n)
│   │   ├── cnxcn.py              # ProductCyclicGroup (C_n × C_m)
│   │   ├── dn.py                 # DihedralGroup (D_n)
│   │   ├── oh.py                 # OctahedralGroup (O_h)
│   │   ├── a5.py                 # AlternatingGroup (A_5 / icosahedral)
│   │   └── irrep.py              # IrreducibleRepresentation helper
│   ├── configs/                  # Group-specific configurations
│   │   └── config_*.yaml
├── runs_data/                    # Precomputed data for combined plot (Figure 1)
│   ├── {C11,C5xC5,D5,Oh,A5}/
│   │   ├── config.yaml
│   │   ├── train_loss_history.npy
│   │   └── power_data.npz
├── test/                         # Unit and integration tests
├── notebooks/                    # Jupyter notebooks for exploration
├── pyproject.toml                # Project metadata and dependencies
├── poetry.lock                   # Pinned dependency versions
└── conda.yaml                    # Conda environment specification
Module details

model.py -- Neural Network Architectures

Model Description Input
TwoLayerMLP Two-layer feedforward network with configurable nonlinearity (square, relu, tanh, gelu) Flattened binary pair (N, 2 * group_size)
QuadraticRNN Recurrent network: h_t = (W_mix h_{t-1} + W_drive x_t)^2 Sequence (N, k, p)

optimizer.py -- Custom Optimizers

Optimizer Description Recommended for
PerNeuronScaledSGD SGD with per-neuron learning rate scaling exploiting model homogeneity TwoLayerMLP
HybridRNNOptimizer Scaled SGD for MLP weights + Adam for recurrent weights QuadraticRNN
Adam (PyTorch built-in) Standard Adam QuadraticRNN

groups/ -- Finite Group Implementations

  • Group (abstract base class) -- defines the interface (order, elements, irreps, regular_rep) and provides concrete Fourier analysis methods: fourier, inverse_fourier, power_spectrum
  • CyclicGroup ($C_n$), ProductCyclicGroup ($C_n \times C_m$), DihedralGroup ($D_n$), OctahedralGroup ($O_h$), AlternatingGroup ($A_5$)

dataset.py -- Data Generation

  • GroupCompositionDataset -- PyTorch map-style dataset for offline group composition; supports arbitrary sequence length k, sampled/exhaustive mode, and return_all_outputs
  • _OnlineModularAdditionDataset1D, _OnlineModularAdditionDataset2D -- generate samples on-the-fly (GPU-accelerated) via __iter__

template.py -- Template Construction

  • one_hot -- one-hot encoding with zeroth frequency removed
  • custom_fourier -- template from desired per-irrep power values
  • make_template -- config-driven template factory (dispatches to the above)
  • mnist_1d, mnist_2d -- templates derived from MNIST images

viz.py -- Visualization and Power Spectrum Helpers

Plotting and analysis functions for training: plot_train_loss_with_theory, plot_predictions_group, plot_power_group, plot_wmix_structure, plot_irreps, plot_combined_loss_and_power, plot_loss_power_and_weight_power, and more. Also includes power spectrum helpers moved from the former power.py: loss_plateau_predictions, powers_per_neuron_rows, model_power_over_time, topk_template_freqs.

train.py -- Training Loops

  • train(model, loader, criterion, optimizer, ...) -- epoch-based offline training
  • train_online(model, loader, criterion, optimizer, ...) -- step-based online training

sweep_analysis.py -- Sweep Result Analysis

Utilities for loading, aggregating, and plotting parameter sweep results: load_sweep_results_grid, load_training_loss_curves, export_lightweight_data, plot_theory_boundaries.

Testing

# All tests (unit + integration)
pytest test/ -v

# Notebook tests only (requires jupyter/nbconvert)
NOTEBOOK_TEST_MODE=1 pytest test/test_notebooks.py -v

Development

# Install pre-commit hooks
pre-commit install

# Run linting
ruff check .
ruff format --check .

Citation

If you find this work useful, please cite:

@article{marchetti2026sequential,
  title   = {Sequential Group Composition: A Window into the Mechanics of Deep Learning},
  author  = {Marchetti, Giovanni Luca and Kunin, Daniel and Myers, Adele and Acosta, Francisco and Miolane, Nina},
  journal = {arXiv preprint arXiv:2602.03655},
  year    = {2026}
}

License

This project is licensed under the MIT License. See LICENSE for details.

About

Group Alternating Gradient Flows

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors