PyTorch-centric library for evaluating and enhancing the robustness of AI technologies

Overview

Responsible AI Toolbox

PyPI Python version support GitHub Actions Tested with Hypothesis

A library that provides high-quality, PyTorch-centric tools for evaluating and enhancing both the robustness and the explainability of AI models.

Check out our documentation for more information.

The rAI-toolbox works great with PyTorch Lightning and Hydra 🐉 . Check out rai_toolbox.mushin to see how we use these frameworks to create efficient, configurable, and reproducible ML workflows with minimal boilerplate code.

Citation

Using rai_toolbox for your research? Please cite the following publication:

@article{soklaski2022tools,
  title={Tools and Practices for Responsible AI Engineering},
  author={Soklaski, Ryan and Goodwin, Justin and Brown, Olivia and Yee, Michael and Matterer, Jason},
  journal={arXiv preprint arXiv:2201.05647},
  year={2022}
}

Contributing

If you would like to contribute to this repo, please refer to our CONTRIBUTING.md document.

Disclaimer

DISTRIBUTION STATEMENT A. Approved for public release. Distribution is unlimited.

© 2022 MASSACHUSETTS INSTITUTE OF TECHNOLOGY

  • Subject to FAR 52.227-11 – Patent Rights – Ownership by the Contractor (May 2014)
  • SPDX-License-Identifier: MIT

This material is based upon work supported by the Under Secretary of Defense for Research and Engineering under Air Force Contract No. FA8702-15-D-0001. Any opinions, findings, conclusions or recommendations expressed in this material are those of the author(s) and do not necessarily reflect the views of the Under Secretary of Defense for Research and Engineering.

A portion of this research was sponsored by the United States Air Force Research Laboratory and the United States Air Force Artificial Intelligence Accelerator and was accomplished under Cooperative Agreement Number FA8750-19-2-1000. The views and conclusions contained in this document are those of the authors and should not be interpreted as representing the official policies, either expressed or implied, of the United States Air Force or the U.S. Government. The U.S. Government is authorized to reproduce and distribute reprints for Government purposes notwithstanding any copyright notation herein.

The software/firmware is provided to you on an As-Is basis.

Comments
  • Update workflows

    Update workflows

    See example use here: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/blob/workflow-docs/examples/MNIST-Translation-Robustness.ipynb

    • [x] Create base class for workflows
    • [x] Update docs
    • [x] Create tests for workflows
    opened by jgbos 7
  • Strange computational graph issue with `gradient_ascent` and `LightningModule`

    Strange computational graph issue with `gradient_ascent` and `LightningModule`

    First here's a working simple example of running gradient_ascent that works without error:

    from functools import partial
    import torch as tr
    from torchvision import models
    from rai_toolbox.optim import L2ProjectedOptim
    from rai_toolbox.perturbations.solvers import gradient_ascent
    
    model = models.resnet18()
    data = tr.rand(10, 3, 100, 100, dtype=tr.float)
    target = tr.randint(0, 2, size=(10,))
    pert = partial(
        gradient_ascent, optimizer=L2ProjectedOptim, epsilon=1.0, steps=1, lr=1.0
    )
    
    # run gradient ascent
    pert(model=model, data=data, target=target)
    

    Now setup and run the same thing using Trainer.predict:

    import pytorch_lightning as pl
    
    class Lit(pl.LightningModule):
        def __init__(self):
            super().__init__()
            self.model = model
            self.pert = pert
    
        def predict_step(self, batch, *args, **kwargs):
            data, target = batch
            data = self.pert(model=self.model, data=data, target=target)
            logits = self.model(data)
            return logits.sum()
    
    trainer = pl.Trainer()
    trainer.predict(
        Lit(),
        datamodule=pl.LightningDataModule.from_datasets(
            predict_dataset=tr.utils.data.TensorDataset(data, target),
            batch_size=1,
            num_workers=0,
        ),
    )
    

    Here we get the following error:

    ...
    /tmp/ipykernel_74682/1909129363.py in predict_step(self, batch, *args, **kwargs)
         27     def predict_step(self, batch, *args, **kwargs):
         28         data, target = batch
    ---> 29         data = self.pert(model=self, data=data, target=target)
         30         logits = self.model(data)
         31         return logits.sum()
    
    ~/projects/raiden/rai_toolbox/src/rai_toolbox/perturbations/solvers.py in gradient_ascent(model, data, target, optimizer, steps, perturbation_model, targeted, use_best, criterion, reduction_fn, **optim_kwargs)
        277             # Update the perturbation
        278             optim.zero_grad(set_to_none=True)
    --> 279             loss.backward()
        280             optim.step()
        281 
    
    ~/.conda/envs/rai_md/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
        394                 create_graph=create_graph,
        395                 inputs=inputs)
    --> 396         torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
        397 
        398     def register_hook(self, hook):
    
    ~/.conda/envs/rai_md/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
        171     # some Python versions print out the first line of a multi-line function
        172     # calls in the traceback and some print out the last line
    --> 173     Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
        174         tensors, grad_tensors_, retain_graph, create_graph, inputs,
        175         allow_unreachable=True, accumulate_grad=True)  # Calls into the C++ engine to run the backward pass
    
    RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
    

    If I enter debug everything seems to be setup correctly except that pmodel(data) does not return a tensor with grad_fn!!

    #
    # pdb at `loss.backward()` line
    #
    > pmodel.delta.requires_grad
    True
    
    > tr.is_grad_enabled()
    True
    
    > pmodel.delta + data
    ... # tensor output without `grad_fn`
    
    # try reinitializing
    > perturbation_model(data)(data)
    ... # tensor output WITH `grad_fn`
    

    I have no idea how to debug this and find out what is wrong.

    @rsokl do you get this error in your environment?

    opened by jgbos 6
  • Docs: perturbation explanation

    Docs: perturbation explanation

    Starting an explanation on our approach to data perturbations. I still intend to add more to this today, but feel free to take a look and let me know your thoughts on how it's going so far. Especially what should/shouldn't be included in this

    opened by oliviamb 4
  • CIFAR10-Adversarial-Perturbations.ipynb -- Standard rai-toolbox[mushin] install doesn't include dill module

    CIFAR10-Adversarial-Perturbations.ipynb -- Standard rai-toolbox[mushin] install doesn't include dill module

    CIFAR10-Adversarial-Perturbations.ipynb example jupyter notebook attempts to load the pretrained models and fails with ModuleNotFoundError: No module named 'dill'. Dill module is not included in the standard rai-toolbox[mushin] install.

    Full error traceback below:

    ModuleNotFoundError                       Traceback (most recent call last)
    Input In [9], in <cell line: 3>()
          1 # Load pretrained model that was trained using a robust approach (i.e., adversarial training)
          2 ckpt_robust = "mitll_cifar_l2_1_0.pt"
    ----> 3 model_robust = load_model(ckpt_robust)
          4 model_robust.eval();
          6 # Load pretrained model that was trained with standard approach
    
    Input In [7], in load_model(ckpt)
          2 def load_model(ckpt):
    ----> 3     base_model = load_from_checkpoint(
          4         model = resnet50(),
          5         ckpt = ckpt,
          6         weights_key="state_dict",
          7     )
          9     normalizer = transforms.Normalize(
         10         mean=[0.4914, 0.4822, 0.4465],
         11         std=[0.2023, 0.1994, 0.2010],
         12     )
         14     model = nn.Sequential(normalizer, base_model)
    
    File ~/dev/rai-toolbox-james/responsible-ai-toolbox/src/rai_toolbox/mushin/_utils.py:60, in load_from_checkpoint(model, ckpt, weights_key, weights_key_strip, model_attr)
         57     ckpt = Path.home() / ".torch" / "models" / ckpt
         58 log.info(f"Loading model checkpoint from {ckpt}")
    ---> 60 ckpt_data: Dict[str, Any] = torch.load(ckpt, map_location="cpu")
         62 if weights_key is not None:
         63     assert weights_key in ckpt_data
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:713, in load(f, map_location, pickle_module, **pickle_load_args)
        711             return torch.jit.load(opened_file)
        712         return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    --> 713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:930, in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
        928 unpickler = UnpicklerWrapper(f, **pickle_load_args)
        929 unpickler.persistent_load = persistent_load
    --> 930 result = unpickler.load()
        932 deserialized_storage_keys = pickle_module.load(f, **pickle_load_args)
        934 offset = f.tell() if f_should_read_directly else None
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:746, in _legacy_load.<locals>.UnpicklerWrapper.find_class(self, mod_name, name)
        744     except KeyError:
        745         pass
    --> 746 return super().find_class(mod_name, name)
    
    ModuleNotFoundError: No module named 'dill'
    

    Installing dill via pip install dill in the python environment corrects this error.

    opened by miscpeeps 3
  • test for ensuring hydra ddp raises is raising for the wrong reason

    test for ensuring hydra ddp raises is raising for the wrong reason

    @jgbos

    In the following test:

    https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/blob/3882320391c87b6bf6330a09471a9808fca01160/tests/test_mushin/test_lightning_hydra_ddp.py#L27-L40

    launch(Config, pl_main_task) raises TypeError because pl_main_task doesn't accept a single config (pyright warned me about this). I doubt this is what you meant to exercise in this test.

    I am confused about what this test is doing. Config = make_config(trainer=trainer, wrong_config_name=module, devices=2) makes it seem like we are making sure that launch fails for a config with a bad field name, but the test seems like should be exercising ddp

    bug test-suite 
    opened by rsokl 3
  • CIFAR10-Adversarial-Perturbations.ipynb -- Load pretrained CIFAR-10 models not included and incorrectly named

    CIFAR10-Adversarial-Perturbations.ipynb -- Load pretrained CIFAR-10 models not included and incorrectly named

    CIFAR10-Adversarial-Perturbations.ipynb example jupyter notebook and tutorial reference mitll_cifar_l2_1_0.pt and mitll_cifar_nat.pt as pretrained CIFAR-10 models. These models are not included in the standard rai-toolbox[mushin] install (perhaps due to licensing or desire to have most up-to-date models?).

    Models download from urls at robustness Github are named cifar_l2_1_0.pt and cifar_nat.pt and will cause the following error on In[10] of CIFAR10-Adversarial-Perturbations.ipynb:

    FileNotFoundError                         Traceback (most recent call last)
    Input In [8], in <cell line: 3>()
          1 # Load pretrained model that was trained using a robust approach (i.e., adversarial training)
          2 ckpt_robust = "mitll_cifar_l2_1_0.pt"
    ----> 3 model_robust = load_model(ckpt_robust)
          4 model_robust.eval();
          6 # Load pretrained model that was trained with standard approach
    
    Input In [7], in load_model(ckpt)
          2 def load_model(ckpt):
    ----> 3     base_model = load_from_checkpoint(
          4         model = resnet50(),
          5         ckpt = ckpt,
          6         weights_key="state_dict",
          7     )
          9     normalizer = transforms.Normalize(
         10         mean=[0.4914, 0.4822, 0.4465],
         11         std=[0.2023, 0.1994, 0.2010],
         12     )
         14     model = nn.Sequential(normalizer, base_model)
    
    File ~/dev/rai-toolbox-james/responsible-ai-toolbox/src/rai_toolbox/mushin/_utils.py:60, in load_from_checkpoint(model, ckpt, weights_key, weights_key_strip, model_attr)
         57     ckpt = Path.home() / ".torch" / "models" / ckpt
         58 log.info(f"Loading model checkpoint from {ckpt}")
    ---> 60 ckpt_data: Dict[str, Any] = torch.load(ckpt, map_location="cpu")
         62 if weights_key is not None:
         63     assert weights_key in ckpt_data
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:699, in load(f, map_location, pickle_module, **pickle_load_args)
        696 if 'encoding' not in pickle_load_args.keys():
        697     pickle_load_args['encoding'] = 'utf-8'
    --> 699 with _open_file_like(f, 'rb') as opened_file:
        700     if _is_zipfile(opened_file):
        701         # The zipfile reader is going to advance the current file position.
        702         # If we want to actually tail call to torch.jit.load, we need to
        703         # reset back to the original position.
        704         orig_position = opened_file.tell()
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:231, in _open_file_like(name_or_buffer, mode)
        229 def _open_file_like(name_or_buffer, mode):
        230     if _is_path(name_or_buffer):
    --> 231         return _open_file(name_or_buffer, mode)
        232     else:
        233         if 'w' in mode:
    
    File ~/anaconda3/envs/rai-toolbox-james/lib/python3.9/site-packages/torch/serialization.py:212, in _open_file.__init__(self, name, mode)
        211 def __init__(self, name, mode):
    --> 212     super(_open_file, self).__init__(open(name, mode))
    
    FileNotFoundError: [Errno 2] No such file or directory: '/home/scott/.torch/models/mitll_cifar_l2_1_0.pt'
    

    Models must be renamed and manually copied to /home/{$USER}/.torch/models to proceed with tutorial.

    opened by miscpeeps 2
  • `zen` should not attempt to populate `*args` and `**kwargs`

    `zen` should not attempt to populate `*args` and `**kwargs`

    Previously zen would attempt to find a kwargs field in the config:

    Before:

    def f(x, **kwargs): return x
    
    cfg = make_config(x=1)
    
    zen(f)(cfg)  # AttributeError: 'Config' object has no attribute 'kwargs'
    

    Now zen skips *args, **kwargs.

    def f(x, **kwargs): return x
    
    cfg = make_config(x=1)
    
    zen(f)(cfg)  # returns 1
    

    In the future we might permit some configured behavior for populating these.

    bug 
    opened by rsokl 2
  • Update gradient-descent solver

    Update gradient-descent solver

    • Renames: gradient_descent -> gradient_ascent
    • (bug fix) Ensures that returned loss always has the correct sign. Previously, when targeted=False the returned loss values would be negated relative to the actual loss landscape
    • Adds examples section to docs
    • Ensures that data and target can be any array-like input, not necessarily a tensor
    bug code quality 
    opened by rsokl 2
  • Bump pydata-sphinx-theme from 0.8.1 to 0.11.0 in /docs

    Bump pydata-sphinx-theme from 0.8.1 to 0.11.0 in /docs

    Bumps pydata-sphinx-theme from 0.8.1 to 0.11.0.

    Release notes

    Sourced from pydata-sphinx-theme's releases.

    v0.11.0

    What's Changed

    New Contributors

    Full Changelog: https://github.com/pydata/pydata-sphinx-theme/compare/v0.10.1...v0.11.0

    v0.11.0rc3

    What's Changed

    ... (truncated)

    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    dependencies python 
    opened by dependabot[bot] 1
  • Bump sphinx-tabs from 3.3.1 to 3.4.1 in /docs

    Bump sphinx-tabs from 3.3.1 to 3.4.1 in /docs

    Bumps sphinx-tabs from 3.3.1 to 3.4.1.

    Release notes

    Sourced from sphinx-tabs's releases.

    Version 3.4.1

    What's Changed

    Full Changelog: https://github.com/executablebooks/sphinx-tabs/compare/v3.4.0...v3.4.1

    Version 3.4.0

    What's Changed

    New Contributors

    Full Changelog: https://github.com/executablebooks/sphinx-tabs/compare/v3.3.1...v3.4.0

    Changelog

    Sourced from sphinx-tabs's changelog.

    3.4.1 - 2022-97-02

    Added

    • Weekly scheduled testing, to catch breaking changes in unpinned dependencies

    Changed

    • docutils version pin to allow use of verison 0.18.x

    Removed

    • sphinx version pinning - only the latest version of sphinx will now be fully supported, but previous versions will work if sphinx dependencies (i.e. jinja2) are managed correctly. This is inline with the approach at sphinx
    • tests that were specific to older versions of sphinx and pygments
    • jinja2 version pinning, as this is now pinned in latest version of sphinx

    3.4.0 - 2022-06-26

    Added

    • Testing for sphinx 5
    • Tesing for python 3.10

    Fixed

    • Fixed parsing of MyST content, where first line was being stripped
    • Typos in documentation
    • Failing regression tests

    Changed

    • Testing to use an up-to-date pytest version

    Removed

    • Testing for python 3.6 and sphinx versions 2 and 4 (see #164). Note that the package will likely continue to work fine with these, but this won't be assured by tests
    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    dependencies python 
    opened by dependabot[bot] 1
  • Bump sphinx-codeautolink from 0.10.0 to 0.12.0 in /docs

    Bump sphinx-codeautolink from 0.10.0 to 0.12.0 in /docs

    Bumps sphinx-codeautolink from 0.10.0 to 0.12.0.

    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    dependencies python 
    opened by dependabot[bot] 1
  • Bump sphinx from 5.3.0 to 6.0.0 in /docs

    Bump sphinx from 5.3.0 to 6.0.0 in /docs

    Bumps sphinx from 5.3.0 to 6.0.0.

    Release notes

    Sourced from sphinx's releases.

    v6.0.0

    Changelog: https://www.sphinx-doc.org/en/master/changes.html

    v6.0.0b2

    Changelog: https://www.sphinx-doc.org/en/master/changes.html

    v6.0.0b1

    Changelog: https://www.sphinx-doc.org/en/master/changes.html

    Changelog

    Sourced from sphinx's changelog.

    Release 6.0.0 (released Dec 29, 2022)

    Dependencies

    • #10468: Drop Python 3.6 support
    • #10470: Drop Python 3.7, Docutils 0.14, Docutils 0.15, Docutils 0.16, and Docutils 0.17 support. Patch by Adam Turner

    Incompatible changes

    • #7405: Removed the jQuery and underscore.js JavaScript frameworks.

      These frameworks are no longer be automatically injected into themes from Sphinx 6.0. If you develop a theme or extension that uses the jQuery, $, or $u global objects, you need to update your JavaScript to modern standards, or use the mitigation below.

      The first option is to use the sphinxcontrib.jquery_ extension, which has been developed by the Sphinx team and contributors. To use this, add sphinxcontrib.jquery to the extensions list in conf.py, or call app.setup_extension("sphinxcontrib.jquery") if you develop a Sphinx theme or extension.

      The second option is to manually ensure that the frameworks are present. To re-add jQuery and underscore.js, you will need to copy jquery.js and underscore.js from the Sphinx repository_ to your static directory, and add the following to your layout.html:

      .. code-block:: html+jinja

      {%- block scripts %} {{ super() }} {%- endblock %}

      .. _sphinxcontrib.jquery: https://github.com/sphinx-contrib/jquery/

      Patch by Adam Turner.

    • #10471, #10565: Removed deprecated APIs scheduled for removal in Sphinx 6.0. See :ref:dev-deprecated-apis for details. Patch by Adam Turner.

    • #10901: C Domain: Remove support for parsing pre-v3 style type directives and roles. Also remove associated configuration variables c_allow_pre_v3 and c_warn_on_allowed_pre_v3. Patch by Adam Turner.

    Features added

    ... (truncated)

    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    dependencies python 
    opened by dependabot[bot] 1
  • Update madry example

    Update madry example

    Currently we use Workflow.run within hydra.main, which no longer works: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/blob/5348f7d6a96837d0f9d9ce1b5e71ebfdee6f4b88/experiments/madry/run.py#L29

    We should update this to leverage zen, but to make sure that the plotting still works (i.e. we give the workflow the necessary context to gather the xarray)

    opened by rsokl 0
  • Add Pickled Hydra Runs (e.g., Rerun) to Support PL `ddp`

    Add Pickled Hydra Runs (e.g., Rerun) to Support PL `ddp`

    In this PR we will attempt to address two issues:

    1. Reproducible Hydra experiments purely from the run directory by pickling both the runtime configuration and the task function
      • An extension of Hydra's experimental rerun
    2. Solving Hydra+DDP for PyTorch Lightning ddp strategy by saving the task function
      • The current solution in HydraDDP has strong constraints on the expected task function. This limits what the user can do in their experiments.

    Hydra Rerun Capability

    Here we take advantage of Hydra Callbacks to save the runtime configuration and the desired task function. Currently our callback takes a task function on initialization but future Hydra version's may allow the Hydra to pass the task function to the callback methds.

    Callback implementation: MushinPickleJobCallback. This takes in a Hydra task function on initialization and saves the task function and runtime configuration in the hydra.runtime.output_dir folder. The pickled files are stored in:

    <hydra.runtime.output_dir>/config.pickle
    <hydra.runtime.output_dir>/task_fn.pickle
    

    This implementation uses cloudpickle to support pickling of the task function. The only downside of this approach is that the task function must be hashable for pickling and "instantiable" for Hydra from the command line, e.g., defining the task function in the notebook won't work.

    Note: Submitit is capable of pickling functions that were created in __main__, so this should be possible

    Execution: With the configuration and task function saved in the job directory, we can rerun any experiment using:

    $ python -m rai_toolbox.mushin._hydra_rerun +config=<path to config.pickle> +task_fn=<path to task_fn.pickle>
    

    Lightning DDP

    Challenges this PR solves for Hydra+DDP:

    • Runs from notebook
    • Supports generic task functions (i.e., solves HydraDDP issue)
    • Task functions can run multiple Trainer methods (e.g., Trainer.fit followed by Trainer.test). HydraDDP does not support these types of task functions

    First we must configure our custom Hydra Callback, MushinPickleJobCallback:

    task_fn_cfg = builds(...)
    
    callback_cfg = dict(
        save_job_info=builds(MushinPickleJobCallback, task_fn=task_fn_cfg)
    )
    
    cs = ConfigStore.instance()
    cs.store(name="pickle_job", group="hydra/callbacks", node=callback_cfg)
    

    The Trainer strategy can then be configured with our costum Lightning ddp strategy, HydraRerunDDP:

    TrainerConfig = builds(Trainer,   strategy=builds(HydraRerunDDP))
    

    We must set hydra/callbacks in the overrides to launch a job:

    task_fn = instantiate(task_fn_cfg)
    launch(Config, task_fn, overrides=["hydra/callbacks=pickle_job", ...])
    

    Notes

    • MushinPickleJobCallback will clean up the PL environment automatically at the end of a job.
    • See tests for examples.

    I plan to update this comment to better describe everything

    TODOS

    • [ ] Should we deprecate HydraDDP in favor of this
    • [ ] Can we pickle and use task functions built in a "main" setting like the notebook?
    • [ ] Structure of Hydra specific and Lightning specific code
    • [ ] More tests: - Validate results, not just pickle file available - Test Hydra rerun without Lightning
    opened by jgbos 1
  • Implements elastic-net attack

    Implements elastic-net attack

    Derived from: https://arxiv.org/pdf/1709.04114.pdf

    Here is a trivial scenario where we are merely perturbing the "logits" themselves so that the specified targets will be optimized for. Let's see that the longer we run the optimizer, the more the learned perturbation shrinks (while still amounting to a successful attack).

    >>> from rai_toolbox.perturbations.solvers import elastic_net_attack
    >>> logits = [[0.497, 0.503]]
    >>> target = [0]
    
    >>> for num_steps in [1, 10, 100]:
    ...     _, x_adv, _ = elastic_net_attack(
    ...         model=lambda x: x,
    ...         data=logits,
    ...         target=target,
    ...         beta=1e-3,
    ...         c=2,
    ...         steps=num_steps,
    ...         confidence=.01,
    ...         lr=0.5,
    ...     )
    ...     print(f"num-steps: {num_steps}\n{x_adv}")
    num-steps: 1
    tensor([[ 1.4960, -0.4960]])
    num-steps: 10
    tensor([[0.5062, 0.4938]])
    num-steps: 100
    tensor([[0.5018, 0.4982]])
    
    opened by rsokl 0
  • Use fused multiply-add to apply `grad_scale` and `grad_bias`

    Use fused multiply-add to apply `grad_scale` and `grad_bias`

    https://pytorch.org/docs/stable/generated/torch.add.html

    >>> a = torch.randn(4)
    >>> a
    tensor([ 0.0202,  1.0985,  1.3506, -0.6056])
    
    >>> b = torch.randn(4)
    >>> b
    tensor([-0.9732, -0.3497,  0.6245,  0.4022])
    >>> c = torch.randn(4, 1)
    >>> c
    tensor([[ 0.3743],
            [-1.7724],
            [-0.5811],
            [-0.8017]])
    >>> torch.add(b, c, alpha=10)
    tensor([[  2.7695,   3.3930,   4.3672,   4.1450],
            [-18.6971, -18.0736, -17.0994, -17.3216],
            [ -6.7845,  -6.1610,  -5.1868,  -5.4090],
            [ -8.9902,  -8.3667,  -7.3925,  -7.6147]])
    
    opened by rsokl 0
Releases(v0.2.1)
  • v0.2.1(Jun 16, 2022)

    See changelog

    What's Changed

    • Fix TopQGradient device mismatch by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/64
    • Fixes to_xarray when target_job_dirs points to job that performed multirun over sequence values by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/68

    Full Changelog: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/compare/v0.2.0...v0.2.1

    Source code(tar.gz)
    Source code(zip)
  • v0.2.0(Jun 1, 2022)

    See changelog for details

    What's Changed

    • zen should not attempt to populate *args and **kwargs by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/48
    • Workflow Improvement by @jgbos in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/47
    • Adds zen callbacks by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/49
    • Remove numpy dependency (defer to pytorch) by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/50
    • Make methods static where possible; simplify examples; cleanup formating by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/51
    • Add working_subdir data variable to xarray by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/52
    • Improve parity between pre-step and post-step method names by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/54
    • fix typo in univ_adv_pert.rst by @Jasha10 in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/56
    • Update workflows by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/55
    • Adds Support for Lightning's Trainer.predict by @jgbos in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/53
    • Fix hypothesis by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/58
    • Add pre-task method to workflow by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/57
    • Deprecate ParamTransformingOptimizer.project by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/59
    • Add pre-release and nightly CI jobs by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/60
    • Ensure workflow overrides roundtrip by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/61
    • Deprecate evaluation_task in favor of task by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/62
    • Enable user-specified functions for loading metrics files by @rsokl in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/63

    New Contributors

    • @Jasha10 made their first contribution in https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/pull/56

    Full Changelog: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/compare/v0.1.1...v0.2.0

    Source code(tar.gz)
    Source code(zip)
A library for preparing, training, and evaluating scalable deep learning hybrid recommender systems using PyTorch.

collie_recs Collie is a library for preparing, training, and evaluating implicit deep learning hybrid recommender systems, named after the Border Coll

ShopRunner 97 Jan 3, 2023
A library for preparing, training, and evaluating scalable deep learning hybrid recommender systems using PyTorch.

collie Collie is a library for preparing, training, and evaluating implicit deep learning hybrid recommender systems, named after the Border Collie do

ShopRunner 96 Dec 29, 2022
SIEM Logstash parsing for more than hundred technologies

LogIndexer Pipeline Logstash Parsing Configurations for Elastisearch SIEM and OpenDistro for Elasticsearch SIEM Why this project exists The overhead o

null 146 Dec 29, 2022
ROSITA: Enhancing Vision-and-Language Semantic Alignments via Cross- and Intra-modal Knowledge Integration

ROSITA News & Updates (24/08/2021) Release the demo to perform fine-grained semantic alignments using the pretrained ROSITA model. (15/08/2021) Releas

Vision and Language Group@ MIL 48 Dec 23, 2022
Implementation of the paper All Labels Are Not Created Equal: Enhancing Semi-supervision via Label Grouping and Co-training

SemCo The official pytorch implementation of the paper All Labels Are Not Created Equal: Enhancing Semi-supervision via Label Grouping and Co-training

null 42 Nov 14, 2022
[AAAI2021] The source code for our paper 《Enhancing Unsupervised Video Representation Learning by Decoupling the Scene and the Motion》.

DSM The source code for paper Enhancing Unsupervised Video Representation Learning by Decoupling the Scene and the Motion Project Website; Datasets li

Jinpeng Wang 114 Oct 16, 2022
Code & Data for Enhancing Photorealism Enhancement

Code & Data for Enhancing Photorealism Enhancement

Intel ISL (Intel Intelligent Systems Lab) 1.1k Jan 8, 2023
MEDS: Enhancing Memory Error Detection for Large-Scale Applications

MEDS: Enhancing Memory Error Detection for Large-Scale Applications Prerequisites cmake and clang Build MEDS supporting compiler $ make Build Using Do

Secomp Lab at Purdue University 34 Dec 14, 2022
Enhancing Knowledge Tracing via Adversarial Training

Enhancing Knowledge Tracing via Adversarial Training This repository contains source code for the paper "Enhancing Knowledge Tracing via Adversarial T

Xiaopeng Guo 14 Oct 24, 2022
Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning.

Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive Learning. Enhancing Aspect-Based Sentiment Analysis with Supervised Contrastive

HLT@HIT(SZ) 7 Dec 16, 2021
Source code for paper "ATP: AMRize Than Parse! Enhancing AMR Parsing with PseudoAMRs" @NAACL-2022

ATP: AMRize Then Parse! Enhancing AMR Parsing with PseudoAMRs Hi this is the source code of our paper "ATP: AMRize Then Parse! Enhancing AMR Parsing w

Chen Liang 13 Nov 23, 2022
NUANCED is a user-centric conversational recommendation dataset that contains 5.1k annotated dialogues and 26k high-quality user turns.

NUANCED: Natural Utterance Annotation for Nuanced Conversation with Estimated Distributions Overview NUANCED is a user-centric conversational recommen

Facebook Research 18 Dec 28, 2021
The official repo for OC-SORT: Observation-Centric SORT on video Multi-Object Tracking. OC-SORT is simple, online and robust to occlusion/non-linear motion.

OC-SORT Observation-Centric SORT (OC-SORT) is a pure motion-model-based multi-object tracker. It aims to improve tracking robustness in crowded scenes

Jinkun Cao 325 Jan 5, 2023
Object-Centric Learning with Slot Attention

Slot Attention This is a re-implementation of "Object-Centric Learning with Slot Attention" in PyTorch (https://arxiv.org/abs/2006.15055). Requirement

Untitled AI 72 Jan 2, 2023
EMNLP'2021: Simple Entity-centric Questions Challenge Dense Retrievers

EntityQuestions This repository contains the EntityQuestions dataset as well as code to evaluate retrieval results from the the paper Simple Entity-ce

Princeton Natural Language Processing 119 Sep 28, 2022
EMNLP'2021: Simple Entity-centric Questions Challenge Dense Retrievers

EntityQuestions This repository contains the EntityQuestions dataset as well as code to evaluate retrieval results from the the paper Simple Entity-ce

Princeton Natural Language Processing 50 Sep 24, 2021
Does MAML Only Work via Feature Re-use? A Data Set Centric Perspective

Does-MAML-Only-Work-via-Feature-Re-use-A-Data-Set-Centric-Perspective Does MAML Only Work via Feature Re-use? A Data Set Centric Perspective Installin

null 2 Nov 7, 2022
Team nan solution repository for FPT data-centric competition. Data augmentation, Albumentation, Mosaic, Visualization, KNN application

FPT_data_centric_competition - Team nan solution repository for FPT data-centric competition. Data augmentation, Albumentation, Mosaic, Visualization, KNN application

Pham Viet Hoang (Harry) 2 Oct 30, 2022
StyleGAN-Human: A Data-Centric Odyssey of Human Generation

StyleGAN-Human: A Data-Centric Odyssey of Human Generation Abstract: Unconditional human image generation is an important task in vision and graphics,

stylegan-human 762 Jan 8, 2023