Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Overview

Tez: a simple pytorch trainer

NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something doesn't work, please create an issue.

tez (तेज़ / تیز) means sharp, fast & active. This is a simple, to-the-point, library to make your pytorch training easy.

This library is in very early-stage currently! So, there might be breaking changes.

Idea around tez is simple:

  • keep things as simple as possible
  • make it as customizable as possible
  • clean code
  • faster prototyping
  • production ready

Currently, tez supports cpu and gpu training. More coming soon!

Using tez is super-easy. We don't want you to be far away from pytorch. So, you do everything on your own and just use tez to make a few things simpler.

Training using Tez:

  • To train a model, define a dataset and model. The dataset class is the same old class you would write when writing pytorch models.

  • Create your model class. Instead of inheriting from nn.Module, import tez and inherit from tez.Model as shown in the following example.

class MyModel(tez.Model):
    def __init__(self):
        super().__init__()
        .
        .
        # tell when to step the scheduler
        self.step_scheduler_after="batch"

    def monitor_metrics(self, outputs, targets):
        if targets is None:
            return {}
        outputs = torch.sigmoid(outputs).cpu().detach().numpy() >= 0.5
        targets = targets.cpu().detach().numpy()
        accuracy = metrics.accuracy_score(targets, outputs)
        return {"accuracy": accuracy}

    def fetch_scheduler(self):
        # create your own scheduler

    def fetch_optimizer(self):
        # create your own optimizer

    def forward(self, ids, mask, token_type_ids, targets=None):
        _, o_2 = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
        b_o = self.bert_drop(o_2)
        output = self.out(b_o)

        # calculate loss here
        loss = nn.BCEWithLogitsLoss()(output, targets)

        # calculate the metric dictionary here
        metric_dict = self.monitor_metrics(output, targets)
        return output, loss, metric_dict

Everything is super-intuitive!

  • Now you can train your model!
# init datasets
train_dataset = SomeTrainDataset()
valid_dataset = SomeValidDataset()

# init model
model = MyModel()


# init callbacks, you can also write your own callback
es = tez.callbacks.EarlyStopping(monitor="valid_loss", model_path="model.bin")

# train model. a familiar api!
model.fit(
    train_dataset,
    valid_dataset=valid_dataset,
    train_bs=32,
    device="cuda",
    epochs=50,
    callbacks=[es],
    fp16=True,
)

# save model (with optimizer and scheduler for future!)
model.save("model.bin")

You can checkout examples in examples/

Comments
  • ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    I am trying to use this package, and it is throwing as below. I am using the same pipeline from cassava lead detection problem but on different set where image size is (256, 256)

    Could you please help here.

    Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b4-6ed6700e.pth 100% 74.4M/74.4M [00:00<00:00, 107MB/s]

    Loaded pretrained weights for efficientnet-b4 0%| | 0/51 [00:00<?, ?it/s]

    ValueError Traceback (most recent call last) in () 11 epochs=10, 12 callbacks=[es], ---> 13 fp16=True, 14 ) 15 model.save("model.bin")

    6 frames /usr/local/lib/python3.6/dist-packages/tez/model/model.py in fit(self, train_dataset, valid_dataset, train_sampler, valid_sampler, device, epochs, train_bs, valid_bs, n_jobs, callbacks, fp16) 295 self.train_state = enums.TrainingState.EPOCH_START 296 self.train_state = enums.TrainingState.TRAIN_EPOCH_START --> 297 train_loss = self.train_one_epoch(self.train_loader, device) 298 self.train_state = enums.TrainingState.TRAIN_EPOCH_END 299 if self.valid_loader:

    /usr/local/lib/python3.6/dist-packages/tez/model/model.py in train_one_epoch(self, data_loader, device) 176 losses = AverageMeter() 177 tk0 = tqdm(data_loader, total=len(data_loader)) --> 178 for b_idx, data in enumerate(tk0): 179 self.train_state = enums.TrainingState.TRAIN_STEP_START 180 loss, metrics = self.train_one_step(data, device)

    /usr/local/lib/python3.6/dist-packages/tqdm/std.py in iter(self) 1102 fp_write=getattr(self.fp, 'write', sys.stderr.write)) 1103 -> 1104 for obj in iterable: 1105 yield obj 1106 # Update and possibly print the progressbar.

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in next(self) 433 if self._sampler_iter is None: 434 self._reset() --> 435 data = self._next_data() 436 self._num_yielded += 1 437 if self._dataset_kind == _DatasetKind.Iterable and \

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self) 1083 else: 1084 del self._task_info[idx] -> 1085 return self._process_data(data) 1086 1087 def _try_put_index(self):

    /usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data) 1109 self._try_put_index() 1110 if isinstance(data, ExceptionWrapper): -> 1111 data.reraise() 1112 return data 1113

    /usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self) 426 # have message field 427 raise self.exc_type(message=msg) --> 428 raise self.exc_type(msg) 429 430

    ValueError: Caught ValueError in DataLoader worker process 0. Original Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop data = fetcher.fetch(index) File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/usr/local/lib/python3.6/dist-packages/tez/datasets/image_classification.py", line 48, in getitem augmented = self.augmentations(image=image) File "/usr/local/lib/python3.6/dist-packages/albumentations/core/composition.py", line 171, in call data = t(**data) File "/usr/local/lib/python3.6/dist-packages/albumentations/core/transforms_interface.py", line 38, in call res[key] = target_function(arg, **dict(params, **target_dependencies)) File "/usr/local/lib/python3.6/dist-packages/albumentations/augmentations/transforms.py", line 808, in apply return F.normalize(image, self.mean, self.std, self.max_pixel_value) File "/usr/local/lib/python3.6/dist-packages/albumentations/augmentations/functional.py", line 93, in normalize img -= mean ValueError: operands could not be broadcast together with shapes (256,256,4) (3,) (256,256,4)

    opened by nvnvashisth 10
  • zero_grad for accumulation_steps = 1 not working as expected

    zero_grad for accumulation_steps = 1 not working as expected

    As far as I know, in normal execution flow for zero_grad and forward pass, first we zero_gard for each batch and then do the forward pass but I investigated that in code, it is not happening in this way when accumualtion_steps =1 and batch =1, first forward pass executes first without doing zero_grad.

    I tried to reproduce it and it is doing the same which I explained above.

    image

    Also, I think we can fix this by removing condition in the tez.py file on line # 330, 331.

    opened by abdurrehman11 9
  • Can it work without CUDA

    Can it work without CUDA

    I am getting error when I executed the code with CPU configuration.

    Traceback (most recent call last): File "c:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\recommender.py", line 88, in train() File "c:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\recommender.py", line 82, in train model.fit( File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\tez\model\model.py", line 309, in fit self._init_model( File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\tez\model\model.py", line 93, in _init_model self.to(self.device) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 852, in to return self._apply(convert) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 530, in _apply module._apply(fn) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 552, in apply param_applied = fn(param) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\nn\modules\module.py", line 850, in convert return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking) File "C:\Users\Hemanth\Desktop\Data Analytics analyticvidya\recommender system\venv\lib\site-packages\torch\cuda_init.py", line 166, in _lazy_init raise AssertionError("Torch not compiled with CUDA enabled") AssertionError: Torch not compiled with CUDA enabled

    opened by hemanthh17 7
  • Documentation improvement - How is tez faster?

    Documentation improvement - How is tez faster?

    Great to see a nice Pytorch training library.

    I think it would help users use it maybe to show what kind of performance improvements come from the box with Tez. For example comparing how fp16 is enabled in tez vs vanilla pytorch could ben informative or just a quick list of optimisations that are easy to do with Tez such as fp16.

    opened by swartchris8 5
  • Is it possible to set variable Lr per epoch

    Is it possible to set variable Lr per epoch

    @abhishekkrthakur Was finding this framework great and easy to use . But as fairly new to it was thinking if there is a way to pass variable Lr for training say for every epoch as an example.

    Also is there a way to say continue training from a particular epoch if say the local system crashed or got disturbed during the training process.

    opened by gauravbrills 3
  • Applying metrics after the epoch

    Applying metrics after the epoch

    Dears, I am using tez to classify melanoma images (kaggle SIIM binary classification). With wtfml is possible to get AUC ~ 0.85. With tez, I am only getting AUC ~ 0.6. I saw that this happens, in tez, when using metrics.roc_auc_score(...) inside monitor_metrics method. This gives some ValueError exceptions, that must be handled returning auc = 0.5 (this error occurs when the data have only 1 class).

    In the wtfml, the metrics.roc_auc_score(...) method is used only after Engine.evaluate. In this case, the data always have two classes (because the KStratified gives that).

    I am wondering if it is possible, in tez, to apply the metrics.roc_auc_score(...) only after the epoch, and not in each train_bs. With that, the data always will have two classes, avoiding the ValueError exceptions.

    PS.

    1. In the class definition init I am using: self.step_scheduler_after = "epoch" self.step_scheduler_metric = "valid_auc"
    2. In the monitor_metrics method: try: auc = metrics.roc_auc_score(targets, outputs.ravel()) except ValueError: auc = 0.5 return {"auc": auc}
    3. My model.fit is defined as: model.fit(train_dataset, valid_dataset=valid_dataset, train_bs=32, valid_bs=16, device="cuda", epochs=50, callbacks=[es], fp16=False, n_jobs=2)
    opened by waldcarl 2
  • Issue while using Auc metric on imbalanced dataset like melanoma(ValueError: Only one class present in y_true. ROC AUC score is not defined in that case)

    Issue while using Auc metric on imbalanced dataset like melanoma(ValueError: Only one class present in y_true. ROC AUC score is not defined in that case)

    this problem occur due to running metric calculation

    I got the solution from stackoverflow:

    You cannot have an ROC curve without both positive and negative examples in your dataset. With only one class in the dataset, you cannot measure your false-positive rate, and therefore cannot plot an ROC curve. This is why you get this error message.

    How to handle this problem?

    opened by IamSantoshKumar 2
  • Error in Multiclass TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    Error in Multiclass TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    /usr/local/lib/python3.7/dist-packages/torch/cuda/amp/grad_scaler.py:116: UserWarning: torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling. warnings.warn("torch.cuda.amp.GradScaler is enabled, but CUDA is not available. Disabling.") 0%| | 0/2939 [00:00<?, ?it/s]/usr/local/lib/python3.7/dist-packages/torch/cuda/amp/autocast_mode.py:118: UserWarning: torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available. Disabling. warnings.warn("torch.cuda.amp.autocast only affects CUDA ops, but CUDA is not available. Disabling.")

    TypeError Traceback (most recent call last) in () 143 epochs=3, 144 callbacks=[tb_logger, es], --> 145 fp16=True, 146 ) 147 model.save("model.bin")

    8 frames /usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in dropout(input, p, training, inplace) 1074 if p < 0.0 or p > 1.0: 1075 raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) -> 1076 return VF.dropout(input, p, training) if inplace else _VF.dropout(input, p, training) 1077 1078

    TypeError: dropout(): argument 'input' (position 1) must be Tensor, not str

    opened by gokulguptanew 1
  • Text classification examples - Tokenizer is defined twice

    Text classification examples - Tokenizer is defined twice

    The tokenizer is defined both in the model and the dataset in the BERT text classification examples.

    multi_class.py, line 50: self.tokenizer = transformers.BertTokenizer.from_pretrained( "bert-base-uncased", do_lower_case=True )

    opened by obesp 1
  • Small error in image_classification.py

    Small error in image_classification.py

    If augmentation is None then we face error as , variable augmented referenced before assignment UnboundLocalError: local variable 'augmented' referenced before assignment

    elif self.backend == "cv2":
                image = cv2.imread(self.image_paths[item])
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                if self.resize is not None:
                    image = cv2.resize(
                        image,
                        (self.resize[1], self.resize[0]),
                        interpolation=cv2.INTER_CUBIC,
                    )
                if self.augmentations is not None:
                    augmented = self.augmentations(image=image)
                    image = augmented["image"]
    

    If the indendation is fixed we can solve this error.

    opened by VpkPrasanna 1
  • Small error in model.py

    Small error in model.py

    Hi! Love this library.
    In tez/model/model.py there is probably a mistake in line 90:

    self.train_loader = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=train_bs,
                    num_workers=n_jobs,
                    sampler=valid_sampler,
                    shuffle=True,
                )
    

    I guess train_sampler is meant to be used here, not valid_sampler.

    opened by hocop 1
  • run example code error

    run example code error

    when I run example code:

    accelerate launch   imdb_sentiment_classification.py
    

    after run some epoch get error info

    INFO:tez.callbacks.early_stopping:EarlyStopping counter: 4/5
    [train] accuracy=0.9915, loss=0.0269 [valid] accuracy=0.8953, loss=0.4287 [e=5 steps=2112]                                                                                                 
     30%|████████████████████████████████▍                                                                           | 2112/7040 [05:45<06:40, 12.32it/s, accuracy=0.991, epoch=5, loss=0.0269]2022-09-17 07:55:02,832 INFO EarlyStopping counter: 5/5
    INFO:tez.callbacks.early_stopping:EarlyStopping counter: 5/5
     30%|████████████████████████████████▍                                                                           | 2112/7040 [05:47<13:31,  6.07it/s, accuracy=0.991, epoch=5, loss=0.0269]
    
    
    
    
    [E ProcessGroupNCCL.cpp:719] [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808970 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:719] [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808984 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:719] [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1809275 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 2] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1809275 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 3] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808970 milliseconds before timing out.
    [E ProcessGroupNCCL.cpp:406] Some NCCL operations have failed or timed out. Due to the asynchronous nature of CUDA kernels, subsequent GPU operations might run on corrupted/incomplete data. To avoid this inconsistency, we are taking the entire process down.
    terminate called after throwing an instance of 'std::runtime_error'
      what():  [Rank 1] Watchdog caught collective operation timeout: WorkNCCL(SeqNum=34532, OpType=BROADCAST, Timeout(ms)=1800000) ran for 1808984 milliseconds before timing out.
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/resource_sharer.py", line 138, in _serve
        with self._listener.accept() as conn:
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 470, in accept
        deliver_challenge(c, self._authkey)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 745, in deliver_challenge
        response = connection.recv_bytes(256)        # reject large message
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 221, in recv_bytes
        buf = self._recv_bytes(maxlength)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 419, in _recv_bytes
        buf = self._recv(4)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 384, in _recv
        chunk = read(handle, remaining)
    ConnectionResetError: [Errno 104] Connection reset by peer
    WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 113654 closing signal SIGTERM
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/resource_sharer.py", line 138, in _serve
        with self._listener.accept() as conn:
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 470, in accept
        deliver_challenge(c, self._authkey)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 745, in deliver_challenge
        response = connection.recv_bytes(256)        # reject large message
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 221, in recv_bytes
        buf = self._recv_bytes(maxlength)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 419, in _recv_bytes
        buf = self._recv(4)
      File "/root/miniconda3/envs/lightning/lib/python3.9/multiprocessing/connection.py", line 384, in _recv
        chunk = read(handle, remaining)
    ConnectionResetError: [Errno 104] Connection reset by peer
    ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: -6) local_rank: 2 (pid: 113655) of binary: /root/miniconda3/envs/lightning/bin/python
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/bin/torchrun", line 33, in <module>
        sys.exit(load_entry_point('torch==1.11.0', 'console_scripts', 'torchrun')())
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 345, in wrapper
        return f(*args, **kwargs)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/run.py", line 724, in main
        run(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/run.py", line 715, in run
        elastic_launch(
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 131, in __call__
        return launch_agent(self._config, self._entrypoint, list(args))
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/torch/distributed/launcher/api.py", line 245, in launch_agent
        raise ChildFailedError(
    torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
    =======================================================
    imdb_sentiment_classification.py FAILED
    -------------------------------------------------------
    Failures:
    [1]:
      time      : 2022-09-17_08:25:22
      host      : dy-a100-779-tlzrv
      rank      : 3 (local_rank: 3)
      exitcode  : -6 (pid: 113656)
      error_file: <N/A>
      traceback : Signal 6 (SIGABRT) received by PID 113656
    -------------------------------------------------------
    Root Cause (first observed failure):
    [0]:
      time      : 2022-09-17_08:25:22
      host      : dy-a100-779-tlzrv
      rank      : 2 (local_rank: 2)
      exitcode  : -6 (pid: 113655)
      error_file: <N/A>
      traceback : Signal 6 (SIGABRT) received by PID 113655
    =======================================================
    Traceback (most recent call last):
      File "/root/miniconda3/envs/lightning/bin/accelerate", line 33, in <module>
        sys.exit(load_entry_point('accelerate==0.12.0.dev0', 'console_scripts', 'accelerate')())
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/accelerate_cli.py", line 43, in main
        args.func(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/launch.py", line 734, in launch_command
        multi_gpu_launcher(args)
      File "/root/miniconda3/envs/lightning/lib/python3.9/site-packages/accelerate-0.12.0.dev0-py3.9.egg/accelerate/commands/launch.py", line 374, in multi_gpu_launcher
        raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
    subprocess.CalledProcessError: Command '['torchrun', '--nproc_per_node', '4', 'imdb_sentiment_classification.py']' returned non-zero exit status 1.
    
    opened by bestpredicts 0
  • Getting error while importing enums from tez.

    Getting error while importing enums from tez.

    Traceback (most recent call last): File "/content/tez/tez/model/model.py", line 12, in from tez import enums File "/content/tez/tez/model/tez.py", line 11, in from tez import enums ImportError: cannot import name 'enums' from 'tez' (/content/tez/tez/model/tez.py)

    Waiting for positive reply.

    opened by VikasRathod314 3
  • Saving validation score

    Saving validation score

    Is it possible to save somehow a list of the validation scores (on epochs or batches) after training? I have some problems with output on my server, it deletes usually, but I really need validation scores to compare models, it would be really convenient, if I could get them in one file, for example.

    opened by 25icecreamflavors 0
  • Saving after training an epoch

    Saving after training an epoch

    How to save the model after each epoch training? I use fit method for 5 epochs and do not really understand hot to save after each one. not only after the last one.

    opened by 25icecreamflavors 2
  • How can we access the input_ids/attention mask in each train batch loop?

    How can we access the input_ids/attention mask in each train batch loop?

    I tried using a train step callback but I am not sure how to get access to the dataloader input_ids and attention mask during each train step. Is this possible?

    BTW Thanks for the library!

    opened by tkmaker 0
Releases(v0.1.8)
Owner
abhishek thakur
Kaggle: www.kaggle.com/abhishek
abhishek thakur
You like pytorch? You like micrograd? You love tinygrad! ❤️

For something in between a pytorch and a karpathy/micrograd This may not be the best deep learning framework, but it is a deep learning framework. Due

George Hotz 9.7k Jan 5, 2023
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 2, 2023
higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.

higher is a library providing support for higher-order optimization, e.g. through unrolled first-order optimization loops, of "meta" aspects of these

Facebook Research 1.5k Jan 3, 2023
TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards

TorchShard is a lightweight engine for slicing a PyTorch tensor into parallel shards. It can reduce GPU memory and scale up the training when the model has massive linear layers (e.g., ViT, BERT and GPT) or huge classes (millions). It has the same API design as PyTorch.

Kaiyu Yue 275 Nov 22, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision

?? Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.

Hugging Face 3.5k Jan 8, 2023
A tutorial on "Bayesian Compression for Deep Learning" published at NIPS (2017).

Code release for "Bayesian Compression for Deep Learning" In "Bayesian Compression for Deep Learning" we adopt a Bayesian view for the compression of

Karen Ullrich 190 Dec 30, 2022
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 6, 2023
A very simple and small path tracer written in pytorch meant to be run on the GPU

MentisOculi Pytorch Path Tracer A very simple and small path tracer written in pytorch meant to be run on the GPU Why use pytorch and not some other c

Matthew B. Mirman 222 Dec 1, 2022
Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and Layer Input Masking"

model_based_energy_constrained_compression Code for paper "Energy-Constrained Compression for Deep Neural Networks via Weighted Sparse Projection and

Haichuan Yang 16 Jun 15, 2022
A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

null 878 Dec 30, 2022
PyGCL: Graph Contrastive Learning Library for PyTorch

PyGCL is an open-source library for graph contrastive learning (GCL), which features modularized GCL components from published papers, standardized evaluation, and experiment management.

GCL: Graph Contrastive Learning Library for PyTorch 592 Jan 7, 2023
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
null 270 Dec 24, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
On the Variance of the Adaptive Learning Rate and Beyond

RAdam On the Variance of the Adaptive Learning Rate and Beyond We are in an early-release beta. Expect some adventures and rough edges. Table of Conte

Liyuan Liu 2.5k Dec 27, 2022
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 5, 2023
Fast, general, and tested differentiable structured prediction in PyTorch

Torch-Struct: Structured Prediction Library A library of tested, GPU implementations of core structured prediction algorithms for deep learning applic

HNLP 1.1k Jan 7, 2023
A tiny scalar-valued autograd engine and a neural net library on top of it with PyTorch-like API

micrograd A tiny Autograd engine (with a bite! :)). Implements backpropagation (reverse-mode autodiff) over a dynamically built DAG and a small neural

Andrej 3.5k Jan 8, 2023