Unofficial Alias-Free GAN implementation. Based on rosinality's version with expanded training and inference options.

Overview

Alias-Free GAN

An unofficial version of Alias-Free Generative Adversarial Networks (https://arxiv.org/abs/2106.12423). This repository was heavily based on Kim Seonghyeon's (rosinality) implementation. The goal of this version is to be maintainable, easy to use, and expand the features of existing implementations. This is built using pytorch and pytorch lightning (a framework that abstracts away much of the hardware specific code).

See open issues unsupported features, planned features, and current bugs.

Licence

MIT licence badge

This project is officially licensed as an MIT project. However, it's requested that you use this repository with the intention of actively elevating historically marginalized communities. Avoid using this project for creating anything that inflicts physical or psychological violence on individuals, groups, or animals.

Financial Support

If you have the financial means please consider contributing to this project or creators of pretrained models.

This project takes money to run because while many continuos integration (ci) tools are free to open source projects, they do not offer the necessary hardware to run tests on GPUs or TPUs. Because of this the ci testing needs to be run using Google Cloud Platform which charges for GPU and TPU instances. Any financial contributions to this project will first go to covering those costs first.

Licenses and contributions for pretrained models is designed to be flexible so please review the information for a particular model before using it.

Buy me a coffee: https://www.buymeacoffee.com/duskvirkus

Tezos Wallet Address: tz1PeE5NQyyjyJdnyqyKdtFN27TvN9j3ZcUe - Only send tezos to this wallet.

About Branches

devel is the default branch and it includes the latest features but may have breaking changes.

stable branch is the latest stable release of the repository the only updates it receives between versions are to the available pretrained models, additional notebooks, and any bug fixes.

Branch All CI GPU pytest TPUs pytest
devel CI gpu pytest on gcloud tpus pytest on gcloud
stable CI gpu pytest on gcloud tpus pytest on gcloud

Examples

training example gif

Example of animation made from samples generated in training process.


linear interpolation example

Example of linear interpolation between two random seeds.


circular loop example

Example of circular loop interpolation.


noise loop example

Example of open simplex noise loop interpolation.


rosinality translate example

Example output from converted rosinality translate script.

Supported Model Architectures

Below is a list of supported model architecture. The hope is to support NVlabs code when it comes out.

model-architecture Base Repository Repository Version Limit [start,end) Description
alias-free-rosinality-v1 https://github.com/rosinality/alias-free-gan-pytorch [fixed model commit on July 7th 2021, _) Based on rosinality implementation after some model fixes.

Notes:

Notebooks

GPU Colab Training Notebook

devel Branch

View Notebook

Open In Colab

stable Branch

View Notebook

Open In Colab

GPU Colab Inference Notebook

devel Branch

View Notebook

Open In Colab

stable Branch

View Notebook

Open In Colab

TPU Notebooks

Coming at some point in the future.

Pre-trained Models

See pretrained_models.json.

Use model_name as --resume_from argument for trainer.py.

Pretrained models will automatically download using wget but here's link's if that isn't working for some reason. Place them under a pretrained directory in project root directory.

rosinality-ffhq-800k

Description: A 256 model trained by rosinality on ffhq dataset for 800k steps at a batch size of 16.

Contributing

Contribute Pretrained Models

Contributing your trained models for others to transfer learn off of or use in their projects. This helps reduce the training time and resources required to use.

You can do so by creating a pull request to the stable branch. Add information about your model to the to pretrained_models.json using the template below.

        {
            "model_name": "model-name-no-spaces-no-file-extension",
            "creator": "github username",
            "model_architecture": "see model architecture section",
            "description": "Describe your model. What was it trained on? How long was it trained for? Feel free to include links to make donations and suggested donation amounts. Also include licence information such as creative commons or other licencees.",
            "model_size": 512,
            "wget_url": "Please include a link to download your model in you're pull request and I will update this", 
            "sha1": "If you know how to make a sha1 hash then you can fill this out if not leave this blank."
        }

Contribute Notebooks

If you make a notebook and want to share it that is welcome. If it's just a notebook you can just make a pull request to stable. If it requires changes to the code base please open an issue to discuss which branch it should go on.

Other Contributions

Other contributions are welcome but open an issue to discuss what you want to change/add. Unless it's a small non breaking bug fix pull requests may or may not be accepted without discussion beforehand.

Comments
  • Missing .pt file

    Missing .pt file

    Hi. If I follow all your cells and instructions in order, I get this once I reach the second cell of "Generate Single Images"

    Using Alias-Free GAN version: 1.0.0 Invalid path /content/drive/MyDrive/colab-gpu-alias-free/alias-free-gan/results/training-000005/000145-epoch-checkpoint.pt is not a .pt model file.

    I cannot find that file anywhere or any other .pt file for that matter.

    Thanks.

    opened by tnoya001 4
  • Curious about CI costs

    Curious about CI costs

    I've been thinking about adding CI to clip-guided-diffusion - I have a few integration tests that just run a few forward passes on CPU/GPU. Saw your note about pricing being an issue and it's essentially the main thing preventing me from doing that on a GPU setup in the cloud. For now, I'm fortunate the most of the checkpoints run inference fine on my RTX 2070 that I own.

    Assuming you don't have any issue with revealing such info; what are the costs like for CI on the project? Do you have any tips for people trying to do CI with machine learning?

    Thanks in advance for any info

    opened by afiaka87 2
  • v1.1.0 updates

    v1.1.0 updates

    Bug fixes:

    • Fixed kimg counting when using --accumulate_grad_batches. kimgs was under counting by factor of accumulate_grad_batches. If using v1.0.0 training notebook multiply the kimg count by 4 to get the correct number and use --start_kimg_count when resuming training to override the file name parsing.

    Non breaking updates:

    • Refactored download and loading pretrained models. Added code to load pretrained models in all generate scripts.
    • Updated interpolation notebook to include commands to convert frames to .mp4 and .gif
    • Added more examples to readme and fixed a few typos.
    opened by duskvirkus 0
  • Switching to saving samples and checkpoints based on kimgs

    Switching to saving samples and checkpoints based on kimgs

    • fixed tensor device bug in EqualLinear.forward() method
    • added 3 training integration tests (from scratch, ffhq, and custom resume)
    • switched to using wget with digital ocean hosted pretrained models
    • moved ci files out of repository
    • added create_sample_grid_vectors.py script
    • added KimgSaverCallback which saves samples and checkpoints based on kimgs
    • removed --n_samples argument
    • added --save_sample_every_kimgs, --save_checkpoint_every_kimgs, --start_kimg_count, --stop_training_at_kimgs and --sample_grid arguments
    opened by duskvirkus 0
  • Fixing bugs and removing pytorch lightning checkpoints

    Fixing bugs and removing pytorch lightning checkpoints

    fixing bugs from #16. Also drive storage problem and somewhere in the neighborhood of %15 performance improvement by removing pytorch lighting checkpoints.

    opened by duskvirkus 0
Releases(v1.1.0)
  • v1.1.0(Aug 31, 2021)

    Bug fixes:

    • Fixed kimg counting when using --accumulate_grad_batches. kimgs was under counting by factor of accumulate_grad_batches. If using v1.0.0 training notebook multiply the kimg count by 4 to get the correct number and use --start_kimg_count when resuming training to override the file name parsing.

    Non breaking updates:

    • Refactored download and loading pretrained models. Added code to load pretrained models in all generate scripts.
    • Updated interpolation notebook to include commands to convert frames to .mp4 and .gif
    • Added more examples to readme and fixed a few typos.
    Source code(tar.gz)
    Source code(zip)
  • v1.0.0(Aug 25, 2021)

    Release Overview

    In theory this version is a stable working version of Alias-Free GAN supporting rosinality's unofficial implementation.

    Changes from the Rosinality Repository

    • Converted to run using pytorch lighting - Supports GPU and TPU training with many built in options. For more information see: https://pytorch-lightning.readthedocs.io/en/1.4.2/
    • Added CPU op library for TPU support
    • prepare_data.py renamed to convert_dataset.py and moved to scripts/convert_dataset.py
    • generate.py adapted and put in a script under scripts/rosinality_generate.py

    Scripts Added

    Trainer

    Creates an Alias-Free GAN instance and trains the model saving checkpoints based on kimgs (thousands of image).

    Using Alias-Free GAN version: 1.0.0
    usage: trainer.py [-h] --dataset_path DATASET_PATH [--resume_from RESUME_FROM]
                      --size SIZE [--batch BATCH] [--lr_g LR_G] [--lr_d LR_D]
                      [--r1 R1] [--augment AUGMENT] [--augment_p AUGMENT_P]
                      [--ada_target ADA_TARGET] [--ada_length ADA_LENGTH]
                      [--ada_every ADA_EVERY]
                      [--stylegan2_discriminator STYLEGAN2_DISCRIMINATOR]
                      [--save_sample_every_kimgs SAVE_SAMPLE_EVERY_KIMGS]
                      [--save_checkpoint_every_kimgs SAVE_CHECKPOINT_EVERY_KIMGS]
                      [--start_kimg_count START_KIMG_COUNT]
                      [--stop_training_at_kimgs STOP_TRAINING_AT_KIMGS]
                      [--sample_grid SAMPLE_GRID] [--logger [LOGGER]]
                      [--checkpoint_callback [CHECKPOINT_CALLBACK]]
                      [--default_root_dir DEFAULT_ROOT_DIR]
                      [--gradient_clip_val GRADIENT_CLIP_VAL]
                      [--gradient_clip_algorithm GRADIENT_CLIP_ALGORITHM]
                      [--process_position PROCESS_POSITION]
                      [--num_nodes NUM_NODES] [--num_processes NUM_PROCESSES]
                      [--devices DEVICES] [--gpus GPUS]
                      [--auto_select_gpus [AUTO_SELECT_GPUS]]
                      [--tpu_cores TPU_CORES] [--ipus IPUS]
                      [--log_gpu_memory LOG_GPU_MEMORY]
                      [--progress_bar_refresh_rate PROGRESS_BAR_REFRESH_RATE]
                      [--overfit_batches OVERFIT_BATCHES]
                      [--track_grad_norm TRACK_GRAD_NORM]
                      [--check_val_every_n_epoch CHECK_VAL_EVERY_N_EPOCH]
                      [--fast_dev_run [FAST_DEV_RUN]]
                      [--accumulate_grad_batches ACCUMULATE_GRAD_BATCHES]
                      [--max_epochs MAX_EPOCHS] [--min_epochs MIN_EPOCHS]
                      [--max_steps MAX_STEPS] [--min_steps MIN_STEPS]
                      [--max_time MAX_TIME]
                      [--limit_train_batches LIMIT_TRAIN_BATCHES]
                      [--limit_val_batches LIMIT_VAL_BATCHES]
                      [--limit_test_batches LIMIT_TEST_BATCHES]
                      [--limit_predict_batches LIMIT_PREDICT_BATCHES]
                      [--val_check_interval VAL_CHECK_INTERVAL]
                      [--flush_logs_every_n_steps FLUSH_LOGS_EVERY_N_STEPS]
                      [--log_every_n_steps LOG_EVERY_N_STEPS]
                      [--accelerator ACCELERATOR]
                      [--sync_batchnorm [SYNC_BATCHNORM]] [--precision PRECISION]
                      [--weights_summary WEIGHTS_SUMMARY]
                      [--weights_save_path WEIGHTS_SAVE_PATH]
                      [--num_sanity_val_steps NUM_SANITY_VAL_STEPS]
                      [--truncated_bptt_steps TRUNCATED_BPTT_STEPS]
                      [--resume_from_checkpoint RESUME_FROM_CHECKPOINT]
                      [--profiler PROFILER] [--benchmark [BENCHMARK]]
                      [--deterministic [DETERMINISTIC]]
                      [--reload_dataloaders_every_n_epochs RELOAD_DATALOADERS_EVERY_N_EPOCHS]
                      [--reload_dataloaders_every_epoch [RELOAD_DATALOADERS_EVERY_EPOCH]]
                      [--auto_lr_find [AUTO_LR_FIND]]
                      [--replace_sampler_ddp [REPLACE_SAMPLER_DDP]]
                      [--terminate_on_nan [TERMINATE_ON_NAN]]
                      [--auto_scale_batch_size [AUTO_SCALE_BATCH_SIZE]]
                      [--prepare_data_per_node [PREPARE_DATA_PER_NODE]]
                      [--plugins PLUGINS] [--amp_backend AMP_BACKEND]
                      [--amp_level AMP_LEVEL]
                      [--distributed_backend DISTRIBUTED_BACKEND]
                      [--move_metrics_to_cpu [MOVE_METRICS_TO_CPU]]
                      [--multiple_trainloader_mode MULTIPLE_TRAINLOADER_MODE]
                      [--stochastic_weight_avg [STOCHASTIC_WEIGHT_AVG]]
    
    optional arguments:
      -h, --help            show this help message and exit
    
    Trainer Script:
      --dataset_path DATASET_PATH
                            Path to dataset. Required!
      --resume_from RESUME_FROM
                            Resume from checkpoint or transfer learn off
                            pretrained model. Leave blank to train from scratch.
    
    AliasFreeGAN Model:
      --size SIZE           Pixel dimension of model. Must be 256, 512, or 1024.
                            Required!
      --batch BATCH         Batch size. Will be overridden if
                            --auto_scale_batch_size is used. (default: 16)
      --lr_g LR_G           Generator learning rate. (default: 0.002)
      --lr_d LR_D           Discriminator learning rate. (default: 0.002)
      --r1 R1               R1 regularization weights. (default: 10.0)
      --augment AUGMENT     Use augmentations. (default: False)
      --augment_p AUGMENT_P
                            Augment probability, the probability that augmentation
                            is applied. 0.0 is 0 percent and 1.0 is 100. If set to
                            0.0 and augment is enabled AdaptiveAugmentation will
                            be used. (default: 0.0)
      --ada_target ADA_TARGET
                            Target for AdaptiveAugmentation. (default: 0.6)
      --ada_length ADA_LENGTH
                            (default: 500000)
      --ada_every ADA_EVERY
                            How often to update augmentation probabilities when
                            using AdaptiveAugmentation. (default: 8)
      --stylegan2_discriminator STYLEGAN2_DISCRIMINATOR
                            Provide path to a rosinality stylegan2 checkpoint to
                            load the discriminator from it. Will load second so if
                            you load another model first it will override that
                            discriminator.
    
    kimg Saver Callback:
      --save_sample_every_kimgs SAVE_SAMPLE_EVERY_KIMGS
                            Sets the frequency of saving samples in kimgs
                            (thousands of image). (default: 1)
      --save_checkpoint_every_kimgs SAVE_CHECKPOINT_EVERY_KIMGS
                            Sets the frequency of saving model checkpoints in
                            kimgs (thousands of image). (default: 4)
      --start_kimg_count START_KIMG_COUNT
                            Manually override the start count for kimgs. If not
                            set the count will be inferred from checkpoint name.
                            If count can not be inferred it will default to 0.
      --stop_training_at_kimgs STOP_TRAINING_AT_KIMGS
                            Automatically stop training at this number of kimgs.
                            (default: 12800)
      --sample_grid SAMPLE_GRID
                            Sample grid to use for samples. Saved under
                            assets/sample_grids. (default:
                            default_5x3_sample_grid)
    
    pl.Trainer:
      --logger [LOGGER]     Logger (or iterable collection of loggers) for
                            experiment tracking. A ``True`` value uses the default
                            ``TensorBoardLogger``. ``False`` will disable logging.
                            If multiple loggers are provided and the `save_dir`
                            property of that logger is not set, local files
                            (checkpoints, profiler traces, etc.) are saved in
                            ``default_root_dir`` rather than in the ``log_dir`` of
                            any of the individual loggers.
      --checkpoint_callback [CHECKPOINT_CALLBACK]
                            If ``True``, enable checkpointing. It will configure a
                            default ModelCheckpoint callback if there is no user-
                            defined ModelCheckpoint in :paramref:`~pytorch_lightni
                            ng.trainer.trainer.Trainer.callbacks`.
      --default_root_dir DEFAULT_ROOT_DIR
                            Default path for logs and weights when no
                            logger/ckpt_callback passed. Default: ``os.getcwd()``.
                            Can be remote file paths such as `s3://mybucket/path`
                            or 'hdfs://path/'
      --gradient_clip_val GRADIENT_CLIP_VAL
                            0 means don't clip.
      --gradient_clip_algorithm GRADIENT_CLIP_ALGORITHM
                            'value' means clip_by_value, 'norm' means
                            clip_by_norm. Default: 'norm'
      --process_position PROCESS_POSITION
                            orders the progress bar when running multiple models
                            on same machine.
      --num_nodes NUM_NODES
                            number of GPU nodes for distributed training.
      --num_processes NUM_PROCESSES
                            number of processes for distributed training with
                            distributed_backend=\"ddp_cpu\"
      --devices DEVICES     Will be mapped to either `gpus`, `tpu_cores`,
                            `num_processes` or `ipus`, based on the accelerator
                            type.
      --gpus GPUS           number of gpus to train on (int) or which GPUs to
                            train on (list or str) applied per node
      --auto_select_gpus [AUTO_SELECT_GPUS]
                            If enabled and `gpus` is an integer, pick available
                            gpus automatically. This is especially useful when
                            GPUs are configured to be in \"exclusive mode\", such
                            that only one process at a time can access them.
      --tpu_cores TPU_CORES
                            How many TPU cores to train on (1 or 8) / Single TPU
                            to train on [1]
      --ipus IPUS           How many IPUs to train on.
      --log_gpu_memory LOG_GPU_MEMORY
                            None, 'min_max', 'all'. Might slow performance
      --progress_bar_refresh_rate PROGRESS_BAR_REFRESH_RATE
                            How often to refresh progress bar (in steps). Value
                            ``0`` disables progress bar. Ignored when a custom
                            progress bar is passed to
                            :paramref:`~Trainer.callbacks`. Default: None, means a
                            suitable value will be chosen based on the environment
                            (terminal, Google COLAB, etc.).
      --overfit_batches OVERFIT_BATCHES
                            Overfit a fraction of training data (float) or a set
                            number of batches (int).
      --track_grad_norm TRACK_GRAD_NORM
                            -1 no tracking. Otherwise tracks that p-norm. May be
                            set to 'inf' infinity-norm.
      --check_val_every_n_epoch CHECK_VAL_EVERY_N_EPOCH
                            Check val every n train epochs.
      --fast_dev_run [FAST_DEV_RUN]
                            runs n if set to ``n`` (int) else 1 if set to ``True``
                            batch(es) of train, val and test to find any bugs (ie:
                            a sort of unit test).
      --accumulate_grad_batches ACCUMULATE_GRAD_BATCHES
                            Accumulates grads every k batches or as set up in the
                            dict.
      --max_epochs MAX_EPOCHS
                            Stop training once this number of epochs is reached.
                            Disabled by default (None). If both max_epochs and
                            max_steps are not specified, defaults to
                            ``max_epochs`` = 1000.
      --min_epochs MIN_EPOCHS
                            Force training for at least these many epochs.
                            Disabled by default (None). If both min_epochs and
                            min_steps are not specified, defaults to
                            ``min_epochs`` = 1.
      --max_steps MAX_STEPS
                            Stop training after this number of steps. Disabled by
                            default (None).
      --min_steps MIN_STEPS
                            Force training for at least these number of steps.
                            Disabled by default (None).
      --max_time MAX_TIME   Stop training after this amount of time has passed.
                            Disabled by default (None). The time duration can be
                            specified in the format DD:HH:MM:SS (days, hours,
                            minutes seconds), as a :class:`datetime.timedelta`, or
                            a dictionary with keys that will be passed to
                            :class:`datetime.timedelta`.
      --limit_train_batches LIMIT_TRAIN_BATCHES
                            How much of training dataset to check (float =
                            fraction, int = num_batches)
      --limit_val_batches LIMIT_VAL_BATCHES
                            How much of validation dataset to check (float =
                            fraction, int = num_batches)
      --limit_test_batches LIMIT_TEST_BATCHES
                            How much of test dataset to check (float = fraction,
                            int = num_batches)
      --limit_predict_batches LIMIT_PREDICT_BATCHES
                            How much of prediction dataset to check (float =
                            fraction, int = num_batches)
      --val_check_interval VAL_CHECK_INTERVAL
                            How often to check the validation set. Use float to
                            check within a training epoch, use int to check every
                            n steps (batches).
      --flush_logs_every_n_steps FLUSH_LOGS_EVERY_N_STEPS
                            How often to flush logs to disk (defaults to every 100
                            steps).
      --log_every_n_steps LOG_EVERY_N_STEPS
                            How often to log within steps (defaults to every 50
                            steps).
      --accelerator ACCELERATOR
                            Previously known as distributed_backend (dp, ddp,
                            ddp2, etc...). Can also take in an accelerator object
                            for custom hardware.
      --sync_batchnorm [SYNC_BATCHNORM]
                            Synchronize batch norm layers between process
                            groups/whole world.
      --precision PRECISION
                            Double precision (64), full precision (32) or half
                            precision (16). Can be used on CPU, GPU or TPUs.
      --weights_summary WEIGHTS_SUMMARY
                            Prints a summary of the weights when training begins.
      --weights_save_path WEIGHTS_SAVE_PATH
                            Where to save weights if specified. Will override
                            default_root_dir for checkpoints only. Use this if for
                            whatever reason you need the checkpoints stored in a
                            different place than the logs written in
                            `default_root_dir`. Can be remote file paths such as
                            `s3://mybucket/path` or 'hdfs://path/' Defaults to
                            `default_root_dir`.
      --num_sanity_val_steps NUM_SANITY_VAL_STEPS
                            Sanity check runs n validation batches before starting
                            the training routine. Set it to `-1` to run all
                            batches in all validation dataloaders.
      --truncated_bptt_steps TRUNCATED_BPTT_STEPS
                            Deprecated in v1.3 to be removed in 1.5. Please use :p
                            aramref:`~pytorch_lightning.core.lightning.LightningMo
                            dule.truncated_bptt_steps` instead.
      --resume_from_checkpoint RESUME_FROM_CHECKPOINT
                            Path/URL of the checkpoint from which training is
                            resumed. If there is no checkpoint file at the path,
                            start from scratch. If resuming from mid-epoch
                            checkpoint, training will start from the beginning of
                            the next epoch.
      --profiler PROFILER   To profile individual steps during training and assist
                            in identifying bottlenecks.
      --benchmark [BENCHMARK]
                            If true enables cudnn.benchmark.
      --deterministic [DETERMINISTIC]
                            If true enables cudnn.deterministic.
      --reload_dataloaders_every_n_epochs RELOAD_DATALOADERS_EVERY_N_EPOCHS
                            Set to a non-negative integer to reload dataloaders
                            every n epochs. Default: 0
      --reload_dataloaders_every_epoch [RELOAD_DATALOADERS_EVERY_EPOCH]
                            Set to True to reload dataloaders every epoch. ..
                            deprecated:: v1.4 ``reload_dataloaders_every_epoch``
                            has been deprecated in v1.4 and will be removed in
                            v1.6. Please use
                            ``reload_dataloaders_every_n_epochs``.
      --auto_lr_find [AUTO_LR_FIND]
                            If set to True, will make trainer.tune() run a
                            learning rate finder, trying to optimize initial
                            learning for faster convergence. trainer.tune() method
                            will set the suggested learning rate in self.lr or
                            self.learning_rate in the LightningModule. To use a
                            different key set a string instead of True with the
                            key name.
      --replace_sampler_ddp [REPLACE_SAMPLER_DDP]
                            Explicitly enables or disables sampler replacement. If
                            not specified this will toggled automatically when DDP
                            is used. By default it will add ``shuffle=True`` for
                            train sampler and ``shuffle=False`` for val/test
                            sampler. If you want to customize it, you can set
                            ``replace_sampler_ddp=False`` and add your own
                            distributed sampler.
      --terminate_on_nan [TERMINATE_ON_NAN]
                            If set to True, will terminate training (by raising a
                            `ValueError`) at the end of each training batch, if
                            any of the parameters or the loss are NaN or +/-inf.
      --auto_scale_batch_size [AUTO_SCALE_BATCH_SIZE]
                            If set to True, will `initially` run a batch size
                            finder trying to find the largest batch size that fits
                            into memory. The result will be stored in
                            self.batch_size in the LightningModule. Additionally,
                            can be set to either `power` that estimates the batch
                            size through a power search or `binsearch` that
                            estimates the batch size through a binary search.
      --prepare_data_per_node [PREPARE_DATA_PER_NODE]
                            If True, each LOCAL_RANK=0 will call prepare data.
                            Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare
                            data
      --plugins PLUGINS     Plugins allow modification of core behavior like ddp
                            and amp, and enable custom lightning plugins.
      --amp_backend AMP_BACKEND
                            The mixed precision backend to use (\"native\" or
                            \"apex\")
      --amp_level AMP_LEVEL
                            The optimization level to use (O1, O2, etc...).
      --distributed_backend DISTRIBUTED_BACKEND
                            deprecated. Please use 'accelerator'
      --move_metrics_to_cpu [MOVE_METRICS_TO_CPU]
                            Whether to force internal logged metrics to be moved
                            to cpu. This can save some gpu memory, but can make
                            training slower. Use with attention.
      --multiple_trainloader_mode MULTIPLE_TRAINLOADER_MODE
                            How to loop over the datasets when there are multiple
                            train loaders. In 'max_size_cycle' mode, the trainer
                            ends one epoch when the largest dataset is traversed,
                            and smaller datasets reload when running out of their
                            data. In 'min_size' mode, all the datasets reload when
                            reaching the minimum length of datasets.
      --stochastic_weight_avg [STOCHASTIC_WEIGHT_AVG]
                            Whether to use `Stochastic Weight Averaging (SWA)
                            <https://pytorch.org/blog/pytorch-1.6-now-includes-
                            stochastic-weight-averaging/>_`\n"
    

    Generate Images

    Using Alias-Free GAN version: 1.0.0
    usage: generate_images.py [-h] --load_model LOAD_MODEL --outdir OUTDIR
                              [--model_arch MODEL_ARCH] [--seed_start SEED_START]
                              [--seed_stop SEED_STOP] [--trunc TRUNC]
                              [--batch BATCH] --size SIZE
    
    optional arguments:
      -h, --help            show this help message and exit
    
    Generate Script:
      --load_model LOAD_MODEL
                            Load a model checkpoint to use for generating content.
      --outdir OUTDIR       Where to save the output images
      --model_arch MODEL_ARCH
                            The model architecture of the model to be loaded.
                            (default: alias-free-rosinality-v1)
      --seed_start SEED_START
                            Start range for seed values. (default: 0)
      --seed_stop SEED_STOP
                            Stop range for seed values. Is inclusive. (default:
                            99)
      --trunc TRUNC         Truncation psi (default: 0.75)
      --batch BATCH         Number of images to generate each batch. default: 8)
    
    AliasFreeGenerator:
      --size SIZE           Pixel dimension of model. Must be 256, 512, or 1024.
                            Required!
    

    Generate Interpolation

    Using Alias-Free GAN version: 1.0.0
    usage: generate_interpolation.py [-h] --size SIZE --load_model LOAD_MODEL
                                     --outdir OUTDIR [--model_arch MODEL_ARCH]
                                     [--trunc TRUNC] [--batch BATCH]
                                     [--save_z_vectors SAVE_Z_VECTORS]
                                     [--log_args LOG_ARGS] [--method METHOD]
                                     [--path_to_z_vectors PATH_TO_Z_VECTORS]
                                     [--frames FRAMES] [--seeds SEEDS [SEEDS ...]]
                                     [--easing EASING]
                                     [--diameter DIAMETER [DIAMETER ...]]
    
    optional arguments:
      -h, --help            show this help message and exit
    
    AliasFreeGenerator:
      --size SIZE           Pixel dimension of model. Must be 256, 512, or 1024.
                            Required!
    
    Generate Script:
      --load_model LOAD_MODEL
                            Load a model checkpoint to use for generating content.
      --outdir OUTDIR       Where to save the output images
      --model_arch MODEL_ARCH
                            The model architecture of the model to be loaded.
                            (default: alias-free-rosinality-v1)
      --trunc TRUNC         Truncation psi (default: 0.75)
      --batch BATCH         Number of images to generate each batch. default: 8)
      --save_z_vectors SAVE_Z_VECTORS
                            Save the z vectors used to interpolate. default: True
      --log_args LOG_ARGS   Saves the arguments to a text file for later
                            reference. default: True
      --method METHOD       Select a method for interpolation. Options:
                            ['circular', 'interpolate', 'load_z_vectors',
                            'simplex_noise'] default: interpolate
      --path_to_z_vectors PATH_TO_Z_VECTORS
                            Path to saved z vectors to load. For method:
                            'load_z_vectors'
      --frames FRAMES       Total number of frames to generate. For methods:
                            'interpolate', 'circular', 'simplex_noise'
      --seeds SEEDS [SEEDS ...]
                            Add a seed value to a interpolation walk. First seed
                            value will be used as the seed for a circular or noise
                            walk. If none are provided random ones will be
                            generated. For methods: 'interpolate', 'circular',
                            'simplex_noise'
      --easing EASING       How to ease between seeds. For method: 'interpolate'
                            Options: ['easeInBack', 'easeInBounce', 'easeInCirc',
                            'easeInCubic', 'easeInElastic', 'easeInExpo',
                            'easeInOutBack', 'easeInOutBounce', 'easeInOutCirc',
                            'easeInOutCubic', 'easeInOutElastic', 'easeInOutExpo',
                            'easeInOutQuad', 'easeInOutQuart', 'easeInOutQuint',
                            'easeInOutSine', 'easeInQuad', 'easeInQuart',
                            'easeInQuint', 'easeInSine', 'easeOutBack',
                            'easeOutBounce', 'easeOutCirc', 'easeOutCubic',
                            'easeOutElastic', 'easeOutExpo', 'easeOutQuad',
                            'easeOutQuart', 'easeOutQuint', 'easeOutSine',
                            'linear'] default: linear
      --diameter DIAMETER [DIAMETER ...]
                            Defines the diameter of the circular or noise path. If
                            two arguments are passed they will be used as a min
                            and max a range for random diameters. For method:
                            'circular', 'simplex_noise'
    

    Create Sample Grid Vectors

    Creates a pytorch file with vectors to be used in trainer script to generate sample grid.

    usage: create_sample_grid_vectors.py [-h] [--rows ROWS] [--cols COLS]
                                         [--seed SEED] [--style_dim STYLE_DIM]
                                         [--include_zero_point_five_vec INCLUDE_ZERO_POINT_FIVE_VEC]
                                         --save_location SAVE_LOCATION
    
    optional arguments:
      -h, --help            show this help message and exit
    
    Create Sample Grid Vectors Script:
      --rows ROWS           Number of rows in sample grid (default: 3)
      --cols COLS           Number of columns in sample grid (default: 5)
      --seed SEED           Random seed to use (default: 0)
      --style_dim STYLE_DIM
                            Style dimension size. (Not the same as model
                            resolution, you'll proably know if you have to change
                            this.) (default: 512)
      --include_zero_point_five_vec INCLUDE_ZERO_POINT_FIVE_VEC
                            Include vector with 0.5 for every dimension. Will be
                            put in 0, 0 spot on the grid. (default: True)
      --save_location SAVE_LOCATION
                            Where the sample grid vectors will be saved.
    

    TPU Setup

    A bash script to setup TPU IP addresses on colab.

    Use following code to run in .ipynb notebook.

    import os
    with open('./scripts/tpu_setup.sh') as f:
        os.environ.update(
            line.replace('export ', '', 1).strip().split('=', 1) for line in f
            if 'export' in line
        )
    

    pip freeze on Colab

    You don't need all these packages but if there are package conflicts in the future this is a working setup for gpu training and inference.

    absl-py==0.12.0
    aiohttp==3.7.4.post0
    alabaster==0.7.12
    albumentations==0.1.12
    altair==4.1.0
    appdirs==1.4.4
    argcomplete==1.12.3
    argon2-cffi==20.1.0
    arviz==0.11.2
    astor==0.8.1
    astropy==4.3.1
    astunparse==1.6.3
    async-timeout==3.0.1
    atari-py==0.2.9
    atomicwrites==1.4.0
    attrs==21.2.0
    audioread==2.1.9
    autograd==1.3
    Babel==2.9.1
    backcall==0.2.0
    beautifulsoup4==4.6.3
    bleach==4.0.0
    blis==0.4.1
    bokeh==2.3.3
    Bottleneck==1.3.2
    branca==0.4.2
    bs4==0.0.1
    CacheControl==0.12.6
    cached-property==1.5.2
    cachetools==4.2.2
    catalogue==1.0.0
    certifi==2021.5.30
    cffi==1.14.6
    cftime==1.5.0
    chardet==3.0.4
    charset-normalizer==2.0.4
    clang==5.0
    click==7.1.2
    cloudpickle==1.3.0
    cmake==3.12.0
    cmdstanpy==0.9.5
    colorcet==2.0.6
    colorlover==0.3.0
    community==1.0.0b1
    configparser==5.0.2
    contextlib2==0.5.5
    convertdate==2.3.2
    coverage==3.7.1
    coveralls==0.5
    crcmod==1.7
    cufflinks==0.17.3
    cupy-cuda101==9.1.0
    cvxopt==1.2.6
    cvxpy==1.0.31
    cycler==0.10.0
    cymem==2.0.5
    Cython==0.29.24
    daft==0.0.4
    dask==2.12.0
    datascience==0.10.6
    debugpy==1.0.0
    decorator==4.4.2
    defusedxml==0.7.1
    descartes==1.1.0
    dill==0.3.4
    distributed==1.25.3
    dlib @ file:///dlib-19.18.0-cp37-cp37m-linux_x86_64.whl
    dm-tree==0.1.6
    docker-pycreds==0.4.0
    docopt==0.6.2
    docutils==0.17.1
    dopamine-rl==1.0.5
    earthengine-api==0.1.278
    easydict==1.9
    ecos==2.0.7.post1
    editdistance==0.5.3
    en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.5/en_core_web_sm-2.2.5.tar.gz
    entrypoints==0.3
    ephem==4.0.0.2
    et-xmlfile==1.1.0
    fa2==0.3.5
    fastai==1.0.61
    fastdtw==0.3.4
    fastprogress==1.0.0
    fastrlock==0.6
    fbprophet==0.7.1
    feather-format==0.4.1
    filelock==3.0.12
    firebase-admin==4.4.0
    fix-yahoo-finance==0.0.22
    Flask==1.1.4
    flatbuffers==1.12
    folium==0.8.3
    fsspec==2021.7.0
    future==0.18.2
    gast==0.4.0
    GDAL==2.2.2
    gdown==3.6.4
    gensim==3.6.0
    geographiclib==1.52
    geopy==1.17.0
    gin-config==0.4.0
    gitdb==4.0.7
    GitPython==3.1.18
    glob2==0.7
    google==2.0.3
    google-api-core==1.26.3
    google-api-python-client==1.12.8
    google-auth==1.34.0
    google-auth-httplib2==0.0.4
    google-auth-oauthlib==0.4.5
    google-cloud-bigquery==1.21.0
    google-cloud-bigquery-storage==1.1.0
    google-cloud-core==1.0.3
    google-cloud-datastore==1.8.0
    google-cloud-firestore==1.7.0
    google-cloud-language==1.2.0
    google-cloud-storage==1.18.1
    google-cloud-translate==1.5.0
    google-colab @ file:///colabtools/dist/google-colab-1.0.0.tar.gz
    google-pasta==0.2.0
    google-resumable-media==0.4.1
    googleapis-common-protos==1.53.0
    googledrivedownloader==0.4
    graphviz==0.10.1
    greenlet==1.1.1
    grpcio==1.39.0
    gspread==3.0.1
    gspread-dataframe==3.0.8
    gym==0.17.3
    h5py==3.1.0
    HeapDict==1.0.1
    hijri-converter==2.1.3
    holidays==0.10.5.2
    holoviews==1.14.5
    html5lib==1.0.1
    httpimport==0.5.18
    httplib2==0.17.4
    httplib2shim==0.0.3
    humanize==0.5.1
    hyperopt==0.1.2
    ideep4py==2.0.0.post3
    idna==2.10
    imageio==2.4.1
    imagesize==1.2.0
    imbalanced-learn==0.4.3
    imblearn==0.0
    imgaug==0.2.9
    importlib-metadata==4.6.4
    importlib-resources==5.2.2
    imutils==0.5.4
    inflect==2.1.0
    iniconfig==1.1.1
    intel-openmp==2021.3.0
    intervaltree==2.1.0
    ipykernel==4.10.1
    ipython==5.5.0
    ipython-genutils==0.2.0
    ipython-sql==0.3.9
    ipywidgets==7.6.3
    itsdangerous==1.1.0
    jax==0.2.19
    jaxlib @ https://storage.googleapis.com/jax-releases/cuda110/jaxlib-0.1.70+cuda110-cp37-none-manylinux2010_x86_64.whl
    jdcal==1.4.1
    jedi==0.18.0
    jieba==0.42.1
    Jinja2==2.11.3
    joblib==1.0.1
    jpeg4py==0.1.4
    jsonschema==2.6.0
    jupyter==1.0.0
    jupyter-client==5.3.5
    jupyter-console==5.2.0
    jupyter-core==4.7.1
    jupyterlab-pygments==0.1.2
    jupyterlab-widgets==1.0.0
    kaggle==1.5.12
    kapre==0.3.5
    keras==2.6.0
    Keras-Preprocessing==1.1.2
    keras-vis==0.4.1
    kiwisolver==1.3.1
    korean-lunar-calendar==0.2.1
    librosa==0.8.1
    lightgbm==2.2.3
    llvmlite==0.34.0
    lmdb==0.99
    LunarCalendar==0.0.9
    lxml==4.2.6
    Markdown==3.3.4
    MarkupSafe==2.0.1
    matplotlib==3.2.2
    matplotlib-inline==0.1.2
    matplotlib-venn==0.11.6
    missingno==0.5.0
    mistune==0.8.4
    mizani==0.6.0
    mkl==2019.0
    mlxtend==0.14.0
    more-itertools==8.8.0
    moviepy==0.2.3.5
    mpmath==1.2.1
    msgpack==1.0.2
    multidict==5.1.0
    multiprocess==0.70.12.2
    multitasking==0.0.9
    murmurhash==1.0.5
    music21==5.5.0
    natsort==5.5.0
    nbclient==0.5.4
    nbconvert==5.6.1
    nbformat==5.1.3
    nest-asyncio==1.5.1
    netCDF4==1.5.7
    networkx==2.6.2
    nibabel==3.0.2
    ninja==1.10.2
    nltk==3.2.5
    notebook==5.3.1
    numba==0.51.2
    numexpr==2.7.3
    numpy==1.19.5
    nvidia-ml-py3==7.352.0
    oauth2client==4.1.3
    oauthlib==3.1.1
    okgrade==0.4.3
    opencv-contrib-python==4.1.2.30
    opencv-python==4.1.2.30
    opencv-python-headless==4.5.3.56
    openpyxl==2.5.9
    opensimplex==0.3
    opt-einsum==3.3.0
    osqp==0.6.2.post0
    packaging==21.0
    palettable==3.3.0
    pandas==1.1.5
    pandas-datareader==0.9.0
    pandas-gbq==0.13.3
    pandas-profiling==1.4.1
    pandocfilters==1.4.3
    panel==0.12.1
    param==1.11.1
    parso==0.8.2
    pathlib==1.0.1
    pathtools==0.1.2
    patsy==0.5.1
    pep517==0.11.0
    pexpect==4.8.0
    pickleshare==0.7.5
    Pillow==7.1.2
    pip-tools==6.2.0
    plac==1.1.3
    plotly==4.4.1
    plotnine==0.6.0
    pluggy==0.7.1
    pooch==1.4.0
    portpicker==1.3.9
    prefetch-generator==1.0.1
    preshed==3.0.5
    prettytable==2.1.0
    progressbar2==3.38.0
    prometheus-client==0.11.0
    promise==2.3
    prompt-toolkit==1.0.18
    protobuf==3.17.3
    psutil==5.4.8
    psycopg2==2.7.6.1
    ptyprocess==0.7.0
    py==1.10.0
    pyarrow==3.0.0
    pyasn1==0.4.8
    pyasn1-modules==0.2.8
    pycocotools==2.0.2
    pycparser==2.20
    pyct==0.4.8
    pydantic==1.8.2
    pydata-google-auth==1.2.0
    pyDeprecate==0.3.1
    pydot==1.3.0
    pydot-ng==2.0.0
    pydotplus==2.0.2
    PyDrive==1.3.1
    pyemd==0.5.1
    pyerfa==2.0.0
    pyglet==1.5.0
    Pygments==2.6.1
    pygobject==3.26.1
    pyhocon==0.3.58
    pymc3==3.11.2
    PyMeeus==0.5.11
    pymongo==3.12.0
    pymystem3==0.2.0
    PyOpenGL==3.1.5
    pyparsing==2.4.7
    pyrsistent==0.18.0
    pysndfile==1.3.8
    PySocks==1.7.1
    pystan==2.19.1.1
    pytest==3.6.4
    python-apt==0.0.0
    python-chess==0.23.11
    python-dateutil==2.8.2
    python-louvain==0.15
    python-slugify==5.0.2
    python-utils==2.5.6
    pytorch-lightning==1.4.2
    pytorch-lightning-bolts==0.3.2
    pytz==2018.9
    pyviz-comms==2.1.0
    PyWavelets==1.1.1
    PyYAML==5.4.1
    pyzmq==22.2.1
    qdldl==0.1.5.post0
    qtconsole==5.1.1
    QtPy==1.10.0
    regex==2019.12.20
    requests==2.23.0
    requests-oauthlib==1.3.0
    resampy==0.2.2
    retrying==1.3.3
    rpy2==3.4.5
    rsa==4.7.2
    scikit-image==0.16.2
    scikit-learn==0.22.2.post1
    scipy==1.4.1
    screen-resolution-extra==0.0.0
    scs==2.1.4
    seaborn==0.11.1
    semver==2.13.0
    Send2Trash==1.8.0
    sentry-sdk==1.3.1
    setuptools-git==1.2
    Shapely==1.7.1
    shortuuid==1.0.1
    simplegeneric==0.8.1
    six==1.15.0
    sklearn==0.0
    sklearn-pandas==1.8.0
    smart-open==5.1.0
    smmap==4.0.0
    snowballstemmer==2.1.0
    sortedcontainers==2.4.0
    SoundFile==0.10.3.post1
    spacy==2.2.4
    Sphinx==1.8.5
    sphinxcontrib-serializinghtml==1.1.5
    sphinxcontrib-websupport==1.2.4
    SQLAlchemy==1.4.22
    sqlparse==0.4.1
    srsly==1.0.5
    statsmodels==0.10.2
    subprocess32==3.5.4
    sympy==1.7.1
    tables==3.4.4
    tabulate==0.8.9
    tblib==1.7.0
    tensorboard==2.6.0
    tensorboard-data-server==0.6.1
    tensorboard-plugin-wit==1.8.0
    tensorflow @ file:///tensorflow-2.6.0-cp37-cp37m-linux_x86_64.whl
    tensorflow-datasets==4.0.1
    tensorflow-estimator==2.6.0
    tensorflow-gcs-config==2.6.0
    tensorflow-hub==0.12.0
    tensorflow-metadata==1.2.0
    tensorflow-probability==0.13.0
    termcolor==1.1.0
    terminado==0.11.0
    testpath==0.5.0
    text-unidecode==1.3
    textblob==0.15.3
    Theano-PyMC==1.1.2
    thinc==7.4.0
    tifffile==2021.8.8
    toml==0.10.2
    tomli==1.2.1
    toolz==0.11.1
    torch @ https://download.pytorch.org/whl/cu102/torch-1.9.0%2Bcu102-cp37-cp37m-linux_x86_64.whl
    torchmetrics==0.5.0
    torchsummary==1.5.1
    torchtext==0.10.0
    torchvision @ https://download.pytorch.org/whl/cu102/torchvision-0.10.0%2Bcu102-cp37-cp37m-linux_x86_64.whl
    tornado==5.1.1
    tqdm==4.62.0
    traitlets==5.0.5
    tweepy==3.10.0
    typeguard==2.7.1
    typing-extensions==3.7.4.3
    tzlocal==1.5.1
    uritemplate==3.0.1
    urllib3==1.24.3
    vega-datasets==0.9.0
    wandb==0.12.0
    wasabi==0.8.2
    wcwidth==0.2.5
    webencodings==0.5.1
    Werkzeug==1.0.1
    widgetsnbextension==3.5.1
    wordcloud==1.5.0
    wrapt==1.12.1
    xarray==0.18.2
    xgboost==0.90
    xkit==0.0.0
    xlrd==1.1.0
    xlwt==1.3.0
    yarl==1.6.3
    yellowbrick==0.9.1
    zict==2.0.0
    zipp==3.5.0
    
    Source code(tar.gz)
    Source code(zip)
Owner
dusk (they/them)
GitHub Manager for The Coding Train
dusk (they/them)
Simple torch.nn.module implementation of Alias-Free-GAN style filter and resample

Alias-Free-Torch Simple torch module implementation of Alias-Free GAN. This repository including Alias-Free GAN style lowpass sinc filter @filter.py A

이준혁(Junhyeok Lee) 64 Dec 22, 2022
Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) in PyTorch

alias-free-gan-pytorch Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) This implementation

Kim Seonghyeon 502 Jan 3, 2023
Trying to understand alias-free-gan.

alias-free-gan-explanation Trying to understand alias-free-gan in my own way. [Chinese Version 中文版本] CC-BY-4.0 License. Tzu-Heng Lin motivation of thi

Tzu-Heng Lin 12 Mar 17, 2022
Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation

Alias-Free Generative Adversarial Networks (StyleGAN3) Official PyTorch implementation

NVIDIA Research Projects 4.8k Jan 9, 2023
Unofficial implementation of HiFi-GAN+ from the paper "Bandwidth Extension is All You Need" by Su, et al.

HiFi-GAN+ This project is an unoffical implementation of the HiFi-GAN+ model for audio bandwidth extension, from the paper Bandwidth Extension is All

Brent M. Spell 134 Dec 30, 2022
Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly

Ultra-Data-Efficient GAN Training: Drawing A Lottery Ticket First, Then Training It Toughly Code for this paper Ultra-Data-Efficient GAN Tra

VITA 77 Oct 5, 2022
DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time

DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time Introduction This is official implementation for DR-GAN (IEEE TCS

Kang Liao 18 Dec 23, 2022
Torchserve server using a YoloV5 model running on docker with GPU and static batch inference to perform production ready inference.

Yolov5 running on TorchServe (GPU compatible) ! This is a dockerfile to run TorchServe for Yolo v5 object detection model. (TorchServe (PyTorch librar

null 82 Nov 29, 2022
PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

PyTorch-LIT PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices. With

Amin Rezaei 157 Dec 11, 2022
Unofficial PyTorch implementation of Attention Free Transformer (AFT) layers by Apple Inc.

aft-pytorch Unofficial PyTorch implementation of Attention Free Transformer's layers by Zhai, et al. [abs, pdf] from Apple Inc. Installation You can i

Rishabh Anand 184 Dec 12, 2022
Unofficial pytorch implementation of paper "One-Shot Free-View Neural Talking-Head Synthesis for Video Conferencing"

One-Shot Free-View Neural Talking Head Synthesis Unofficial pytorch implementation of paper "One-Shot Free-View Neural Talking-Head Synthesis for Vide

ZLH 406 Dec 23, 2022
Unofficial implementation of One-Shot Free-View Neural Talking Head Synthesis

face-vid2vid Usage Dataset Preparation cd datasets wget https://yt-dl.org/downloads/latest/youtube-dl -O youtube-dl chmod a+rx youtube-dl python load_

worstcoder 68 Dec 30, 2022
Monocular 3D pose estimation. OpenVINO. CPU inference or iGPU (OpenCL) inference.

human-pose-estimation-3d-python-cpp RealSenseD435 (RGB) 480x640 + CPU Corei9 45 FPS (Depth is not used) 1. Run 1-1. RealSenseD435 (RGB) 480x640 + CPU

Katsuya Hyodo 8 Oct 3, 2022
Data-depth-inference - Data depth inference with python

Welcome! This readme will guide you through the use of the code in this reposito

Marco 3 Feb 8, 2022
A Python training and inference implementation of Yolov5 helmet detection in Jetson Xavier nx and Jetson nano

yolov5-helmet-detection-python A Python implementation of Yolov5 to detect head or helmet in the wild in Jetson Xavier nx and Jetson nano. In Jetson X

null 12 Dec 5, 2022
BMW TechOffice MUNICH 148 Dec 21, 2022
Unofficial pytorch implementation for Self-critical Sequence Training for Image Captioning. and others.

An Image Captioning codebase This is a codebase for image captioning research. It supports: Self critical training from Self-critical Sequence Trainin

Ruotian(RT) Luo 906 Jan 3, 2023