WRENCH: Weak supeRvision bENCHmark

Overview

made-with-python Maintenance license repo size Total lines visitors GitHub stars GitHub forks Arxiv

🔧 What is it?

Wrench is a benchmark platform containing diverse weak supervision tasks. It also provides a common and easy framework for development and evaluation of your own weak supervision models within the benchmark.

For more information, checkout our publications:

If you find this repository helpful, feel free to cite our publication:

@misc{zhang2021wrench,
      title={WRENCH: A Comprehensive Benchmark for Weak Supervision}, 
      author={Jieyu Zhang and Yue Yu and Yinghao Li and Yujing Wang and Yaming Yang and Mao Yang and Alexander Ratner},
      year={2021},
      eprint={2109.11377},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

🔧 What is weak supervision?

Weak Supervision is a paradigm for automated training data creation without manual annotations.

For a brief overview, please check out this blog.

To track recent advances in weak supervision, please follow this repo.

🔧 Installation

[1] Install anaconda: Instructions here: https://www.anaconda.com/download/

[2] Clone the repository:

git clone https://github.com/JieyuZ2/wrench.git
cd wrench

[3] Create virtual environment:

conda env create -f environment.yml
source activate wrench

If this not working or you want to use only a subset of modules of Wrench, check out this wiki page

🔧 Available Datasets

The datasets can be downloaded via this.

or via command line

pip install gdown 
gdown https://drive.google.com/uc?id=19wMFmpoo_0ORhBzB6n16B1nRRX508AnJ
unzip datasets.zip
rm datasets.zip 

A documentation of dataset format and usage can be found in this wiki-page

classification:

Name Task # class # LF # train # validation # test data source LF source
Census income clasification 2 83 10083 5561 16281 link link
Youtube spam clasification 2 10 1586 120 250 link link
SMS spam clasification 2 73 4571 500 500 link link
IMDB sentiment clasification 2 8 20000 2500 2500 link link
Yelp sentiment clasification 2 8 30400 3800 3800 link link
AGNews topic clasification 4 9 96000 12000 12000 link link
TREC question classification 6 68 4965 500 500 link link
Spouse relation classification 2 9 22254 2801 2701 link link
SemEval relation classification 9 164 1749 200 692 link link
CDR bio relation classification 2 33 8430 920 4673 link link
Chemprot chemical relation classification 10 26 12861 1607 1607 link link
Commercial video frame classification 2 4 64130 9479 7496 link link
Tennis Rally video frame classification 2 6 6959 746 1098 link link
Basketball video frame classification 2 4 17970 1064 1222 link link

sequence tagging:

Name # class # LF # train # validation # test data source LF source
CoNLL-03 4 16 14041 3250 3453 link link
WikiGold 4 16 1355 169 170 link link
OntoNotes 5.0 18 17 115812 5000 22897 link link
BC5CDR 2 9 500 500 500 link link
NCBI-Disease 1 5 592 99 99 link link
Laptop-Review 1 3 2436 609 800 link link
MIT-Restaurant 8 16 7159 500 1521 link link
MIT-Movies 12 7 9241 500 2441 link link

The detailed documentation is coming soon.

🔧 Available Models

classification:

Model Model Type Reference Link to Wrench
Majority Voting Label Model -- link
Weighted Majority Voting Label Model -- link
Dawid-Skene Label Model link link
Data Progamming Label Model link link
MeTaL Label Model link link
FlyingSquid Label Model link link
Logistic Regression End Model -- link
MLP End Model -- link
BERT End Model link link
COSINE End Model link link
Denoise Joint Model link link

sequence tagging:

Model Model Type Reference Link to Wrench
Hidden Markov Model Label Model link link
Conditional Hidden Markov Model Label Model link link
LSTM-CNNs-CRF End Model link link
BERT-CRF End Model link link
LSTM-ConNet Joint Model link link
BERT-ConNet Joint Model link link

classification-to-sequence-tagging wrapper:

Wrench also provides a SeqLabelModelWrapper that adaptes label model for classification task to sequence tagging task.

🔧 Quick examples

🔧 Label model with parallel grid search for hyper-parameters

import logging
import numpy as np
import pprint

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.search import grid_search
from wrench import labelmodel 
from wrench.evaluation import AverageMeter

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)


#### Specify the hyper-parameter search space for grid search
search_space = {
    'Snorkel': {
        'lr': np.logspace(-5, -1, num=5, base=10),
        'l2': np.logspace(-5, -1, num=5, base=10),
        'n_epochs': [5, 10, 50, 100, 200],
    }
}

#### Initialize label model
label_model_name = 'Snorkel'
label_model = getattr(labelmodel, label_model_name)

#### Search best hyper-parameters using validation set in parallel
n_trials = 100
n_repeats = 5
target = 'acc'
searched_paras = grid_search(label_model(), dataset_train=train_data, dataset_valid=valid_data,
                             metric=target, direction='auto', search_space=search_space[label_model_name],
                             n_repeats=n_repeats, n_trials=n_trials, parallel=True)

#### Evaluate the label model with searched hyper-parameters and average meter
meter = AverageMeter(names=[target])
for i in range(n_repeats):
    model = label_model(**searched_paras)
    history = model.fit(dataset_train=train_data, dataset_valid=valid_data)
    metric_value = model.test(test_data, target)
    meter.update(target=metric_value)

metrics = meter.get_results()
pprint.pprint(metrics)

For detailed guidance of grid_search, please check out this wiki page.

🔧 Run a standard supervised learning pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)


#### Train a MLP classifier
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50
target='acc'

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, device=device, metric=target, 
                    patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

🔧 Build a two-stage weak supervision pipeline

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.endmodel import MLPModel
from wrench.labelmodel import MajorityVoting

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Extract data features using pre-trained BERT model and cache it
extract_fn = 'bert'
model_name = 'bert-base-cased'
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=True, extract_fn=extract_fn,
                                                 cache_name=extract_fn, model_name=model_name)

#### Generate soft training label via a label model
#### The weak labels provided by supervision sources are alreadly encoded in dataset object
label_model = MajorityVoting()
label_model.fit(train_data, valid_data)
soft_label = label_model.predict_proba(train_data)


#### Train a MLP classifier with soft label
device = torch.device('cuda:0')
n_steps = 100000
batch_size = 128
test_batch_size = 1000 
patience = 200
evaluation_step = 50
target='acc'

model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
history = model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=soft_label, 
                    device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

#### Evaluate the trained model
metric_value = model.test(test_data, target)

#### We can also train a MLP classifier with hard label
from snorkel.utils import probs_to_preds
hard_label = probs_to_preds(soft_label)
model = MLPModel(n_steps=n_steps, batch_size=batch_size, test_batch_size=test_batch_size)
model.fit(dataset_train=train_data, dataset_valid=valid_data, y_train=hard_label, 
          device=device, metric=target, patience=patience, evaluation_step=evaluation_step)

🔧 Procedural labeling function generator

import logging
import torch

from wrench.dataset import load_dataset
from wrench.logging import LoggingHandler
from wrench.synthetic import ConditionalIndependentGenerator, NGramLFGenerator
from wrench.labelmodel import FlyingSquid

#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
logger = logging.getLogger(__name__)


#### Generate synthetic dataset
generator = ConditionalIndependentGenerator(
    n_class=2,
    n_lfs=10,
    alpha=0.75, # mean accuracy
    beta=0.1, # mean propensity
    alpha_radius=0.2, # radius of accuracy
    beta_radius=0.1 # radius of propensity
)
train_data = generator.generate_split('train', 10000)
valid_data = generator.generate_split('valid', 1000)
test_data = generator.generate_split('test', 1000)

#### Evaluate label model on synthetic dataset
label_model = FlyingSquid()
label_model.fit(dataset_train=train_data, dataset_valid=valid_data)
target_value = label_model.test(test_data, metric_fn='auc')

#### Load dataset 
dataset_home = '../datasets'
data = 'youtube'

#### Load real-world dataset
train_data, valid_data, test_data = load_dataset(dataset_home, data, extract_feature=False)

#### Generate procedural labeling functions
generator = NGramLFGenerator(dataset=train_data, min_acc_gain=0.1, min_support=0.01, ngram_range=(1, 2))
applier = generator.generate(mode='correlated', n_lfs=10)
L_test = applier.apply(test_data)
L_train = applier.apply(train_data)


#### Evaluate label model on real-world dataset with semi-synthetic labeling functions
label_model = FlyingSquid()
label_model.fit(dataset_train=L_train, dataset_valid=valid_data)
target_value = label_model.test(L_test, metric_fn='auc')

🔧 Contact

Contact person: Jieyu Zhang, [email protected]

Don't hesitate to send us an e-mail if you have any question.

We're also open to any collaboration!

🔧 Contributing Dataset and Model

We sincerely welcome any contribution to the datasets or models!

Comments
  • ModuleNotFoundError: No module named 'tokenizations'

    ModuleNotFoundError: No module named 'tokenizations'

    Hi, I faced some problems when trying to install the library. I tried to use pip install ws-benchmark==1.1.2rc0 as suggested in the document, the installation was successful but when I run the code I faced the error ModuleNotFoundError: No module named 'tokenizations'. Then I tried to clone the repository and create the environment using conda env create -f environment.yml, but the installation failed due to the following error FileNotFoundError: [Errno 2] No such file or directory: '/home/naiqing/miniconda3/envs/wrench/lib/python3.6/site-packages/huggingface_hub-0.0.16-py3.8.egg'. Do you have ideas on what might cause the problem and how can I fix it?

    opened by Gnaiqing 12
  • Is there a limitation of using dataset for different algs?

    Is there a limitation of using dataset for different algs?

    Firstly, thank you for building this awesome benchmark. While I try the example with different datasets (e.g., I try astra with youtube dataset), I got some errors like this,

        loss = cross_entropy_with_probs(predict_l, batch['labels'].to(device))
    KeyError: 'labels'
    

    Can this be fixed?

    opened by mrbeann 8
  • Python Package Installation Fails

    Python Package Installation Fails

    Installing ws-benchmark python package fails due to dependency conflict (see stack trace below).

    Tested on system:

    • OS: ubuntu
    • Python: 3.8.13
    • Clean VE

    Command to replicate:

    • pip install ws-benchmark

    Stack Trace:

    ERROR: Cannot install ws-benchmark and ws-benchmark==1.1.1 because these package versions have conflicting dependencies.
    
    The conflict is caused by:
        ws-benchmark 1.1.1 depends on networkx==2.7
        snorkel 0.9.7 depends on networkx<2.4 and >=2.2
    
    To fix this you could try to:
    1. loosen the range of package versions you've specified
    2. remove package versions to allow pip attempt to solve the dependency conflict
    
    ERROR: ResolutionImpossible: for help visit https://pip.pypa.io/en/latest/topics/dependency-resolution/#dealing-with-dependency-conflicts
    
    opened by bradleyfowler123 4
  • Using Multiple GPUs

    Using Multiple GPUs

    Hi,

    Is it possible to use multiple GPUs for the experiments, or will it be in future releases? It would be a nice feature if it is not possible right now.

    Best regards.

    opened by tolgayan 4
  • Running scripts

    Running scripts

    Hi, I am trying to run some models on the IMDB dataset.

    MLP:

    import logging
    import torch
    import numpy as np
    from wrench.dataset import load_dataset
    from wrench.labelmodel import Snorkel
    from wrench.logging import LoggingHandler
    from wrench.search import grid_search
    from wrench.endmodel import EndClassifierModel
    
    #### Just some code to print debug information to stdout
    logging.basicConfig(format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    
    logger = logging.getLogger(__name__)
    
    device = torch.device('cuda')
    
    if __name__ == '__main__':
        #### Load dataset
        dataset_path = '../datasets/'
        data = "imdb"
        bert_model_name = "bert-base-cased"
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            data,
            extract_feature=True,
            extract_fn='bert',  # extract bert embedding
            model_name=bert_model_name,
            cache_name='bert',
            dataset_type="TextDataset"
        )
    
        #### Run label model: Snorkel
        label_model = Snorkel(
            lr=0.005,
            l2=0,
            n_epochs=200,
            seed=123
        )
        label_model.fit(
            dataset_train=train_data,
            dataset_valid=valid_data
        )
    
        acc = label_model.test(test_data, 'acc')
        logger.info(f'label model test acc: {acc}')
    
        #### Filter out uncovered training data
        aggregated_hard_labels = label_model.predict(train_data)
        aggregated_soft_labels = label_model.predict_proba(train_data)
    
        #### Search Space
        search_space = {
            'optimizer_lr': np.logspace(-5, -1, num=5, base=10),
            'optimizer_weight_decay': np.logspace(-5, -1, num=5, base=10),
        }
    
        #### Initialize the model: MLP
        model = EndClassifierModel(
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
            backbone='MLP',
            optimizer='Adam'
        )
    
        #### Search best hyper-parameters using validation set in parallel
        n_trials = 20
        n_repeats = 1
        searched_paras = grid_search(
            model,
            dataset_train=train_data,
            y_train=aggregated_soft_labels,
            dataset_valid=valid_data,
            metric='acc',
            direction='auto',
            search_space=search_space,
            n_repeats=n_repeats,
            n_trials=n_trials,
            parallel=True,
            device=device,
        )
    
    
        #### Run end model: MLP
        model = EndClassifierModel(
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
            backbone='MLP',
            optimizer='Adam',
            **searched_paras
        )
        model.fit(
            dataset_train=train_data,
            y_train=aggregated_soft_labels,
            dataset_valid=valid_data,
            metric='acc',
            device=device
        )
    
        logger.info(model.predict(test_data).tolist())
    
        acc = model.test(test_data, 'acc')
        logger.info(f'end model (MLP) test acc: {acc}')
    
    

    for which I am getting the following output:

    100%|██████████| 20000/20000 [00:00<00:00, 902651.16it/s]
    100%|██████████| 2500/2500 [00:00<00:00, 852639.45it/s]
    100%|██████████| 2500/2500 [00:00<00:00, 829503.99it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|██████████| 20000/20000 [1:42:45<00:00,  3.24it/s]  
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|██████████| 2500/2500 [13:24<00:00,  3.11it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|██████████| 2500/2500 [13:50<00:00,  3.01it/s]
    [I 2021-10-23 22:24:36,807] A new study created in memory with name: no-name-9e4ad09c-ea4a-4ee8-80c2-7633429e4038
    huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
    To disable this warning, you can either:
            - Avoid using `tokenizers` before the fork if possible
            - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/train.json
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/valid.json
    2021-10-23 20:14:19 - loading data from ../datasets/imdb/test.json
    2021-10-23 21:57:10 - saving features into ../datasets/imdb/train_bert.pkl
    2021-10-23 22:10:40 - saving features into ../datasets/imdb/valid_bert.pkl
    2021-10-23 22:24:36 - saving features into ../datasets/imdb/test_bert.pkl
    2021-10-23 22:24:36 - label model test acc: 0.716
    huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
    To disable this warning, you can either:
            - Avoid using `tokenizers` before the fork if possible
            - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
    100%|██████████| 1/1 [00:37<00:00, 37.48s/it]
    [I 2021-10-23 22:25:14,563] Trial 0 finished with value: 0.5012 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.0001}. Best is trial 0 with value: 0.5012.
    100%|██████████| 1/1 [00:23<00:00, 23.70s/it]
    [I 2021-10-23 22:25:38,448] Trial 1 finished with value: 0.496 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.1}. Best is trial 0 with value: 0.5012.
    100%|██████████| 1/1 [00:14<00:00, 14.53s/it]
    [I 2021-10-23 22:25:53,171] Trial 2 finished with value: 0.5004 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.001}. Best is trial 0 with value: 0.5012.
    100%|██████████| 1/1 [00:43<00:00, 43.73s/it]
    [I 2021-10-23 22:26:37,071] Trial 3 finished with value: 0.5088 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.001}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:18<00:00, 18.85s/it]
    [I 2021-10-23 22:26:56,161] Trial 4 finished with value: 0.488 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.1}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:38<00:00, 38.81s/it]
    [I 2021-10-23 22:27:35,214] Trial 5 finished with value: 0.4948 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.1}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:38<00:00, 38.15s/it]
    [I 2021-10-23 22:28:13,614] Trial 6 finished with value: 0.5024 and parameters: {'optimizer_lr': 0.01, 'optimizer_weight_decay': 0.01}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:15<00:00, 15.47s/it]
    [I 2021-10-23 22:28:29,335] Trial 7 finished with value: 0.4996 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 1e-05}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:22<00:00, 22.49s/it]
    [I 2021-10-23 22:28:52,093] Trial 8 finished with value: 0.5008 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 1e-05}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:40<00:00, 40.25s/it]
    [I 2021-10-23 22:29:32,594] Trial 9 finished with value: 0.5008 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.0001}. Best is trial 3 with value: 0.5088.
    100%|██████████| 1/1 [00:39<00:00, 39.06s/it]
    [I 2021-10-23 22:30:11,902] Trial 10 finished with value: 0.5116 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 1e-05}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:43<00:00, 43.46s/it]
    [I 2021-10-23 22:30:55,531] Trial 11 finished with value: 0.4912 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 1e-05}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:23<00:00, 23.41s/it]
    [I 2021-10-23 22:31:19,095] Trial 12 finished with value: 0.4956 and parameters: {'optimizer_lr': 0.001, 'optimizer_weight_decay': 0.01}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:22<00:00, 22.12s/it]
    [I 2021-10-23 22:31:41,374] Trial 13 finished with value: 0.492 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.01}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:15<00:00, 15.78s/it]
    [I 2021-10-23 22:31:57,283] Trial 14 finished with value: 0.5044 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.0001}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:37<00:00, 37.28s/it]
    [I 2021-10-23 22:32:34,728] Trial 15 finished with value: 0.488 and parameters: {'optimizer_lr': 1e-05, 'optimizer_weight_decay': 0.001}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:16<00:00, 16.04s/it]
    [I 2021-10-23 22:32:50,934] Trial 16 finished with value: 0.4924 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.001}. Best is trial 10 with value: 0.5116.
    100%|██████████| 1/1 [00:19<00:00, 19.65s/it]
    [I 2021-10-23 22:33:10,753] Trial 17 finished with value: 0.5156 and parameters: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.1}. Best is trial 17 with value: 0.5156.
    100%|██████████| 1/1 [00:15<00:00, 15.41s/it]
    [I 2021-10-23 22:33:26,345] Trial 18 finished with value: 0.5068 and parameters: {'optimizer_lr': 0.01, 'optimizer_weight_decay': 0.001}. Best is trial 17 with value: 0.5156.
    100%|██████████| 1/1 [00:16<00:00, 16.75s/it]
    [I 2021-10-23 22:33:43,222] Trial 19 finished with value: 0.498 and parameters: {'optimizer_lr': 0.0001, 'optimizer_weight_decay': 0.01}. Best is trial 17 with value: 0.5156.
    [TRAIN]:  15%|█████▌                               | 1499/10000 [00:21<02:04, 68.19steps/s, loss=4.02, val_acc=0.5, best_val_acc=0.508, best_step=500]
    2021-10-23 22:33:43 - [END: BEST VAL / PARAMS] Best value: 0.5156, Best paras: {'optimizer_lr': 0.1, 'optimizer_weight_decay': 0.1}
    2021-10-23 22:33:43 - 
    ==========[hyper parameters]==========
    {
        "batch_size": 8,
        "real_batch_size": 8,
        "test_batch_size": 8,
        "n_steps": 10000,
        "grad_norm": -1,
        "use_lr_scheduler": false,
        "binary_mode": false
    }
    ==========[optimizer config]==========
    {
        "name": "Adam",
        "paras": {
            "lr": 0.1,
            "weight_decay": 0.1
        }
    }
    ==========[backbone config]==========
    {
        "name": "MLP",
        "paras": {
            "hidden_size": 100,
            "dropout": 0.0
        }
    }
    
    2021-10-23 22:34:09 - [INFO] early stop @ step 1500!
    2021-10-23 22:34:09 - [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    2021-10-23 22:34:09 - end model (MLP) test acc: 0.5004
    

    COSINE:

    import logging
    import torch
    from wrench.dataset import load_dataset
    from wrench.logging import LoggingHandler
    from wrench.labelmodel import Snorkel
    from wrench.endmodel import Cosine
    
    #### Just some code to print debug information to stdout
    logging.basicConfig(format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO,
                        handlers=[LoggingHandler()])
    
    logger = logging.getLogger(__name__)
    
    device = torch.device('cuda')
    
    if __name__ == '__main__':
        #### Load dataset
        dataset_path = '../datasets/'
        data = "imdb"
        bert_model_name = "bert-base-cased"
        train_data, valid_data, test_data = load_dataset(
            dataset_path,
            data,
            extract_feature=True,
            extract_fn='bert',  # extract bert embedding
            model_name=bert_model_name,
            cache_name='bert',
            dataset_type="TextDataset"
        )
    
        #### Run label model: Snorkel
        label_model = Snorkel(
            lr=0.005,
            l2=0,
            n_epochs=200,
            seed=123
        )
        label_model.fit(
            dataset_train=train_data,
            dataset_valid=valid_data
        )
    
        acc = label_model.test(test_data, 'acc')
        logger.info(f'label model test acc: {acc}')
    
        #### Filter out uncovered training data
        aggregated_hard_labels = label_model.predict(train_data)
        aggregated_soft_labels = label_model.predict_proba(train_data)
    
    
        # COSINE
        model = Cosine(
            teacher_update=100,
            margin=1.0,
            thresh=0.6,
            lr=1e-5,
            mu=1.0,
            lamda=0.05,
            backbone='BERT',
            backbone_model_name=bert_model_name,
            batch_size=8,
            real_batch_size=8,
            test_batch_size=8,
        )
    
        model.fit(dataset_train=train_data,
                  dataset_valid=valid_data,
                  y_train=aggregated_hard_labels,
                  evaluation_step=10,
                  metric='acc',
                  patience=50,
                  device=device)
    
        acc = model.test(test_data, 'acc')
    
        logger.info(model.predict(test_data))
    
        logger.info(f'end model (COSINE) test acc: {acc}')
    

    for which I am getting the following output:

    100%|██████████| 20000/20000 [00:00<00:00, 899119.81it/s]
    100%|██████████| 2500/2500 [00:00<00:00, 423667.07it/s]
    100%|██████████| 2500/2500 [00:00<00:00, 802645.44it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|██████████| 20000/20000 [1:47:44<00:00,  3.09it/s]  
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|██████████| 2500/2500 [14:22<00:00,  2.90it/s]
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    100%|██████████| 2500/2500 [13:33<00:00,  3.07it/s] 
    Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
    - This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    - This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    [TRAIN] COSINE pretrain stage:   5%|â–Š               | 509/10000 [21:19<6:37:40,  2.51s/steps, loss=0.605, val_acc=0.5, best_val_acc=0.5, best_step=10]
    [TRAIN] COSINE distillation stage:   0%|                                                                                 | 0/10000 [03:05<?, ?steps/s]
    2021-10-23 20:14:13 - loading data from ../datasets/imdb/train.json
    2021-10-23 20:14:13 - loading data from ../datasets/imdb/valid.json
    2021-10-23 20:14:14 - loading data from ../datasets/imdb/test.json
    2021-10-23 22:02:05 - saving features into ../datasets/imdb/train_bert.pkl
    2021-10-23 22:16:34 - saving features into ../datasets/imdb/valid_bert.pkl
    2021-10-23 22:30:14 - saving features into ../datasets/imdb/test_bert.pkl
    2021-10-23 22:30:14 - label model test acc: 0.716
    2021-10-23 22:30:17 - 
    ==========[hyper parameters]==========
    {
        "teacher_update": 100,
        "margin": 1.0,
        "mu": 1.0,
        "thresh": 0.6,
        "lamda": 0.05,
        "batch_size": 8,
        "real_batch_size": 8,
        "test_batch_size": 8,
        "n_steps": 10000,
        "grad_norm": -1,
        "use_lr_scheduler": false,
        "binary_mode": false
    }
    ==========[optimizer config]==========
    {
        "name": "Adam",
        "paras": {
            "lr": 0.001,
            "weight_decay": 0.0
        }
    }
    ==========[backbone config]==========
    {
        "name": "BERT",
        "paras": {
            "model_name": "bert-base-cased",
            "max_tokens": 512,
            "fine_tune_layers": -1
        }
    }
    ==========[label model_config config]==========
    {
        "name": "MajorityVoting",
        "paras": {}
    }
    
    2021-10-23 22:51:52 - [INFO] early stop @ step 510!
    2021-10-23 22:55:20 - early stop because all the data are filtered!
    2021-10-23 22:56:06 - [1 1 1 ... 1 1 1]
    2021-10-23 22:56:06 - end model (COSINE) test acc: 0.5
    

    As can be seen for both models, label model test acc: 0.716 but end model (MLP) test acc: 0.5004 and end model (COSINE) test acc: 0.5.

    Am I doing something completely wrong? Could you please tell me if I am running the code correctly or is there some issue with hyperparameters?

    I would greatly appreciate if you could give me some advice. I would be very glad if you could include an example running script of the COSINE model as well.

    Thanks for the benchmark, I really appreciate it!

    opened by viheheb757 4
  • Reproducing Table 11 for classification

    Reproducing Table 11 for classification

    Thanks for this package @JieyuZ2 -- do you happen to have an orchestration script for reproducing Table 11 (and therefore Table 3) in the Wrench paper?

    opened by pmangg 3
  • No module named 'wrench.classification.self_training'

    No module named 'wrench.classification.self_training'

    Hi, I am trying to run run_denoise.py but I am getting the following error:

    Traceback (most recent call last):
      File "run_denoise.py", line 5, in <module>
        from wrench.classification import Denoise
      File "/gpfs/space/home/wrench/wrench/classification/__init__.py", line 4, in <module>
        from .self_training import LDSelfTrain, DDSelfTrain
    ModuleNotFoundError: No module named 'wrench.classification.self_training'
    

    Could you please add LDSelfTrain and DDSelfTrain classes?

    opened by andreaspung 3
  • Questions on the use of ground-truth labels for validation

    Questions on the use of ground-truth labels for validation

    Thanks for putting up the benchmark! This is really great work! It seems that both the label model and the end model use the ground-truth labels for validation. For example, the base label model uses the ground-truth labels of the validation set to calculate the class balance weights: https://github.com/JieyuZ2/wrench/blob/544119e781d010797cf153307aa1090361c99522/wrench/basemodel.py#L286 I have a few questions regarding this: (1) A valid baseline for the label models would be a classifier trained on the validation set with the weak labels of LFs as features and the ground-truth labels as the target. Given that the validation set for most datasets is actually not small, I feel the trained model might be a pretty strong baseline compared to other unsupervised label models. (2) Similar to how we combine the weak labels on the training set to get aggregated labels, we could also get aggregated labels for the validation set. Then, the aggregated labels instead of the ground-truth labels of the validation set could be used for validation purposes for the end model. Wouldn't this be a more realistic setting? Especially considering that the proposal of weak supervision is to replace human labeling with programmatical labeling.

    I appreciate any explanations. Thanks!

    opened by wurenzhi 2
  • Clarifying dataset download links

    Clarifying dataset download links

    Great work on the benchmark!

    Under the "Available Datasets" section on the main README, you provide 2 links for downloading the WRENCH datasets:

    One point of confusion is that expanded datasets found on the Google drive link are different than the direct download zip file. For example, classification/youtube/train.json on Google drive has 1686 instances while the zip file contains 1586 for the same file, matching the statistics reported on the README. Can you make the correct file download unambiguous in the documentation?

    opened by jason-fries 2
  • Fix retained probabilities

    Fix retained probabilities

    This pull request removes a bug which lead to the wrong probabilities being stored along with the predictions of each labeling function.

    Previously, all probabilities (2d tensor of size batch by classes) were saved alongside the class predictions. However, what was supposed to be saved is the probability associated with each prediction of the model.

    opened by benbo 2
  • New Release

    New Release

    Hi! Love the repo, super useful so far and really easy interface to use. Thanks for putting it together!

    I was wondering if there were plans to cut another release any time soon? We use the v1.0 tag for making sure the version is consistent across multiple builds. Noticed a few bug fixes and QOL improvements since the last release, and those would be nice to have marked at a new tag.

    opened by rsmith49 2
  • Numba 0.43 doesn't work with newer Python versions

    Numba 0.43 doesn't work with newer Python versions

    The numba package 0.43, specified here, doesn't work with Python 3.9. Upgrading the package to the latest version (0.54) resolves the issue. Traceback:

    /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/llvmlite/llvmpy/__init__.py:3: UserWarning: The module `llvmlite.llvmpy` is deprecated and will be removed in the future.
      warnings.warn(
    /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/llvmlite/llvmpy/core.py:8: UserWarning: The module `llvmlite.llvmpy.core` is deprecated and will be removed in the future. Equivalent functionality is provided by `llvmlite.ir`.
      warnings.warn(
    Traceback (most recent call last):
      File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
      File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
      File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
      File "<frozen importlib._bootstrap_external>", line 850, in exec_module
      File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/wrench/labelmodel/__init__.py", line 1, in <module>
        from .dawid_skene import DawidSkene
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/wrench/labelmodel/dawid_skene.py", line 6, in <module>
        from numba import njit, prange
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/__init__.py", line 25, in <module>
        from .decorators import autojit, cfunc, generated_jit, jit, njit, stencil
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/decorators.py", line 12, in <module>
        from .targets import registry
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/targets/registry.py", line 5, in <module>
        from . import cpu
      File "/home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/targets/cpu.py", line 9, in <module>
        from numba import _dynfunc, config
    ImportError: /home/murphy/anaconda3/envs/taskbase/lib/python3.9/site-packages/numba/_dynfunc.cpython-39-x86_64-linux-gnu.so: undefined symbol: _PyObject_GC_UNTRACK
    
    opened by susuzheng 0
  • COSINE for token classification?

    COSINE for token classification?

    Hi,

    I would like to know whether the code for cosine weak-supervision technique is already capable of performing token classification? Or else what changes should I need to do to build a weakly-supervised training pipeline using some weakly-labeled and unlabeled datasets?

    opened by KrishnanJothi 0
  • Balance in Dawid Skene is obsolete.

    Balance in Dawid Skene is obsolete.

    https://github.com/JieyuZ2/wrench/blob/6d8397956533fc6c2fe50e93fcfe0a2303bdd05f/wrench/labelmodel/dawid_skene.py#L55

    I realized this balance variable is used nowhere in this file. If it is intended, I think it should be removed from input parameters.

    opened by ch-shin 1
  • Balance sum to 1

    Balance sum to 1

    https://github.com/JieyuZ2/wrench/blob/ab717ac26a76649c8fdb946a28dffe7e682c80ba/wrench/basemodel.py#L303

    Hi, I find a minor issue that the class prior computed by this function does not sum to 1. Hope you can revise it.

    opened by Gnaiqing 0
  • about COSINE endmodel

    about COSINE endmodel

    Hi @JieyuZ2 and @yinxiangshi , I am trying to run the COSINE endmodel but I have some troubles in reproducing the results in COSINE paper. Although I tried to use the suggested hyperparameters I still get marginal benefit from wrench, and I'm not sure where is wrong. Can you share the scripts you used when evaluating COSINE? Thanks.

    opened by Gnaiqing 0
  • Recommended parameters to use for each algorithms and datasets.

    Recommended parameters to use for each algorithms and datasets.

    I've tried several combinations of different algorithms and datasets, but I found it's hard to get results similar to the paper. I suspect this is due to inappropriate parameter settings, so, I think it will be great if this repo can provide some recommended parameters. (Especially for the newly added algorithms, it's hard to judge if it get the right results)

    opened by mrbeann 0
Releases(v1.1)
  • v1.1(Nov 9, 2021)

    What's new:

    • A branch of new methods: WeaSEL, ImplyLoss, ASTRA, MeanTeacher, Meta-Weight-Net, Learning-to-Reweight
    • A new EndClassifierModel model which unifies all the classification backbones
    • Two new datasets on image classification
    • Support torch native amp for inference in the validation step
    • Support training on multiple GPUS via torch's DistributedDataParallel and the new parallel_fit function
    • fixed some bugs
    Source code(tar.gz)
    Source code(zip)
  • v1.0(Sep 7, 2021)

Owner
Jieyu Zhang
CS PhD
Jieyu Zhang
Mixup for Supervision, Semi- and Self-Supervision Learning Toolbox and Benchmark

OpenSelfSup News Downstream tasks now support more methods(Mask RCNN-FPN, RetinaNet, Keypoints RCNN) and more datasets(Cityscapes). 'GaussianBlur' is

AI Lab, Westlake University 332 Jan 3, 2023
Official implementation of "Not only Look, but also Listen: Learning Multimodal Violence Detection under Weak Supervision" ECCV2020

XDVioDet Official implementation of "Not only Look, but also Listen: Learning Multimodal Violence Detection under Weak Supervision" ECCV2020. The proj

peng 64 Dec 12, 2022
A curated list of programmatic weak supervision papers and resources

A curated list of programmatic weak supervision papers and resources

Jieyu Zhang 118 Jan 2, 2023
Self-training with Weak Supervision (NAACL 2021)

This repo holds the code for our weak supervision framework, ASTRA, described in our NAACL 2021 paper: "Self-Training with Weak Supervision"

Microsoft 148 Nov 20, 2022
Code and data of the ACL 2021 paper: Few-Shot Text Ranking with Meta Adapted Synthetic Weak Supervision

MetaAdaptRank This repository provides the implementation of meta-learning to reweight synthetic weak supervision data described in the paper Few-Shot

THUNLP 5 Jun 16, 2022
Hierarchical Metadata-Aware Document Categorization under Weak Supervision (WSDM'21)

Hierarchical Metadata-Aware Document Categorization under Weak Supervision This project provides a weakly supervised framework for hierarchical metada

Yu Zhang 53 Sep 17, 2022
Open source implementation of AceNAS: Learning to Rank Ace Neural Architectures with Weak Supervision of Weight Sharing

AceNAS This repo is the experiment code of AceNAS, and is not considered as an official release. We are working on integrating AceNAS as a built-in st

Yuge Zhang 6 Sep 7, 2022
A weakly-supervised scene graph generation codebase. The implementation of our CVPR2021 paper ``Linguistic Structures as Weak Supervision for Visual Scene Graph Generation''

README.md shall be finished soon. WSSGG 0 Overview 1 Installation 1.1 Faster-RCNN 1.2 Language Parser 1.3 GloVe Embeddings 2 Settings 2.1 VG-GT-Graph

Keren Ye 35 Nov 20, 2022
Learning trajectory representations using self-supervision and programmatic supervision.

Trajectory Embedding for Behavior Analysis (TREBA) Implementation from the paper: Jennifer J. Sun, Ann Kennedy, Eric Zhan, David J. Anderson, Yisong Y

null 58 Jan 6, 2023
ReSSL: Relational Self-Supervised Learning with Weak Augmentation

ReSSL: Relational Self-Supervised Learning with Weak Augmentation This repository contains PyTorch evaluation code, training code and pretrained model

mingkai 45 Oct 25, 2022
PyTorch implementation of Weak-shot Fine-grained Classification via Similarity Transfer

SimTrans-Weak-Shot-Classification This repository contains the official PyTorch implementation of the following paper: Weak-shot Fine-grained Classifi

BCMI 60 Dec 2, 2022
[NeurIPS 2021] A weak-shot object detection approach by transferring semantic similarity and mask prior.

[NeurIPS 2021] A weak-shot object detection approach by transferring semantic similarity and mask prior.

BCMI 49 Jul 27, 2022
Weak-supervised Visual Geo-localization via Attention-based Knowledge Distillation

Weak-supervised Visual Geo-localization via Attention-based Knowledge Distillation Introduction WAKD is a PyTorch implementation for our ICPR-2022 pap

null 2 Oct 20, 2022
[CVPR 2022] Back To Reality: Weak-supervised 3D Object Detection with Shape-guided Label Enhancement

Back To Reality: Weak-supervised 3D Object Detection with Shape-guided Label Enhancement Announcement ?? We have not tested the code yet. We will fini

Xiuwei Xu 7 Oct 30, 2022
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

Wonjae Kim 922 Jan 1, 2023
Code for the ICML 2021 paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision"

ViLT Code for the paper: "ViLT: Vision-and-Language Transformer Without Convolution or Region Supervision" Install pip install -r requirements.txt pip

Wonjae Kim 922 Jan 1, 2023
[CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision

TorchSemiSeg [CVPR 2021] Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision by Xiaokang Chen1, Yuhui Yuan2, Gang Zeng1, Jingdong Wang

Chen XiaoKang 387 Jan 8, 2023