R²SQL
The PyTorch implementation of paper Dynamic Hybrid Relation Network for Cross-Domain Context-Dependent Semantic Parsing. (AAAI 2021)
Requirements
The model is tested in python 3.6 with following requirements:
torch==1.0.0
transformers==2.10.0
sqlparse
pymysql
progressbar
nltk
numpy
six
spacy
All experiments on SParC and CoSQL datasets were run on NVIDIA V100 GPU with 32GB GPU memory.
- Tips: The 16GB GPU memory may appear out-of-memory error.
Setup
The SParC and CoSQL experiments in two different folders, you need to download different datasets from [SParC | CoSQL] to the {sparc|cosql}/data
folder separately. Another related data file could be download from EditSQL. Then, download the database sqlite files from [here] as data/database
.
Download Pretrained BERT model from [here] as model/bert/data/annotated_wikisql_and_PyTorch_bert_param/pytorch_model_uncased_L-12_H-768_A-12.bin
.
Download Glove embeddings file (glove.840B.300d.txt
) and change the GLOVE_PATH
for your own path in all scripts.
Download Reranker models from [SParC reranker | CoSQL reranker] as submit_models/reranker_roberta.pt
, besides the roberta-base model could download from here for ./[sparc|cosql]/local_param/
.
Usage
Train the model from scratch.
./sparc_train.sh
Test the model for the concrete checkpoint:
./sparc_test.sh
then the dev prediction file will be appeared in results
folder, named like save_%d_predictions.json
.
Get the evaluation result from the prediction file:
./sparc_evaluate.sh
the final result will be appeared in results
folder, named *.eval
.
Similarly, the CoSQL experiments could be reproduced in same way.
You could download our trained checkpoint and results in here:
Reranker
If your want train your own reranker model, you could download the training file from here:
- SParC: [reranker training data]
- CoSQL: [reranker training data]
Then you could train, test and predict it:
train:
python -m reranker.main --train --batch_size 64 --epoches 50
test:
python -m reranker.main --test --batch_size 64
predict:
python -m reranker.predict
Improvements
We have improved the origin version (descripted in paper) and got more performance improvements
Compare with the origin version, we have made the following improvements:
- add the self-ensemble strategy for prediction, which use different epoch checkpoint to get final result. In order to easily perform this strategy, we remove the task-related representation in Reranker module.
- remove the decay function in DCRI, we find that DCRI is unstable with decay function, so we let DCRI degenerate into vanilla cross attention.
- replace the BERT-based with RoBERTa-based model for Reranker module.
The final performance comparison on dev as follows:
SParC | CoSQL | |||
---|---|---|---|---|
QM | IM | QM | IM | |
EditSQL | 47.2 | 29.5 | 39.9 | 12.3 |
R²SQL v1 (origin paper) | 54.1 | 35.2 | 45.7 | 19.5 |
R²SQL v2 (this repo) | 54.0 | 35.2 | 46.3 | 19.5 |
R²SQL v2 + ensemble | 55.1 | 36.8 | 47.3 | 20.9 |
Citation
Please star this repo and cite paper if you want to use it in your work.
Acknowledgments
This implementation is based on "Editing-Based SQL Query Generation for Cross-Domain Context-Dependent Questions" EMNLP 2019.