This package requires jax, tensorflow, and numpy. Either tensorflow or scikit-learn can be used for loading data.
To run in a nix-shell with required packages (at specific versions used
nix-shell
Results are generated from main.py, running with arguments required, e.g. python main.py --lr
. The results as described in the paper are in csv files in the results
subfolder.
Figures in the paper can be reproduced by running analysis.py
. To generate plots with the bounds and errors using the same scale (as described in the appendix), set the variable BOUND_SCALE_AXIS
in this file to False.