Neural Tangent Generalization Attacks (NTGA)

Overview

Neural Tangent Generalization Attacks (NTGA)

ICML 2021 Video | Paper | Quickstart | Results | Unlearnable Datasets | Competitions | Citation

Stars Forks Last Commit License

Overview

This is the repo for Neural Tangent Generalization Attacks, Chia-Hung Yuan and Shan-Hung Wu, In Proceedings of ICML 2021.

We propose the generalization attack, a new direction for poisoning attacks, where an attacker aims to modify training data in order to spoil the training process such that a trained network lacks generalizability. We devise Neural Tangent Generalization Attack (NTGA), a first efficient work enabling clean-label, black-box generalization attacks against Deep Neural Networks.

NTGA declines the generalization ability sharply, i.e. 99% -> 15%, 92% -> 33%, 99% -> 72% on MNIST, CIFAR10 and 2- class ImageNet, respectively. Please see Results or the main paper for more complete results. We also release the unlearnable MNIST, CIFAR-10, and 2-class ImageNet generated by NTGA, which can be found and downloaded in Unlearnable Datasets, and also launch learning on unlearnable data competitions. The following figures show one clean and the corresponding poisoned examples.

Clean NTGA

Installation

Our code uses the Neural Tangents library, which is built on top of JAX, and TensorFlow 2.0. To use JAX with GPU, please follow JAX's GPU installation instructions. Otherwise, install JAX on CPU by running

pip install jax jaxlib --upgrade

Once JAX is installed, clone and install remaining requirements by running

git clone https://github.com/lionelmessi6410/ntga.git
cd ntga
pip install -r requirements.txt

If you only want to examine the effectiveness of NTGAs, you can download datasets here and evaluate with evaluate.py or any code/model you prefer. To use evaluate.py, you do not need to install JAX externally, instead, all dependencies are specified in requirements.txt.

Usage

NTGA Attack

To generate poisoned data by NTGA, run

python generate_attack.py --model_type fnn --dataset cifar10 --save_path ./data/

There are few important arguments:

  • --model_type: A string. Surrogate model used to craft poisoned data. One of fnn or cnn. fnn and cnn stands for the fully-connected and convolutional networks, respectively.
  • --dataset: A string. One of mnist, cifar10, or imagenet.
  • --t: An integer. Time step used to craft poisoned data. Please refer to main paper for more details.
  • --eps: A float. Strength of NTGA. The default settings for MNIST, CIFAR-10, and ImageNet are 0.3, 8/255, and 0.1, respectively.
  • --nb_iter: An integer. Number of iteration used to generate poisoned data.
  • --block_size: An integer. Block size of B-NTGA algorithm.
  • --batch_size: An integer.
  • --save_path: A string.

In general, the attacks based on the FNN surrogate have greater influence against the fully-connected target networks, while the attacks based on the CNN surrogate work better against the convolutional target networks. The hyperparameter t plays an important role in NTGA, which controls when an attack will take effect during the training process of a target model. With a smaller t, the attack has a better chance to affect training before the early stop.

Both eps and block_size influence the effectiveness of NTGA. Larger eps leads to stronger but more distinguishable perturbations, while larger block_size results in better collaborative effect (stronger attack) in NTGA but also induces both higher time and space complexities. If you encounter out-of-memory (OOM) errors, especially when using --model_type cnn, please try to reduce block_size and batch_size to save memory usage.

For ImageNet or another custom dataset, please specify the path to the dataset in the code directly. The original clean data and the poisoned ones crafted by NTGA can be found and downloaded in Unlearnable Datasets.

Evaluation

Next, you can examine the effectiveness of the poisoned data crafted by NTGA by calling

python evaluate.py --model_type densenet121 --dataset cifar10 --dtype NTGA \
	--x_train_path ./data/x_train_cifar10_ntga_cnn_best.npy \
	--y_train_path ./data/y_train_cifar10.npy \
	--batch_size 128 --save_path ./figure/

If you are interested in the performance on the clean data, run

python evaluate.py --model_type densenet121 --dataset cifar10 --dtype Clean \
	--batch_size 128 --save_path ./figures/

This code will also plot the learning curve and save it in --save_path ./figures/. The following figures show the results of DenseNet121 trained on the CIFAR-10 dataset. The left figure demonstrates the normal learning curve, where the network is trained on the clean data, and the test accuracy achieves ~93%. On the contrary, the figure on the right-hand side shows the remarkable result of NTGA, where the training accuracy is ~100%, but test accuracy drops sharply to ~37%, in other word, the model fails to generalize.

There are few important arguments:

  • --model_type: A string. Target model used to evaluate poisoned data. One of fnn, fnn_relu, cnn, resnet18, resnet34, or densenet121.
  • --dataset: A string. One of mnist, cifar10, or imagenet.
  • --dtype: A string. One of Clean or NTGA, used for figure's title.
  • --x_train_path: A string. Path for poisoned training data. Leave it empty for clean data (mnist or cifar10).
  • --y_train_path: A string. Path for training labels. Leave it empty for clean data (mnist or cifar10).
  • --x_val_path: A string. Path for validation data.
  • --y_val_path: A string. Path for validation labels.
  • --x_test_path: A string. Path for testing data. The ground truth (y_test) is hidden. You can submit the prediction to Competitions.
  • --epoch: An integer.
  • --batch_size: An integer.
  • --save_path: A string.

Visualization

How does the poisoned data look like? Is it truly imperceptible to a human? You can visualize the poisoned data and their normalized perturbations by calling

python plot_visualization.py --dataset cifar10 \
	--x_train_path ./data/x_train_cifar10.npy \
	--x_train_ntga_path ./data/x_train_cifar10_ntga_fnn_t1.npy \
	--save_path ./figure/

The following figure shows some poisoned CIFAR-10 images. As we can see, they look almost the same as the original clean data. However, training on the clean data can achieve ~92% test accuracy, while training on the poisoned data the performance decreases sharply to ~35%.

Here we also visualize the high-resolution ImageNet dataset and find even more interesting results:

The perturbations are nearly invisible. The only difference between the clean and poisoned images is the hue!

There are few important arguments:

  • --dataset: A string. One of mnist, cifar10, or imagenet.
  • --x_train_path: A string. Path for clean training data.
  • --x_train_ntga_path: A string. Path for poisoned training data.
  • --num: An integer. Number of data to be visualized. The valid value is 1-5.
  • --save_path: A string.

Results

Here we briefly report the performance of NTGA and two baselines (RFA and DeepConfuse) equipped with the FNN and CNN surrogates. NTGA(·) denotes an attack generated by NTGA with a hyperparameter t mentioned in NTGA Attack, and NTGA(best) represents the results of the best hyperparameter of the specific dataset and surrogate combination. NTGA(1) is the most imperceptible poisoned data which has the lowest-frequency perturbations.

As we can see, NTGA attack has remarkable transferability across a wide range of models, including Fully-connected Networks (FNNs) and Convolutional Neural Networks (CNNs), trained under various conditions regarding the optimization method, loss function, etc.

FNN Surrogate

Target\Attack Clean RFA DeepConfuse NTGA(1) NTGA(best)
Dataset: MNIST
FNN 96.26 74.23 - 3.95 2.57
FNN-ReLU 97.87 84.62 - 2.08 2.18
CNN 99.49 86.99 - 33.80 26.03
Dataset: CIFAR-10
FNN 49.57 37.79 - 36.05 20.63
FNN-ReLU 54.55 43.19 - 40.08 25.95
CNN 78.12 74.71 - 48.46 36.05
ResNet18 91.92 88.76 - 39.72 39.68
DenseNet121 92.71 88.81 - 46.50 47.36
Dataset: ImageNet
FNN 91.60 90.20 - 76.60 76.60
FNN-ReLU 92.20 89.60 - 80.00 80.00
CNN 96.00 95.80 - 77.80 77.80
ResNet18 99.80 98.20 - 76.40 76.40
DenseNet121 98.40 96.20 - 72.80 72.80

CNN Surrogate

Target\Attack Clean RFA DeepConfuse NTGA(1) NTGA(best)
Dataset: MNIST
FNN 96.26 69.95 15.48 8.46 4.63
FNN-ReLU 97.87 84.15 17.50 3.48 2.86
CNN 99.49 94.92 46.21 23.89 15.64
Dataset: CIFAR-10
FNN 49.57 41.31 32.59 28.84 28.81
FNN-ReLU 54.55 46.87 35.06 32.77 32.11
CNN 78.12 73.80 44.84 41.17 40.52
ResNet18 91.92 89.54 41.10 34.74 33.29
DenseNet121 92.71 90.50 54.99 43.54 37.79
Dataset: ImageNet
FNN 91.60 87.80 90.80 75.80 75.80
FNN-ReLU 92.20 87.60 91.00 80.00 80.00
CNN 96.00 94.40 93.00 79.00 79.00
ResNet18 99.80 96.00 92.80 76.40 76.40
DenseNet121 98.40 90.40 92.80 80.60 80.60

Unlearnable Datasets

Here we publicly release the poisoned datasets generated by NTGA. We provide 5 versions for each dataset. FNN(·) denotes an attack generated by NTGA from the FNN surrogate with a hyperparameter t. The best hyperparameter t is selected according to the empirical results. For the 2-class ImageNet, we choose n01560419 and n01910747 (bulbul v.s. jellyfish) from the original ImageNet dataset. Please refer to the main paper and supplementary materials for more details.

  • MNIST
    • FNN(best) = FNN(64)
    • CNN(best) = CNN(64)
  • CIFAR-10
    • FNN(best) = FNN(4096)
    • CNN(best) = CNN(8)
  • ImageNet
    • FNN(best) = FNN(1)
    • CNN(best) = CNN(1)

Please support the project by hitting a star if you find this code or dataset is helpful for your research.

Dataset\Attack Clean FNN(1) FNN(best) CNN(1) CNN(best)
MNIST Download Download Download Download Download
CIFAR-10 Download Download Download Download Download
ImageNet Download Download Download Download Download

We do not provide the test label (y_test.npy) for each dataset since we launched Competitions. Nevertheless, if you are a researcher and need to use these data for academic purpose, we are willing to provide the complete dataset to you. Please send an email to [email protected]. Last but not least, using these data to participate in the competition defeats the entire purpose. So seriously, don't do that.

Competitions

We launch 3 competitions on Kaggle, where we are interested in learning from unlearnable MNIST, CIFAR-10, and 2-class ImageNet created by Neural Tangent Generalization Attack. Feel free to give it a shot if you are interested. We welcome people who can successfully train the model on the unlearnable data and overturn our conclusions.

Kaggle Competitions Unlearnable MNIST Unlearnable CIFAR-10 Unlearnable ImageNet

For instance, you can create a submission file by calling:

python evaluate.py --model_type resnet18 --dataset cifar10 --dtype NTGA \
	--x_train_path ./data/x_train_cifar10_unlearn.npy \
	--y_train_path ./data/y_train_cifar10.npy \
	--x_val_path ./data/x_val_cifar10.npy \
	--y_val_path ./data/y_val_cifar10.npy \
	--x_test_path ./data/x_test_cifar10.npy \
	--save_path ./figure/

The results will be stored as y_pred_cifar10.csv. Please specify --x_test_path for the test data.

Citation

If you find this code or dataset is helpful for your research, please cite our ICML 2021 paper.

@inproceedings{yuan2021neural,
	title={Neural Tangent Generalization Attacks},
	author={Yuan, Chia-Hung and Wu, Shan-Hung},
	booktitle={International Conference on Machine Learning},
	pages={12230--12240},
	year={2021},
	organization={PMLR}
}
You might also like...
Adversarial Attacks on Probabilistic Autoregressive Forecasting Models.

Attack-Probabilistic-Models This is the source code for Adversarial Attacks on Probabilistic Autoregressive Forecasting Models. This repository contai

Efficient Sparse Attacks on Videos using Reinforcement Learning
Efficient Sparse Attacks on Videos using Reinforcement Learning

EARL This repository provides a simple implementation of the work "Efficient Sparse Attacks on Videos using Reinforcement Learning" Example: Demo: Her

SCAAML is a deep learning framwork dedicated to side-channel attacks run on top of TensorFlow 2.x.
SCAAML is a deep learning framwork dedicated to side-channel attacks run on top of TensorFlow 2.x.

SCAAML (Side Channel Attacks Assisted with Machine Learning) is a deep learning framwork dedicated to side-channel attacks. It is written in python and run on top of TensorFlow 2.x.

Boosting Adversarial Attacks with Enhanced Momentum (BMVC 2021)

EMI-FGSM This repository contains code to reproduce results from the paper: Boosting Adversarial Attacks with Enhanced Momentum (BMVC 2021) Xiaosen Wa

PyTorch implementation of our method for adversarial attacks and defenses in hyperspectral image classification.
PyTorch implementation of our method for adversarial attacks and defenses in hyperspectral image classification.

Self-Attention Context Network for Hyperspectral Image Classification PyTorch implementation of our method for adversarial attacks and defenses in hyp

Using LSTM to detect spoofing attacks in an Air-Ground network
Using LSTM to detect spoofing attacks in an Air-Ground network

Using LSTM to detect spoofing attacks in an Air-Ground network Specifications IDE: Spider Packages: Tensorflow 2.1.0 Keras NumPy Scikit-learn Matplotl

A repository built on the Flow software package to explore cyber-security attacks on intelligent transportation systems.

A repository built on the Flow software package to explore cyber-security attacks on intelligent transportation systems.

BlockUnexpectedPackets - Preventing BungeeCord CPU overload due to Layer 7 DDoS attacks by scanning BungeeCord's logs

BlockUnexpectedPackets This script automatically blocks DDoS attacks that are sp

Source code of our TTH paper: Targeted Trojan-Horse Attacks on Language-based Image Retrieval.
Source code of our TTH paper: Targeted Trojan-Horse Attacks on Language-based Image Retrieval.

Targeted Trojan-Horse Attacks on Language-based Image Retrieval Source code of our TTH paper: Targeted Trojan-Horse Attacks on Language-based Image Re

Comments
  • Which jax and jaxline version did you use?

    Which jax and jaxline version did you use?

    Hi, nice work and thanks for sharing the code! Could you provide the details of your jax and jaxline versions? Perhaps the CUDA and cuDNN versions would also be helpful in debugging the building errors. Thanks for your help!

    opened by SongweiGe 1
  •  'ShapedArray' object has no attribute 'val'

    'ShapedArray' object has no attribute 'val'

    Hi, nice work, and thanks for sharing the code. When I was running the code, we encountered the following error.

    jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val'
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    

    The detailed output is below

    Loading dataset...
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    Building model...
    Generating NTGA....
      0%|                                                                                                                                                                      | 0/78 [00:00<?, ?it/s]
    Traceback (most recent call last):
      File "generate_attack.py", line 228, in <module>
        main()
      File "generate_attack.py", line 195, in main
        nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size)
      File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent
        fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size)
      File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method
        targeted)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss
        donated_invars=donated_invars, inline=inline)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
        outs = primitive.process(top_trace, fun, tracers, params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
        return trace.process_call(self, fun, tracers, params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 606, in process_call
        return primitive.impl(f, *tracers, **params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 593, in _xla_call_impl
        *unsafe_map(arg_spec, args))
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 262, in memoized_fun
        ans = call(fun, *args)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/xla.py", line 668, in _xla_callable
        fun, abstract_args, pe.debug_info_final(fun, "jit"))
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1284, in trace_to_jaxpr_final
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 829, in grad_f
        _, g = value_and_grad_f(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 901, in value_and_grad_f
        ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 1997, in _vjp
        flat_fun, primals_flat, reduce_axes=reduce_axes)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 115, in vjp
        out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 102, in linearize
        jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 505, in trace_to_jaxpr
        jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "generate_attack.py", line 146, in adv_loss
        ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/traceback_util.py", line 183, in reraise_with_filtered_traceback
        return fun(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/api.py", line 427, in cache_miss
        donated_invars=donated_invars, inline=inline)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
        outs = primitive.process(top_trace, fun, tracers, params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
        return trace.process_call(self, fun, tracers, params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/ad.py", line 318, in process_call
        result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
        outs = primitive.process(top_trace, fun, tracers, params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
        return trace.process_call(self, fun, tracers, params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 195, in process_call
        f, in_pvals, app, instantiate=False)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 303, in partial_eval
        out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux()
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1560, in bind
        return call_bind(self, fun, *args, **params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1551, in call_bind
        outs = primitive.process(top_trace, fun, tracers, params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/core.py", line 1563, in process
        return trace.process_call(self, fun, tracers, params)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1072, in process_call
        jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/interpreters/partial_eval.py", line 1262, in trace_to_subjaxpr_dynamic
        ans = fun.call_wrapped(*in_tracers)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
        ans = self.f(*args, **dict(self.params, **kwargs))
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
        fn_out = fn(*canonicalized_args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any
        **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1
        out_kernel = kernel_fn(kernel, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
        return kernel_fn(k, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
        k = f(k, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
        fn_out = fn(*canonicalized_args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
        **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
        out_kernel = kernel_fn(kernel, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
        return kernel_fn(k, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
        k = f(k, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
        fn_out = fn(*canonicalized_args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
        **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
        out_kernel = kernel_fn(kernel, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
        return kernel_fn(k, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
        k = f(k, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
        fn_out = fn(*canonicalized_args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
        **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel
        return _set_shapes(init_fn, kernel, out_kernel)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes
        shape1 = _propagate_shape(init_fn, in_kernel.shape1)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape
        out_shape = tree_map(lambda x: int(x.val), out_shape)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in tree_map
        return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/jax/_src/tree_util.py", line 168, in <genexpr>
        return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda>
        out_shape = tree_map(lambda x: int(x.val), out_shape)
    jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'ShapedArray' object has no attribute 'val'
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "generate_attack.py", line 228, in <module>
        main()
      File "generate_attack.py", line 195, in main
        nb_iter=args.nb_iter, clip_min=0, clip_max=1, batch_size=args.batch_size)
      File "/home/yila22/prj/ntga/attacks/projected_gradient_descent.py", line 82, in projected_gradient_descent
        fx_train_0, fx_test_0, eps, norm, clip_min, clip_max, targeted, batch_size)
      File "/home/yila22/prj/ntga/attacks/fast_gradient_method.py", line 66, in fast_gradient_method
        targeted)
      File "generate_attack.py", line 146, in adv_loss
        ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
        fn_out = fn(*canonicalized_args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2935, in kernel_fn_any
        **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2852, in kernel_fn_x1
        out_kernel = kernel_fn(kernel, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
        return kernel_fn(k, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
        k = f(k, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
        fn_out = fn(*canonicalized_args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
        **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
        out_kernel = kernel_fn(kernel, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
        return kernel_fn(k, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
        k = f(k, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
        fn_out = fn(*canonicalized_args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
        **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2843, in kernel_fn_kernel
        out_kernel = kernel_fn(kernel, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 185, in new_kernel_fn
        return kernel_fn(k, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 318, in kernel_fn
        k = f(k, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 84, in h
        return g(*args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/utils/utils.py", line 118, in getter_fn
        fn_out = fn(*canonicalized_args, **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2928, in kernel_fn_any
        **kwargs)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2844, in kernel_fn_kernel
        return _set_shapes(init_fn, kernel, out_kernel)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2780, in _set_shapes
        shape1 = _propagate_shape(init_fn, in_kernel.shape1)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in _propagate_shape
        out_shape = tree_map(lambda x: int(x.val), out_shape)
      File "/home/yila22/anaconda3/envs/ntga1/lib/python3.6/site-packages/neural_tangents/stax.py", line 2770, in <lambda>
        out_shape = tree_map(lambda x: int(x.val), out_shape)
    AttributeError: 'ShapedArray' object has no attribute 'val'
    
    opened by liuyixin-louis 0
Owner
Chia-Hung Yuan
My goal is to develop robust machine learning to reliably interact with a dynamic and uncertain world.
Chia-Hung Yuan
Defending graph neural networks against adversarial attacks (NeurIPS 2020)

GNNGuard: Defending Graph Neural Networks against Adversarial Attacks Authors: Xiang Zhang ([email protected]), Marinka Zitnik (marinka@hms.

Zitnik Lab @ Harvard 44 Dec 7, 2022
Stable Neural ODE with Lyapunov-Stable Equilibrium Points for Defending Against Adversarial Attacks

Stable Neural ODE with Lyapunov-Stable Equilibrium Points for Defending Against Adversarial Attacks Stable Neural ODE with Lyapunov-Stable Equilibrium

Kang Qiyu 8 Dec 12, 2022
Code used to generate the results appearing in "Train longer, generalize better: closing the generalization gap in large batch training of neural networks"

Train longer, generalize better - Big batch training This is a code repository used to generate the results appearing in "Train longer, generalize bet

Elad Hoffer 145 Sep 16, 2022
This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CNPs), Neural Processes (NPs), Attentive Neural Processes (ANPs).

The Neural Process Family This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CN

DeepMind 892 Dec 28, 2022
Image-Scaling Attacks and Defenses

Image-Scaling Attacks & Defenses This repository belongs to our publication: Erwin Quiring, David Klein, Daniel Arp, Martin Johns and Konrad Rieck. Ad

Erwin Quiring 163 Nov 21, 2022
Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet

Attack classification models with transferability, black-box attack; unrestricted adversarial attacks on imagenet, CVPR2021 安全AI挑战者计划第六期:ImageNet无限制对抗攻击 决赛第四名(team name: Advers)

null 51 Dec 1, 2022
transfer attack; adversarial examples; black-box attack; unrestricted Adversarial Attacks on ImageNet; CVPR2021 天池黑盒竞赛

transfer_adv CVPR-2021 AIC-VI: unrestricted Adversarial Attacks on ImageNet CVPR2021 安全AI挑战者计划第六期赛道2:ImageNet无限制对抗攻击 介绍 : 深度神经网络已经在各种视觉识别问题上取得了最先进的性能。

null 25 Dec 8, 2022
Code for "Diversity can be Transferred: Output Diversification for White- and Black-box Attacks"

Output Diversified Sampling (ODS) This is the github repository for the NeurIPS 2020 paper "Diversity can be Transferred: Output Diversification for W

null 50 Dec 11, 2022
Implementation of Wasserstein adversarial attacks.

Stronger and Faster Wasserstein Adversarial Attacks Code for Stronger and Faster Wasserstein Adversarial Attacks, appeared in ICML 2020. This reposito

null 21 Oct 6, 2022