中文无监督SimCSE Pytorch实现

Overview

A PyTorch implementation of unsupervised SimCSE

SimCSE: Simple Contrastive Learning of Sentence Embeddings


1. 用法

无监督训练

python train_unsup.py ./data/news_title.txt ./path/to/huggingface_pretrained_model

详细参数

usage: train_unsup.py [-h] [--pretrained PRETRAINED] [--model_out MODEL_OUT]
                      [--num_proc NUM_PROC] [--max_length MAX_LENGTH]
                      [--batch_size BATCH_SIZE] [--epochs EPOCHS] [--lr LR]
                      [--tao TAO] [--device DEVICE]
                      [--display_interval DISPLAY_INTERVAL]
                      [--save_interval SAVE_INTERVAL] [--pool_type POOL_TYPE]
                      [--dropout_rate DROPOUT_RATE]
                      train_file

positional arguments:
  train_file            train text file

optional arguments:
  -h, --help            show this help message and exit
  --pretrained PRETRAINED
                        huggingface pretrained model (default: hfl/chinese-
                        bert-wwm-ext)
  --model_out MODEL_OUT
                        model output path (default: ./model)
  --num_proc NUM_PROC   dataset process thread num (default: 5)
  --max_length MAX_LENGTH
                        sentence max length (default: 100)
  --batch_size BATCH_SIZE
                        batch size (default: 64)
  --epochs EPOCHS       epochs (default: 2)
  --lr LR               learning rate (default: 1e-05)
  --tao TAO             temperature (default: 0.05)
  --device DEVICE       device (default: cuda)
  --display_interval DISPLAY_INTERVAL
                        display interval (default: 50)
  --save_interval SAVE_INTERVAL
                        save interval (default: 100)
  --pool_type POOL_TYPE
                        pool_type (default: cls)
  --dropout_rate DROPOUT_RATE
                        dropout_rate (default: 0.3)

相似文本检索测试

python test_unsup.py
query title:
基金亏损路未尽 后市看法仍偏谨慎

sim title:
基金亏损路未尽 后市看法仍偏谨慎
海通证券:私募对后市看法偏谨慎
连塑基本面不容乐观 后市仍有下行空间
基金谨慎看待后市行情
稳健投资者继续保持观望 市场走势还未明朗
下半年基金投资谨慎乐观
华安基金许之彦:下半年谨慎乐观
楼市主导 期指后市不容乐观
基金公司谨慎看多明年市
前期乐观预期被否 基金重归谨慎

STS-B数据集训练和测试

中文STS-B数据集,详情见这里

# 训练
python train_unsup.py ./data/STS-B/cnsd-sts-train_unsup.txt

# 验证
python eval_unsup.py
模型 STS-B dev STS-B test
hfl/chinese-bert-wwm-ext 0.3326 0.3209
simcse 0.7499 0.6909

与苏剑林的实验结果接近,BERT-P1是0.3465,SIMCSE是0.6904

2. 参考

You might also like...
Comments
  • 怎样实现随机输出

    怎样实现随机输出

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from transformers import BertTokenizer,BertModel,BertConfig
    
    tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
    batch_x = tokenizer(["中国人民公安大学年硕士研究生目录及书目","中国人民公安大学年硕士研究生目录及书目"], return_tensors="pt", padding=True, truncation=True, max_length=128)
    
    class SimCSE(nn.Module):
        def __init__(self, pretrained="bert-base-chinese", pool_type="pooler", dropout_prob=0.3):
            super().__init__()
            conf = BertConfig.from_pretrained(pretrained)
            conf.attention_probs_dropout_prob = dropout_prob
            conf.hidden_dropout_prob = dropout_prob
            self.encoder = BertModel.from_pretrained(pretrained, config=conf)
            assert pool_type in ["cls", "pooler"], "invalid pool_type: %s" % pool_type
            self.pool_type = pool_type
    
        def forward(self, input_ids, attention_mask, token_type_ids):
            output = self.encoder(input_ids,
                                  attention_mask=attention_mask,
                                  token_type_ids=token_type_ids)
            if self.pool_type == "cls":
                output = output[0][:, 0]
            elif self.pool_type == "pooler":
                output = output[1]
            return output
    
    model = SimCSE()
    pred = model(input_ids = batch_x["input_ids"],attention_mask=batch_x["attention_mask"],token_type_ids=batch_x["token_type_ids"])
    
    

    您好,我按照这样的思路输入两次,但是最后输出的pred[0]和pred[1]是完全一样的,所以想请教下您的随机输出是怎么实现的?

    opened by duruiting 4
  • SimCSERetrieval.py  中encode_file() 存在不足

    SimCSERetrieval.py 中encode_file() 存在不足

    你好,看了你的代码受益匪浅,但是在代码SimCSERetrieval.py 中encode_file() 有一些问题如下:

      if len(texts) >= self.batch_size:
          vecs = self.encode_batch(texts)
          vecs = vecs / vecs.norm(dim=1, keepdim=True)
          all_texts.extend(texts)
          all_ids.extend(idxs)
          all_vecs.append(vecs.cpu())
          texts = []
          idxs = []
    

    如果 fname中的样本数N, N%self.batch_size = d,那么会遗漏d个样本。 我在测试中使用样本数为: N = 559 batch_size = 64 得到all_vecs.shape[0]=512

    opened by shencangblue 0
Owner
null