State-of-the-art language models can match human performance on many tasks

Overview

Status: Archive (code is provided as-is, no updates expected)

Grade School Math

[Blog Post] [Paper]

State-of-the-art language models can match human performance on many tasks, but they still struggle to robustly perform multi-step mathematical reasoning. To diagnose the failures of current models and support research, we're releasing GSM8K, a dataset of 8.5K high quality linguistically diverse grade school math word problems. We find that even the largest transformer models fail to achieve high test performance, despite the conceptual simplicity of this problem distribution.

Dataset Details

GSM8K consists of 8.5K high quality grade school math problems created by human problem writers. We segmented these into 7.5K training problems and 1K test problems. These problems take between 2 and 8 steps to solve, and solutions primarily involve performing a sequence of elementary calculations using basic arithmetic operations (+ - / *) to reach the final answer. A bright middle school student should be able to solve every problem.

The raw data files can be found in:

  • grade_school_math/data/train.jsonl
  • grade_school_math/data/test.jsonl

Each line of those files corresponds to a single grade school math problem, saved as a json dictionary (with a "question" key and an "answer" key). The answer is formatted such that it uses calculation annotations and so that the final numeric solution is the final line of the solution, preceded by ####.

Calculation Annotations

Our models frequently fail to accurately perform calculations. Although larger models make fewer arithmetic mistakes than smaller models, this remains a common source of errors. To mitigate this issue, we train our models to use a calculator by injecting calculation annotations into the training set. At training time, we simply finetune on this language data as is. At test time, a calculator will override sampling when the model chooses to use these annotations. An example implementation of the calculator sampling can be found in calculator.py.

If you would like to remove the calculator annotations, simply remove any string that starts with << and ends with >>.

Solution Extracting

To extract the final numeric solution for a particular question, simply parse the completion to extract the numeric value immediately following the #### token. Some example python code to do so is shown in dataset.py:is_correct.

Socratic Dataset

During our research, we also investigated a modified solution format that injects automatically generated "Socratic subquestions" before each step. Although we ultimately did not use this format for any experiments in the paper, we make this data available to anyone who is interested.

We show an example below, with the socratic subquestions in bold:

A carnival snack booth made $50 selling popcorn each day. It made three times as much selling cotton candy. For a 5-day activity, the booth has to pay $30 rent and $75 for the cost of the ingredients. How much did the booth earn for 5 days after paying the rent and the cost of ingredients?
How much did the booth make selling cotton candy each day? ** The booth made $50 x 3 = $<<50*3=150>>150 selling cotton candy each day.
How much did the booth make in a day? ** In a day, the booth made a total of $150 + $50 = $<<150+50=200>>200.
How much did the booth make in 5 days? ** In 5 days, they made a total of $200 x 5 = $<<200*5=1000>>1000.
How much did the booth have to pay? ** The booth has to pay a total of $30 + $75 = $<<30+75=105>>105.
How much did the booth earn after paying the rent and the cost of ingredients? ** Thus, the booth earned $1000 - $105 = $<<1000-105=895>>895.

We generated each Socratic subquestion by conditioning on each ground truth (contractor-provided) step in a solution, using a model specifically finetuned for this task (on around 800 examples). To construct the full Socratic dataset, each step in the solution was prefixed by the model-generated Socratic subquestion. Steps were otherwise left untouched.

These data files can be found in:

  • grade_school_math/data/train_socratic.jsonl
  • grade_school_math/data/test_socratic.jsonl

View Model Solutions

For each test question, we provide solutions generated from 6B finetuning, 6B verification, 175B finetuning and 175B verification. This data can be found in:

  • grade_school_math/data/example_model_solutions.jsonl

To view these results problem-by-problem, run:

python view_model_solutions.py

Citation

Please use the below BibTeX entry to cite this dataset:

@article{cobbe2021gsm8k,
  title={Training Verifiers to Solve Math Word Problems},
  author={Cobbe, Karl and Kosaraju, Vineet and Bavarian, Mohammad and Hilton, Jacob and Nakano, Reiichiro and Hesse, Christopher and Schulman, John},
  journal={arXiv preprint arXiv:2110.14168},
  year={2021}
}

Usage

We present a basic example of training a GPT2 sized model and using the calculator in the sampling process. We include this code for illustrative purposes only. This pipeline was not used for any experiments in the paper.

Training a Model

python train.py

Sampling from the Model

python sample.py

The core calculator sampling logic can be found in calculator.py:sample. Note that this code is inefficient as implemented. Specifically, the function does not support batches, and does not cache activations from previous tokens.

You might also like...
TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

TorchMultimodal (Alpha Release) Introduction TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

 Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021
Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021

Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021 The code for training mCOLT/mRASP2, a multilingua

M2MRF: Many-to-Many Reassembly of Features for Tiny Lesion Segmentation in Fundus Images

M2MRF: Many-to-Many Reassembly of Features for Tiny Lesion Segmentation in Fundus Images This repo is the official implementation of paper "M2MRF: Man

State of the Art Neural Networks for Deep Learning

pyradox This python library helps you with implementing various state of the art neural networks in a totally customizable fashion using Tensorflow 2

Code for paper "A Critical Assessment of State-of-the-Art in Entity Alignment" (https://arxiv.org/abs/2010.16314)

A Critical Assessment of State-of-the-Art in Entity Alignment This repository contains the source code for the paper A Critical Assessment of State-of

State of the art Semantic Sentence Embeddings

Contrastive Tension State of the art Semantic Sentence Embeddings Published Paper · Huggingface Models · Report Bug Overview This is the official code

tsai is an open-source deep learning package built on top of Pytorch & fastai focused on state-of-the-art techniques for time series classification, regression and forecasting.
tsai is an open-source deep learning package built on top of Pytorch & fastai focused on state-of-the-art techniques for time series classification, regression and forecasting.

Time series Timeseries Deep Learning Pytorch fastai - State-of-the-art Deep Learning with Time Series and Sequences in Pytorch / fastai

Deep Text Search is an AI-powered multilingual text search and recommendation engine with state-of-the-art transformer-based multilingual text embedding (50+ languages).
Deep Text Search is an AI-powered multilingual text search and recommendation engine with state-of-the-art transformer-based multilingual text embedding (50+ languages).

Deep Text Search - AI Based Text Search & Recommendation System Deep Text Search is an AI-powered multilingual text search and recommendation engine w

Comments
  • How can I replicate these results with the GPT-3 API?

    How can I replicate these results with the GPT-3 API?

    Hi! I'm trying to replicate your results with the GPT-3 API. This is how I've preprocessed the train file: {"prompt": "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?\n\n###\n\n", "completion": " 72"} And here's the first line of the valid file: {"prompt": "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\n\n###\n\n", "completion": " 18"}

    I'm trying to replicate the result from the paper which doesn't use the intermediate steps:

    If we instead finetune a 6B model to directly output the final answer without any intermediate steps, performance drops drastically from 20.6% to 5.2%.

    I run this command: openai api fine_tunes.create -t train.jsonl -v valid.jsonl -m curie --n_epochs 20

    The training loss and accuracy looks great, but the validation accuracy just goes crazy throughout all of training:

    13488,10571584,107904,0.006697522398264077,1.0,1.0,0.2868101323364219,0.0,0.0
    13567,10633656,108536,0.004670373035388367,1.0,1.0,0.40770788124678037,0.0,0.0
    13646,10695280,109168,0.006482512512579658,1.0,1.0,0.23502722387660283,0.125,0.125
    13722,10755664,109776,0.007619243074462266,1.0,1.0,0.2898588183223495,0.125,0.125
    13800,10817728,110400,0.007864901396759683,1.0,1.0,0.3764918240812498,0.0,0.0
    13878,10877360,111024,0.007001778227947788,1.0,1.0,0.34287615764819,0.0,0.0
    13955,10938520,111640,0.005789751697015372,1.0,1.0,0.343928185654087,0.125,0.125
    14033,11000392,112264,0.005374998155119428,1.0,1.0,0.2626242895447271,0.0,0.0
    14110,11061552,112880,0.007522114686767156,1.0,1.0,0.303700584687286,0.0,0.0
    14189,11122920,113512,0.007042250184718232,1.0,1.0,0.25505893241963656,0.25,0.25
    14267,11184344,114136,0.005404598385505561,1.0,1.0,0.4671648049487169,0.0,0.0
    14344,11245504,114752,0.00575376992666942,1.0,1.0,0.3507942277975161,0.125,0.125
    14424,11306624,115392,0.00572749277984494,1.0,1.0,0.23907356193368778,0.25,0.25
    14502,11368880,116016,0.006820225469565155,1.0,1.0,0.321503115529883,0.25,0.25
    14579,11428760,116632,0.006588759870975103,1.0,1.0,0.33556188792293684,0.0,0.0
    14657,11489160,117256,0.008451912472405503,1.0,1.0,0.26277955851476575,0.125,0.125
    14735,11550648,117880,0.005669759222337177,1.0,1.0,0.29857896823688596,0.0,0.0
    14813,11613672,118504,0.0038765392963223673,1.0,1.0,0.2643208104990911,0.0,0.0
    14893,11674920,119144,0.005361676043338958,1.0,1.0,0.3359258272212552,0.125,0.125
    14971,11735512,119768,0.005775294550614619,1.0,1.0,0.3953682584280405,0.0,0.1111111111111111
    15049,11796040,120392,0.006509018494322124,1.0,1.0,0.38131386485963575,0.0,0.0
    15128,11857408,121024,0.006104460320729444,1.0,1.0,0.36498191312040756,0.125,0.125
    15205,11918056,121640,0.008038989705175963,1.0,1.0,0.34087099317847736,0.0,0.0
    15282,11980816,122256,0.0049923646901917085,1.0,1.0,0.3232976954536787,0.125,0.125
    15361,12042376,122888,0.004565948148458752,1.0,1.0,0.43752457097146086,0.0,0.0
    15437,12103336,123496,0.008198144564371599,1.0,1.0,0.3357829815657358,0.125,0.125
    

    I'm also not sure how in some iterations validation_sequence_accuracy and validation_token_accuracy even though all validation completions are of length 1 (I pruned all the prompts where the answer were longer).

    I'd be grateful for any advice here. I've tried different LRs but that didn't seem to help.

    Thanks!

    opened by ofirpress 2
  • Verifiers

    Verifiers

    @kcobbe @vineetsk10 The blog post mentions about 'Verifiers' that improve the model performance. In the data , for instance there is this one which indicates that answer is not correct ==> "6b_verification": {"is_correct": false, etc...] which I believe are the verifiers..

    when we create new datasets to fine tune the model, should we by any convention name a verifier key as "is_correct" ? or can it be anything arbitrary like "correct_answer: false" that GPT3 (api) would understand ?

    opened by nsankar 1
  • The third example problem is incorrect

    The third example problem is incorrect

    image

    "With one person drinking 5, that brings the total drank to 5+9+8+3=25" is incorrect according to the problem setup and should instead be 5+9+8=22.

    The accompanying paper estimates that fewer than 2% of problems have errors like the one above. Unless this example was explicitly cherry-picked to demonstrate an incorrect problem, it is either exceedingly unlucky that this problem made it into the showcased examples or the frequency of incorrect examples is likely higher.

    opened by melaniebeck 0
Owner
OpenAI
OpenAI
Quickly comparing your image classification models with the state-of-the-art models (such as DenseNet, ResNet, ...)

Image Classification Project Killer in PyTorch This repo is designed for those who want to start their experiments two days before the deadline and ki

null 349 Dec 8, 2022
A selection of State Of The Art research papers (and code) on human locomotion (pose + trajectory) prediction (forecasting)

A selection of State Of The Art research papers (and code) on human trajectory prediction (forecasting). Papers marked with [W] are workshop papers.

Karttikeya Manglam 40 Nov 18, 2022
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow ?? Transformers provides thousands of pretrained mo

Hugging Face 77.2k Jan 2, 2023
A modular, research-friendly framework for high-performance and inference of sequence models at many scales

T5X T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of

Google Research 1.1k Jan 8, 2023
QuickAI is a Python library that makes it extremely easy to experiment with state-of-the-art Machine Learning models.

QuickAI is a Python library that makes it extremely easy to experiment with state-of-the-art Machine Learning models.

null 152 Jan 2, 2023
LaneDet is an open source lane detection toolbox based on PyTorch that aims to pull together a wide variety of state-of-the-art lane detection models

LaneDet is an open source lane detection toolbox based on PyTorch that aims to pull together a wide variety of state-of-the-art lane detection models. Developers can reproduce these SOTA methods and build their own methods.

TuZheng 405 Jan 4, 2023
PaddleViT: State-of-the-art Visual Transformer and MLP Models for PaddlePaddle 2.0+

PaddlePaddle Vision Transformers State-of-the-art Visual Transformer and MLP Models for PaddlePaddle ?? PaddlePaddle Visual Transformers (PaddleViT or

null 1k Dec 28, 2022
LWCC: A LightWeight Crowd Counting library for Python that includes several pretrained state-of-the-art models.

LWCC: A LightWeight Crowd Counting library for Python LWCC is a lightweight crowd counting framework for Python. It wraps four state-of-the-art models

Matija Teršek 39 Dec 28, 2022
PySlowFast: video understanding codebase from FAIR for reproducing state-of-the-art video models.

PySlowFast PySlowFast is an open source video understanding codebase from FAIR that provides state-of-the-art video classification models with efficie

Meta Research 5.3k Jan 3, 2023