Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Overview

Tacotron 2 (without wavenet)

PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions.

This implementation includes distributed and automatic mixed precision support and uses the LJSpeech dataset.

Distributed and Automatic Mixed Precision support relies on NVIDIA's Apex and AMP.

Visit our website for audio samples using our published Tacotron 2 and WaveGlow models.

Alignment, Predicted Mel Spectrogram, Target Mel Spectrogram

Pre-requisites

  1. NVIDIA GPU + CUDA cuDNN

Setup

  1. Download and extract the LJ Speech dataset
  2. Clone this repo: git clone https://github.com/NVIDIA/tacotron2.git
  3. CD into this repo: cd tacotron2
  4. Initialize submodule: git submodule init; git submodule update
  5. Update .wav paths: sed -i -- 's,DUMMY,ljs_dataset_folder/wavs,g' filelists/*.txt
    • Alternatively, set load_mel_from_disk=True in hparams.py and update mel-spectrogram paths
  6. Install PyTorch 1.0
  7. Install Apex
  8. Install python requirements or build docker image
    • Install python requirements: pip install -r requirements.txt

Training

  1. python train.py --output_directory=outdir --log_directory=logdir
  2. (OPTIONAL) tensorboard --logdir=outdir/logdir

Training using a pre-trained model

Training using a pre-trained model can lead to faster convergence
By default, the dataset dependent text embedding layers are ignored

  1. Download our published Tacotron 2 model
  2. python train.py --output_directory=outdir --log_directory=logdir -c tacotron2_statedict.pt --warm_start

Multi-GPU (distributed) and Automatic Mixed Precision Training

  1. python -m multiproc train.py --output_directory=outdir --log_directory=logdir --hparams=distributed_run=True,fp16_run=True

Inference demo

  1. Download our published Tacotron 2 model
  2. Download our published WaveGlow model
  3. jupyter notebook --ip=127.0.0.1 --port=31337
  4. Load inference.ipynb

N.b. When performing Mel-Spectrogram to Audio synthesis, make sure Tacotron 2 and the Mel decoder were trained on the same mel-spectrogram representation.

Related repos

WaveGlow Faster than real time Flow-based Generative Network for Speech Synthesis

nv-wavenet Faster than real time WaveNet.

Acknowledgements

This implementation uses code from the following repos: Keith Ito, Prem Seetharaman as described in our code.

We are inspired by Ryuchi Yamamoto's Tacotron PyTorch implementation.

We are thankful to the Tacotron 2 paper authors, specially Jonathan Shen, Yuxuan Wang and Zongheng Yang.

Comments
  • Optimize model for inference speed

    Optimize model for inference speed

    https://github.com/NVIDIA/waveglow/issues/54 In this issue, they were talking about lower some parameters to maximize inference speed. But I dont know how to do it properly, what can be reduced and what need to remain. Anyone did this before? Please send me your hparams configuration.

    if I trained my model using fp32, can it run inference in fp16 and vice versa? in this case, will it impove inference speed? I am using RTX 2080ti, my model run 7 times faster than real-time, and I am pretty sure it can be improved

    and one more thing, is there any benefit of running inference using multi-GPUs?

    opened by EuphoriaCelestial 66
  • How much iterations i need?

    How much iterations i need?

    Training on Russian dataset, output says I have less than 0.2 loss, the default is 500 epoch. Now I'm on 1333 epoch and still get Warning! Reached max decoder steps Should I counting or it screwed up? http://puu.sh/CgXpt/41886048cd.jpg

    opened by hadaev8 61
  • Little gap between words on the alignment plot

    Little gap between words on the alignment plot

    Hi, My alignment plot has a little gap between words.

    alignment-63k

    batch size - 64 filter_length=2048, hop_length=275, win_length=1100,

    And it sounds bit like reading separated words. Is there anyone who had same issue before.

    Thanks.

    opened by dnnnew 50
  • How to train a new model with dataset of diffirent language?

    How to train a new model with dataset of diffirent language?

    I would like to know if it possible to train a Tacotron 2 model for another language, using another dataset which have the same structure as LJ Speech dataset? And if it is possible, is there any tutorial to do so?

    opened by EuphoriaCelestial 46
  • Audio examples?

    Audio examples?

    Very cool work, this! However, it would be ideal to also provide examples of input text + output audio from a trained system, alongside held-out examples from the database. This will give an impression of what kind of results the code is capable of generating with the LJSpeech data, and is standard practise in the text-to-speech field.

    Aside from synthesising held-out sentences from LJSpeech, Google's speech examples for Tacotron 2 provide another set of challenging text prompts to generate.

    Are there any plans to do this? Or are synthesised speech examples already available somewhere?

    opened by ghenter 36
  • Model can not converge

    Model can not converge

    Hello, I have a question.

    I'm using a dataset with > 17k sentences (about 30 hours audio), 90% for training and 10% for validating. It's been training for 3 days (using batch_size 8) and reaching Epoch 56. Plz see training info below [Grad norm] image [Training loss] image [Validation loss] image I thought it looks good. But when I tested it, the output audio was wrong and Attention looks awful. image

    And the loss seems can not decrease any more. Do I have to train for more Epoch or there was something wrong with my dataset, or something else? Plz help me, thank u guys so much.

    opened by HiiamCong 32
  • Hoarseness in synthesised voice

    Hoarseness in synthesised voice

    Hi, so we have been training both tacotron2 and waveglow models on clean male speech (Hindi language) of about 10+ hours @ 16kHz sampling rate using phonemic text. We keep the window parameters, and mel_band parameters unchanged from the 22.5kHz setup in the original repo. Both the models were trained from scratch in a distributed manner using 8-V100 for over 2 days. The resulting voice synthesis produces a hoarse voice. We tried synthesizing using models from different iterations. The hoarseness remains the same irrespective. From visualizing the spectrogram of the original (left-side of the image) and synthesized audio (right-side of the image) of the same text we observe that the spectrogram of synthesized audio is smudged in the middle frequencies in comparison to crisp harmonics of the original audio.

    Any suggestions on how to train better?

    Left: Original (Audio) , Right: Synthesized (Audio)

    Spectrogram - left: Original, right: Synthesized

    opened by sharathadavanne 28
  • Gate_output is being mis-predicted

    Gate_output is being mis-predicted

    I have trained the model for only few steps (8000) for testing.

    At this checkpoint, the synthesized mel spectrogram's frames always bigger than max frames (in hparams config)

    As I have checked, the problem is gate_output always less than configured gate_threshold (it's around 0.0022 or 0.0023), so the frames had been generated forever.

    Is this problem because I was not trained model much enough, or what should I have to check?

    Please help, thank you so much.

    opened by Thien223 26
  • Warning! Reached max decoder steps

    Warning! Reached max decoder steps

    I had no such problem before But today used pretrained models And sometimes spectrogram is bugged

    here is notebook for reproduce https://colab.research.google.com/drive/1jR12cEKdkg0hlDUHGhf2fPb0RwqPwEYj?#scrollTo=CyBu2F7eisFM

    opened by hadaev8 22
  • Overfitting: validation loss stuck.

    Overfitting: validation loss stuck.

    Hello, I am experimenting with 3.5 hours of data set. (16kbit, 22khz, 9.5s av.duration, 1300 audio files) I am using pretrained models which are provided by this repo.

    My hparams; p_attention_dropout=0.5, p_decoder_dropout=0.5, learning_rate=4e-5, batch_size=16,

    After 15k iteraton: Train loss 0.25. It is producing almost original like outputs. alignment melpredict meltarget

    The only thing I couldn't improve is a validation loss, it stuck since 4000 iteration.

    validationloss

    All suggestions are welcome! Thanks beforehand.

    opened by ksaidin 21
  • scaling Mel Spectrogram output for Wavenet Vocoder

    scaling Mel Spectrogram output for Wavenet Vocoder

    Hello,

    First of all thanks for the nice Tacotron 2 implementation.

    I'm trying to use the trained Tacotron 2 outputs as inputs to r9r9's Wavenet vocoder. However his pre-trained wavenet works on scaled Mel Spectrogram between [0, 1].

    What is the range for this tacotron 2 implementation, I'm having a hard time finding this out to use it for scaling.

    For reference, this is r9r9's normalization function that he applies to the Mel Spectrogram before using it for training, which scales it between 0 and 1:

    def _normalize(S): return np.clip((S - hparams.min_level_db) / -hparams.min_level_db, 0, 1)

    opened by G-Wang 20
  • Missing requirement and some requirement dependency issues

    Missing requirement and some requirement dependency issues

    • numpy==1.13.3
      • conflicts with tensorflow and matplotlib
    • scipy==1.0.0
      • function that are used are missing in this version
    • torch
      • isn't even there
    • numba==0.48
      • missing functions in new version required by librosa
    • resampy==0.3.1
      • I forgot which module requires this but it works
    opened by AdmiralPuni 0
  • inference error! please help me

    inference error! please help me

    when I run the inference code ,it has an error as followed

    waveglow_path = '/home/zhonghuihang/tacotron2-master/waveglow_256channels_universal_v5.pt' waveglow = torch.load(waveglow_path)['model'] waveglow.cuda().eval().half() for k in waveglow.convinv: k.float() denoiser = Denoiser(waveglow)


    InvalidArgumentsError Traceback (most recent call last) /tmp/ipykernel_1534266/1318601042.py in 1 waveglow_path = '/home/zhonghuihang/tacotron2-master/waveglow_256channels_universal_v5.pt' ----> 2 waveglow = torch.load(waveglow_path)['model'] 3 waveglow.cuda().eval().half() 4 for k in waveglow.convinv: 5 k.float()

    ~/miniconda3/envs/torch/lib/python3.7/site-packages/torch/serialization.py in load(f, map_location, pickle_module, weights_only, **pickle_load_args)

    ~/miniconda3/envs/torch/lib/python3.7/site-packages/torch/serialization.py in _legacy_load(f, map_location, pickle_module, **pickle_load_args)

    ~/miniconda3/envs/torch/lib/python3.7/site-packages/torch/serialization.py in find_class(self, mod_name, name)

    ~/miniconda3/envs/torch/lib/python3.7/site-packages/glow/init.py in 28 extend_all(functions) 29 ---> 30 from .wgr import * # For backwards compatibility. Avoid showing this import in docs. 31 32 from . import wgr

    ~/miniconda3/envs/torch/lib/python3.7/site-packages/glow/wgr/init.py in 13 # limitations under the License. 14 ---> 15 from glow.wgr.ridge_reduction import * 16 from glow.wgr.ridge_regression import * 17 from glow.wgr.logistic_ridge_regression import *

    ~/miniconda3/envs/torch/lib/python3.7/site-packages/glow/wgr/ridge_reduction.py in 13 # limitations under the License. 14 ---> 15 from .ridge_udfs import * 16 from .model_functions import _is_binary, _prepare_covariates, _prepare_labels_and_warn, _check_model 17 from nptyping import Float, NDArray

    ~/miniconda3/envs/torch/lib/python3.7/site-packages/glow/wgr/ridge_udfs.py in 13 # limitations under the License. 14 ---> 15 from glow.wgr.model_functions import * 16 from nptyping import Float 17 import pandas as pd

    ~/miniconda3/envs/torch/lib/python3.7/site-packages/glow/wgr/model_functions.py in 108 # @typechecked -- typeguard does not support numpy array 109 def assemble_block(n_rows: Int, n_cols: Int, pdf: pd.DataFrame, cov_matrix: NDArray[(Any, Any), --> 110 Float], 111 row_mask: NDArray[Any]) -> NDArray[Float]: 112 """

    ~/miniconda3/envs/torch/lib/python3.7/site-packages/nptyping/base_meta_classes.py in getitem(cls, item) 136 raise NPTypingError(f"Type nptyping.{cls} is already parameterized.") 137 --> 138 args = cls._get_item(item) 139 additional_values = cls._get_additional_values(item) 140 assert hasattr(cls, "args"), "A SubscriptableMeta must have args."

    ~/miniconda3/envs/torch/lib/python3.7/site-packages/nptyping/ndarray.py in _get_item(cls, item) 67 def _get_item(cls, item: Any) -> Tuple[Any, ...]: 68 cls._check_item(item) ---> 69 shape, dtype = cls._get_from_tuple(item) 70 return shape, dtype 71

    ~/miniconda3/envs/torch/lib/python3.7/site-packages/nptyping/ndarray.py in _get_from_tuple(cls, item) 108 def _get_from_tuple(cls, item: Tuple[Any, ...]) -> Tuple[Shape, DType]: 109 # Return the Shape Expression and DType from a tuple. --> 110 shape = cls._get_shape(item[0]) 111 dtype = cls._get_dtype(item[1]) 112 return shape, dtype

    ~/miniconda3/envs/torch/lib/python3.7/site-packages/nptyping/ndarray.py in _get_shape(cls, dtype_candidate) 122 else: 123 raise InvalidArgumentsError( --> 124 f"Unexpected argument '{dtype_candidate}', expecting" 125 " Shape[]" 126 " or Literal[]"

    InvalidArgumentsError: Unexpected argument '(typing.Any, typing.Any)', expecting Shape[] or Literal[] or typing.Any.

    opened by lunar333 0
  • Tacotron 2 no longer works on Google colab

    Tacotron 2 no longer works on Google colab

    Since Google colab no longer functions with Tensorflow 1, it has corrupted Tacotron 2 training and synthesis notebooks. Even when I use the Tensorflow 2, it still corrupts Tacotron 2 by not recognizing the child directories. Like: os.chdir('tacotron2') no longer functions. Is there a way to fix?

    opened by JCThornton90 1
  • build(deps): bump tensorflow from 1.15.2 to 2.9.3

    build(deps): bump tensorflow from 1.15.2 to 2.9.3

    Bumps tensorflow from 1.15.2 to 2.9.3.

    Release notes

    Sourced from tensorflow's releases.

    TensorFlow 2.9.3

    Release 2.9.3

    This release introduces several vulnerability fixes:

    TensorFlow 2.9.2

    Release 2.9.2

    This releases introduces several vulnerability fixes:

    ... (truncated)

    Changelog

    Sourced from tensorflow's changelog.

    Release 2.9.3

    This release introduces several vulnerability fixes:

    Release 2.8.4

    This release introduces several vulnerability fixes:

    ... (truncated)

    Commits
    • a5ed5f3 Merge pull request #58584 from tensorflow/vinila21-patch-2
    • 258f9a1 Update py_func.cc
    • cd27cfb Merge pull request #58580 from tensorflow-jenkins/version-numbers-2.9.3-24474
    • 3e75385 Update version numbers to 2.9.3
    • bc72c39 Merge pull request #58482 from tensorflow-jenkins/relnotes-2.9.3-25695
    • 3506c90 Update RELEASE.md
    • 8dcb48e Update RELEASE.md
    • 4f34ec8 Merge pull request #58576 from pak-laura/c2.99f03a9d3bafe902c1e6beb105b2f2417...
    • 6fc67e4 Replace CHECK with returning an InternalError on failing to create python tuple
    • 5dbe90a Merge pull request #58570 from tensorflow/r2.9-7b174a0f2e4
    • Additional commits viewable in compare view

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    • @dependabot use these labels will set the current labels as the default for future PRs for this repo and language
    • @dependabot use these reviewers will set the current reviewers as the default for future PRs for this repo and language
    • @dependabot use these assignees will set the current assignees as the default for future PRs for this repo and language
    • @dependabot use this milestone will set the current milestone as the default for future PRs for this repo and language

    You can disable automated security fix PRs for this repo from the Security Alerts page.

    dependencies 
    opened by dependabot[bot] 0
  • AssertionError on assert os.path.isfile(checkpoint_path)

    AssertionError on assert os.path.isfile(checkpoint_path)

    Hello, as the title says I'm getting the AssertionError on line 85 of train.py when I try to run the code with

    python J:\tacotron2\train.py --output_directory="J:\tacotron\output" --log_directory="J:\tacotron\log" -c tacotron2_statedict.pt --warm_start

    I've installed Anaconda3 with Python 3.9, Tensorflow (GPU) 2.6.0 and all the required libraries (librosa, inflect, unidecode, keras). I've also made the modifications to the hparams.py file in order to make it work with Tensorflow2 as suggested by @v-nhandt21 at https://github.com/NVIDIA/tacotron2/issues/278.

    What could be the problem? I have a list of 100 WAV files, 22050 Hz 16-bit.

    opened by Pitr2 0
  • Add CodeQL workflow for GitHub code scanning

    Add CodeQL workflow for GitHub code scanning

    Hi NVIDIA/tacotron2!

    This is a one-off automatically generated pull request from LGTM.com :robot:. You might have heard that we’ve integrated LGTM’s underlying CodeQL analysis engine natively into GitHub. The result is GitHub code scanning!

    With LGTM fully integrated into code scanning, we are focused on improving CodeQL within the native GitHub code scanning experience. In order to take advantage of current and future improvements to our analysis capabilities, we suggest you enable code scanning on your repository. Please take a look at our blog post for more information.

    This pull request enables code scanning by adding an auto-generated codeql.yml workflow file for GitHub Actions to your repository — take a look! We tested it before opening this pull request, so all should be working :heavy_check_mark:. In fact, you might already have seen some alerts appear on this pull request!

    Where needed and if possible, we’ve adjusted the configuration to the needs of your particular repository. But of course, you should feel free to tweak it further! Check this page for detailed documentation.

    Questions? Check out the FAQ below!

    FAQ

    Click here to expand the FAQ section

    How often will the code scanning analysis run?

    By default, code scanning will trigger a scan with the CodeQL engine on the following events:

    • On every pull request — to flag up potential security problems for you to investigate before merging a PR.
    • On every push to your default branch and other protected branches — this keeps the analysis results on your repository’s Security tab up to date.
    • Once a week at a fixed time — to make sure you benefit from the latest updated security analysis even when no code was committed or PRs were opened.

    What will this cost?

    Nothing! The CodeQL engine will run inside GitHub Actions, making use of your unlimited free compute minutes for public repositories.

    What types of problems does CodeQL find?

    The CodeQL engine that powers GitHub code scanning is the exact same engine that powers LGTM.com. The exact set of rules has been tweaked slightly, but you should see almost exactly the same types of alerts as you were used to on LGTM.com: we’ve enabled the security-and-quality query suite for you.

    How do I upgrade my CodeQL engine?

    No need! New versions of the CodeQL analysis are constantly deployed on GitHub.com; your repository will automatically benefit from the most recently released version.

    The analysis doesn’t seem to be working

    If you get an error in GitHub Actions that indicates that CodeQL wasn’t able to analyze your code, please follow the instructions here to debug the analysis.

    How do I disable LGTM.com?

    If you have LGTM’s automatic pull request analysis enabled, then you can follow these steps to disable the LGTM pull request analysis. You don’t actually need to remove your repository from LGTM.com; it will automatically be removed in the next few months as part of the deprecation of LGTM.com (more info here).

    Which source code hosting platforms does code scanning support?

    GitHub code scanning is deeply integrated within GitHub itself. If you’d like to scan source code that is hosted elsewhere, we suggest that you create a mirror of that code on GitHub.

    How do I know this PR is legitimate?

    This PR is filed by the official LGTM.com GitHub App, in line with the deprecation timeline that was announced on the official GitHub Blog. The proposed GitHub Action workflow uses the official open source GitHub CodeQL Action. If you have any other questions or concerns, please join the discussion here in the official GitHub community!

    I have another question / how do I get in touch?

    Please join the discussion here to ask further questions and send us suggestions!

    opened by lgtm-com[bot] 0
Owner
NVIDIA Corporation
NVIDIA Corporation
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 4, 2023
null 270 Dec 24, 2022
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

Fidelity Investments 56 Sep 13, 2022
A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

null 878 Dec 30, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
A PyTorch implementation of EfficientNet

EfficientNet PyTorch Quickstart Install with pip install efficientnet_pytorch and load a pretrained EfficientNet with: from efficientnet_pytorch impor

Luke Melas-Kyriazi 7.2k Jan 6, 2023
PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022
An implementation of Performer, a linear attention-based transformer, in Pytorch

Performer - Pytorch An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random

Phil Wang 900 Dec 22, 2022
PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

glow-pytorch PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Kim Seonghyeon 433 Dec 27, 2022
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
A Pytorch Implementation for Compact Bilinear Pooling.

CompactBilinearPooling-Pytorch A Pytorch Implementation for Compact Bilinear Pooling. Adapted from tensorflow_compact_bilinear_pooling Prerequisites I

null 169 Dec 23, 2022
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 7, 2022
Pytorch implementation of Distributed Proximal Policy Optimization

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Alexis David Jacq 164 Jan 5, 2023
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
Pretrained ConvNets for pytorch: NASNet, ResNeXt, ResNet, InceptionV4, InceptionResnetV2, Xception, DPN, etc.

Pretrained models for Pytorch (Work in progress) The goal of this repo is: to help to reproduce research papers results (transfer learning setups for

Remi 8.7k Dec 31, 2022