SimCTG - A Contrastive Framework for Neural Text Generation


A Contrastive Framework for Neural Text Generation

Authors: Yixuan Su, Tian Lan, Yan Wang, Dani Yogatama, Lingpeng Kong, and Nigel Collier

This repository contains code, models, and other related resources of our paper A Contrastive Framework for Neural Text Generation.


1. Introduction:

Text generation is of great importance to many natural language processing applications. However, maximization-based decoding methods (e.g. beam search) of neural language models often lead to degenerate solutions---the generated text is unnatural and contains undesirable repetitions. Existing approaches introduce stochasticity via sampling or modify training objectives to decrease probabilities of certain tokens (e.g., unlikelihood training). However, they often lead to solutions that lack coherence. In this work, we show that an underlying reason for model degeneration is the anisotropic distribution of token representations. We present a contrastive solution: (i) SimCTG, a contrastive training objective to calibrate the model's representation space, and (ii) a decoding method---contrastive search---to encourage diversity while maintaining coherence in the generated text. Extensive experiments and analyses on three benchmarks from two languages demonstrate that our proposed approach outperforms state-of-the-art text generation methods as evaluated by both human and automatic metrics.

2. News:

[2022/02/15] SimCTG is publicly released!

3. Citation:

If you find our paper and resources useful, please kindly leave a star and cite our paper. Thanks!

4. Huggingface Models:

Model Name Task Language Training Corpus (Size) Model Size Model Address
cambridgeltl/simctg_wikitext103 Document Generation English Wikitext-103 (529MB) 117M [link]
cambridgeltl/simctg_lccc_dialogue Open-domain Dialogue Generation Chinese LCCC (708MB) 117M [link]
cambridgeltl/simctg_english_wikipedia General Domain Pre-training English Wikipedia (14.11GB) 117M [link]

5. Environment Setup:

python version: 3.8
pip3 install -r requirements.txt

6. Example Usage of Contrastive Search:

6.1. Use SimCTG Pretrained on Wikipedia Corpus:

Here, we show how to use contrastive search to generate the result.

import torch
import sys
from simctg import SimCTGPretraining
# load SimCTG model pretrained on the large-scale Wikipedia corpus
model_path = r'cambridgeltl/simctg_english_wikipedia'
model = SimCTGPretraining(model_path)

# we randomly select a prefix from the dev set of Wikipedia pre-training corpus and prepare the text prefix input
text = r'Insect farming is the practice of raising and breeding insects as livestock, also referred to as minilivestock or micro stock. Insects may be farmed for the commodities'
tokens = model.tokenizer.tokenize(text)
input_ids = model.tokenizer.convert_tokens_to_ids(tokens)
input_ids = torch.LongTensor(input_ids).view(1,-1)

# use contrastive search to generate the result
beam_width, alpha, decoding_len = 5, 0.6, 128
eos_token = '<|endoftext|>'
print (model.fast_contrastive_search(input_ids, beam_width, alpha, decoding_len, eos_token))

   Insect farming is the practice of raising and breeding insects as livestock, also referred to as minilivestock
   or micro stock. Insects may be farmed for the  commodities they produce, such as honey, corn, sorghum, and 
   other crops. In some cases, the production of insects is a way to increase income for the owner or his family. 
   This type of farming has been described as "an economic system that benefits all people regardless of race, sex, 
   or social status" (p.\xa09). A large number of farmers in North America, Europe, and South America have used the 
   method of farming for food production in order to feed their families and livestock. The most common method of 
   farming is by hand-cropping, which consists of cutting a hole in the ground and using a saw

More details on how to pre-train SimCTG on large-scale corpus and the details of the argument setup in contrastive search can be found [here].

6.2. Use Off-the-shelf Language Models from Different Languages:

Importantly, we found that contrastive search can be directly applied to off-the-shelf language models even without contrastive training. The only condition is that the corresponding language should be naturally tokenized by character units. Some examples include Chinese, Japanese, and Korean. In the following, we showcase how to use contrastive search with off-the-shelf Chinese, Japanese, and Korean language models. More analysis of why contrastive search works well on vanilla language models can be found in the Appendix C of our paper.

6.2.1. Chinese Language Model:
import torch
import sys
from simctg import SimCTGPretraining
# load an off-the-shelf Chinese GPT (
model_path = r'uer/gpt2-chinese-cluecorpussmall'
model = SimCTGPretraining(model_path)

# prepare text prefix input
text = r'苹果公司'
tokens = model.tokenizer.tokenize(text)
input_ids = model.tokenizer.convert_tokens_to_ids(tokens)
input_ids = torch.LongTensor(input_ids).view(1,-1)

# (1) use contrastive search to generate the result
beam_width, alpha, decoding_len = 3, 0.6, 128
eos_token = '[SEP]'
print (model.fast_contrastive_search(input_ids, beam_width, alpha, decoding_len, eos_token))

# (2) use nucleus sampling to generate the result
nucleus_p, decoding_len = 0.95, 128
eos_token = '[SEP]'
print (model.nucleus_sampling(input_ids, nucleus_p, decoding_len, eos_token))

# (3) use greedy search to generate the result
decoding_len = 128
eos_token = '[SEP]'
print (model.greedy_search(input_ids, decoding_len, eos_token))

# (4) use beam search to generate the result
beam_width, decoding_len = 10, 128
eos_token = '[SEP]'
print (model.beam_search(input_ids, 10, decoding_len, eos_token))

# ------------------------------------------ Another Example --------------------------------------------- #
# prepare text prefix input
text = r'百节年为首,春节是中华民族最隆重的传统佳节。它不仅集中体现了中华'
tokens = model.tokenizer.tokenize(text)
input_ids = model.tokenizer.convert_tokens_to_ids(tokens)
input_ids = torch.LongTensor(input_ids).view(1,-1)

# (1) use contrastive search to generate the result
beam_width, alpha, decoding_len = 3, 0.6, 128
eos_token = '[SEP]'
print (model.fast_contrastive_search(input_ids, beam_width, alpha, decoding_len, eos_token))

# (2) use nucleus sampling to generate the result
nucleus_p, decoding_len = 0.95, 128
eos_token = '[SEP]'
print (model.nucleus_sampling(input_ids, nucleus_p, decoding_len, eos_token))

# (3) use greedy search to generate the result
decoding_len = 128
eos_token = '[SEP]'
print (model.greedy_search(input_ids, decoding_len, eos_token))

# (4) use beam search to generate the result
beam_width, decoding_len = 10, 128
eos_token = '[SEP]'
print (model.beam_search(input_ids, 10, decoding_len, eos_token))

More details on how to use different decoding methods to generate the result can be found [here].

6.2.2. Japanese Language Model:
import torch
import sys
from simctg import SimCTGPretraining
# load an off-the-shelf Japanese GPT (
model_path = r'colorfulscoop/gpt2-small-ja'
model = SimCTGPretraining(model_path)

   Prepare text prefix input. The prefix is copied from a random Japanese Wikipedia 
   page here (
text = r'臥龍桜(がりゅうざくら)は、岐阜県高山市一之宮町にある一本桜。龍が地'
tokens = model.tokenizer.tokenize(text)
input_ids = model.tokenizer.convert_tokens_to_ids(tokens)
input_ids = torch.LongTensor(input_ids).view(1,-1)

# (1) use contrastive search to generate the result
beam_width, alpha, decoding_len = 5, 0.6, 128
eos_token = model.tokenizer.eos_token
print (model.fast_contrastive_search(input_ids, beam_width, alpha, decoding_len, eos_token))

# (2) use nucleus sampling to generate the result
nucleus_p, decoding_len = 0.95, 128
eos_token = model.tokenizer.eos_token
print (model.nucleus_sampling(input_ids, nucleus_p, decoding_len, eos_token))

# (3) use greedy search to generate the result
decoding_len = 128
eos_token = model.tokenizer.eos_token
print (model.greedy_search(input_ids, decoding_len, eos_token))

# (4) use beam search to generate the result
beam_width, decoding_len = 10, 128
eos_token = model.tokenizer.eos_token
print (model.beam_search(input_ids, 10, decoding_len, eos_token))

[Note] Sadly, I do not speak Japanese (I wish I do!), so I can only judge the quality of the generated text using Google translate. It would be great if anyone could tell me whether the generated text is good or not. Thank you in advance!

6.2.3. Korean Language Model:
import torch
import sys
from simctg import SimCTGPretraining
# load an off-the-shelf Korean GPT (
model_path = r'skt/ko-gpt-trinity-1.2B-v0.5'
model = SimCTGPretraining(model_path)

   Prepare text prefix input.
text = r'인간처럼 생각하고, 행동하는 \'지능\'을 통해 인류가 이제까지 풀지 못했던'
tokens = model.tokenizer.tokenize(text)
input_ids = model.tokenizer.convert_tokens_to_ids(tokens)
input_ids = torch.LongTensor(input_ids).view(1,-1)

# (1) use contrastive search to generate the result
beam_width, alpha, decoding_len = 5, 0.6, 64 
# because this model is pretty large, so we set the generation length (decoding_len) as 64
eos_token = model.tokenizer.eos_token
print (model.fast_contrastive_search(input_ids, beam_width, alpha, decoding_len, eos_token))

# (2) use nucleus sampling to generate the result
nucleus_p, decoding_len = 0.95, 64
eos_token = model.tokenizer.eos_token
print (model.nucleus_sampling(input_ids, nucleus_p, decoding_len, eos_token))

# (3) use greedy search to generate the result
decoding_len = 64
eos_token = model.tokenizer.eos_token
print (model.greedy_search(input_ids, decoding_len, eos_token))

# (4) use beam search to generate the result
# We do not print the result, because beam search stops generation immediately.

[Note] Sadly, I am not a Korean speaker either, so I can only judge the quality of the generated text using Google translate as well. It would be great if anyone could tell me whether the generated text is good or not. Thank you!

7. Document Generation:

The detailed tutorial of experiment on document generation is provided [here].

8. Open-domain Dialogue Generation:

The detailed tutorial of experiment on open-domain dialogue generation provided [here].

9. Large-Scale Pre-training with SimCTG

In addition to fine-tuning on downstream tasks (e.g. document generation and open-domain dialogue generation), we can also use a large-scale general domain corpus (i.e. Wikipedia) to pre-train a SimCTG model. Here, we show the details of how to pre-train SimCTG using a large-scale English Wikipedia corpus.

10. Contact

If you have any questions, feel free to contact me via (ys484 at

    We have updated instructions on how to apply contrastive search on encoder-decoder models (e.g. BART and T5).

    If you are looking for codes of how to apply contrastive search on encoder-decoder models (e.g. BART and T5). Please find more details here (

    opened by yxuansu 8
