Re-implementation of the paper 'Grokking: Generalization beyond overfitting on small algorithmic datasets'
Paper
Original paper can be found here
Datasets
I'm not super clear on how they defined their division. I am using integer division:
- $$x\circ y = (x // y) mod p$$, for some prime $$p$$ and $$0\leq x,y \leq p$$
- $$x\circ y = (x // y) mod p$$ if y is odd else (x - y) mod p, for some prime $$p$$ and $$0\leq x,y \leq p$$
Hyperparameters
The default hyperparameters are from the paper, but can be adjusted via the command line when running train.py
Running experiments
To run with default settings, simply run python train.py
. The first time you train on any dataset you have to specify --force_data
.
Arguments:
optimizer args
- "--lr", type=float, default=1e-3
- "--weight_decay", type=float, default=1
- "--beta1", type=float, default=0.9
- "--beta2", type=float, default=0.98
model args
- "--num_heads", type=int, default=4
- "--layers", type=int, default=2
- "--width", type=int, default=128
data args
- "--data_name", type=str, default="perm", choices=[
- "perm_xy", # permutation composition x * y
- "perm_xyx1", # permutation composition x * y * x^-1
- "perm_xyx", # permutation composition x * y * x
- "plus", # x + y
- "minus", # x - y
- "div", # x / y
- "div_odd", # x / y if y is odd else x - y
- "x2y2", # x^2 + y^2
- "x2xyy2", # x^2 + y^2 + xy
- "x2xyy2x", # x^2 + y^2 + xy + x
- "x3xy", # x^3 + y
- "x3xy2y" # x^3 + xy^2 + y ]
- "--num_elements", type=int, default=5 (choose 5 for permutation data, 97 for arithmetic data)
- "--data_dir", type=str, default="./data"
- "--force_data", action="store_true", help="Whether to force dataset creation."
training args
- "--batch_size", type=int, default=512
- "--steps", type=int, default=10**5
- "--train_ratio", type=float, default=0.5
- "--seed", type=int, default=42
- "--verbose", action="store_true"
- "--log_freq", type=int, default=10
- "--num_workers", type=int, default=4