Git Product home page Git Product logo

worrisome-nn's Introduction

code release for ECAI23 paper entitled

Worrisome Properties of Neural Network Controllers and Their Symbolic Representations by Jacek Cyranka, Kevin E M Church and Jean-Philippe Lessard

Extended paper on ArXiV.

Codebase consists of two separate packages (Python and Julia).

Controllers training, sybolic/small-net regression and persistent solutions search is done in Python, whereas computer-assisted proofs are performed in Julia due to availability of convenient libraries, weights of controllers including neural networks are transferred through csv files.

Raw csv datafiles with controller data imported in the paper appendix are found in controllers_data/pendulum and controllers_data/cartpole_swingup directories corresponding to the respective controller class.

Studied problems

At present our implementation is limited to two simple benchmark problems implemented in environment/cartpole_swingup_modif.py and environment/pendulum.py. We required the problems to be defined in a closed form and reimplemented in Julia in order to perform the proofs. We are working on extending our method to more complicated problems including the MuJoCo suite standard benchmark.

Controllers Data

Each directory corresponding to the studied controller class contains info on the studied controllers, the computed metrics reported in the paper (like average return & penalty).

  1. ReLU... contain data of the deepNN controller (e.g., controllers_data/pendulum/ReLU_256_256_256_simple2),
  2. SmallNet contain data of the distilled SmallNets (e.g., controllers_data/pendulum/SmallNet)
  3. The 'finetune' folders contain data with the finetuned controllers. (e.g., controllers_data/pendulum/Symbolic_finetune_CMA ...)
  4. persistent... folders contain the found persistent solutions and potential orbits listed as a hall-of-fames controllers_data/pendulum/Symbolic_persistent_CMA.

Python micro-documentation

Requirements are :

  • pytorch
  • pysr
  • pandas
  • matplotlib
  • gym
  • sympy
  • numpy
  • scikit-learn
  • sb3 (for RL training)
  • sb3-zoo (for RL training)

The workflow to reproduce the results:

  1. Train a deep NN controller using the stable-baselines3 https://github.com/DLR-RM/rl-baselines3-zoo;, and put the checkpoint in the working dir; We include the pretrained agents in the respective folders;
  2. run python/derive_symbolic.py script (the first argument is the dir with the trained sb3 checkpoint); it will output a hall-of-fame csv file with found symbolic controllers;
  3. Then the found symbolic controllers need to be defined as PyTorch parametrized functions and serialized in .pt file, see the example in python/pendulum_symbolic2.py; We include the serialized (pickle) pytorch symbolic controllers in the respective folders;
  4. To distill a small NN use python/smallNN_distill.py script (the first argument is the dir name with the trained sb3 checkpoint);
  5. For testing of the controllers contained in a hall-of-fame csv (like computing the average rewards for different discretizations) use python/pysr_controller_test_hof.py script ;
  6. Finally apply the fine-tuning script python/pendulum_cma.py and python/swingup_cma.py for the pendulum and cartpole-swingup problem respectively;
  7. The persistent solutions search is implemented in python/pendulum_transientsCMA.py and python/swingup_transientsCMA.py for the pendulum and cartpole-swingup problem respectively;

Julia micro-documentation

Dependencies:

  • ForwardDiff
  • RadiiPolynomial
  • CSV
  • DataFrames

Note that all dependencies are downloaded/installed on instantiation. Use Pkg to activate/instantiate the environment and package as follows...

import Pkg
Pkg.activate("path/to/julia/VerifySolutionsOrbits") # edit the path accordingly
Pkg.instantiate()
using VerifySolutionsOrbits

Structure of the package:

  • Pendulum.jl and CartPoleSwingup.jl : structs for pendulum and cartpole vector fields, penalty functions, and other pendulum/cartpole specific functions.
  • controllers.jl : Landajuela controllers as well as "generic" controllers for pendulum and cartpole.
  • integrator.jl : generic integrator. Uses steppers from Pendulum.jl/CartPoleSwingup.jl.
  • network.jl : for contstructing neural network controllers.
  • numerics.jl : Newton's method implementation, orbit finders.
  • proofs.jl : Implements abstract theorems on computer-assisted proof. The function check_contraction is generic; all other functions are tailored to the structures of Pendulum.jl/CartPoleSwingup.jl.
  • batch.jl : Batch proof functions to reproduce tables from the appendices.

Usage Example 1: Loading the Large NN, finding and proving all periodic orbits in cartpole swingup model. This has been scripted for ease of reproduction.

  • Run the following from the REPL: julia> proof_cartpole_LargeNet(file,folder); where file = suitably formatted (string) path to: ...\julia\VerifySolutionsOrbits\orbits_400_300_ReLU_cartpole\relu_400_300_cartpole_swingup.jld2 where folder = suitably formatted (string) path to: ...\controllers_data\cartpole_swingup\ReLU_400_300

Usage Example 2: Loading the Small NN, proving all persistent solutions and reconstructing the table (LaTeX code) from the paper, for cartpole swingup model.

  • Import weights and biases. julia> W,B = load_tensors_zero_bias("SmallNet25_cma_swingup_final_model.pt_fc",path,2); where path = suitably formatted (string) path to: ...\controllers_data\cartpole_swingup\SmallNet

  • Convert weights and biases to BigFloat (4096 bits; required because of extreme wrapping effect). julia> W,B = convert_weight_bias_bigfloat(W,B;precision=4096);

  • Pass the network to the batch proof function; where csv_path = suitably formatted (string) path to: ...\controllers_data\cartpole_swingup\transients_infiniteray_cma\SmallNet25_cma_swingup_final_model_transients_hof.csv

activations = ["ReLU","Tanh"];
scaling(x) = x;
solutions, penalty, escape, escape_flag, step_size = prove_transients_cartpole_NeuralNet(x->Network(x,W,B,activations,scaling),csv_path);
  • Print the LaTeX table julia> str = output_LaTeX_table_cartpole("Small NN", csv_path, solutions, penalty, escape_flag); where csv_path = suitably formatted (string) path to: ...\controllers_data\cartpole_swingup\transients_infiniteray_cma\SmallNet25_cma_swingup_final_model_transients_hof.csv julia> println(str)

worrisome-nn's People

Contributors

kemchurch avatar dzako avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.