-
- Downloads
Add GPVI training (#932)
This commits adds GPVI to alf, main changes to alf/algorithms/generator.py and tests are * Add _create_mvp_network method, which instantiates an encoding network to compute a matrix-vector product, used for computing the inverse jacobian vector product in GPVI. * Add InverseMVPAlgorithm to specify a training step of the above InverseMVP network. * Add rkhs_func_grad function to generator.py This is the main update routine for GPVI, and is related to the InverseMVPAlgorithm. I have not added a function value version of GPVI, nor have I added a minmax version ``rkhs_func_grad`` is called whenever the generator argument ``functional_gradient`` is set to ``rkhs``. Right now, ``functional_gradient`` only supports a ReluMLP generator, as it relies on the fast ``compute_vjp`` function, rather than autograd * Add GPVI tests to generator_test.py. * Add GPVI tests to hypernetwork_test.py. --Small Changes-- * ReluMLP method ``compute_vjp`` now returns the output of the forward evaluation. needed for GPVI update * Added arguments for GPVI to generator.py, such as force_fullrank, and fullrank_diag_weight * Added arguments for pinverse_network to hypernetwork_algorithm.py * Added arguments for GPVI to hypernetwork_algorithm.py * addressing comments in PR discussion * renamed pinverse network to InverseMVPNetwork for (matrix vector product) * added a test for InverseMVPNetwork that explicitly computes the inverse Jacobian vector product of an MLP with respect to some random input. This is evaluated against the solution found by training an InverseMVPNetwork. * Implemented fixes and suggestions to generator and hypernetwork files. * Changed naming of Pinverse network to the new InverseMVP network. * Addressing PR comments. * Removed InverseMVPNetwork file. Opted for EncodingNetwork as suggested by Wei. * Refactored InverseMVP test accordingly, to show that the idea still works. * Added helper function to generator.py, to create this network * Added/changed docstrings to generator.py as suggested.
Showing
- alf/algorithms/generator.py 299 additions, 10 deletionsalf/algorithms/generator.py
- alf/algorithms/generator_test.py 37 additions, 4 deletionsalf/algorithms/generator_test.py
- alf/algorithms/hypernetwork_algorithm.py 73 additions, 12 deletionsalf/algorithms/hypernetwork_algorithm.py
- alf/algorithms/hypernetwork_algorithm_test.py 14 additions, 3 deletionsalf/algorithms/hypernetwork_algorithm_test.py
- alf/networks/inverse_mvp_network_test.py 168 additions, 0 deletionsalf/networks/inverse_mvp_network_test.py
- alf/networks/pinverse_network.py 151 additions, 0 deletionsalf/networks/pinverse_network.py
- alf/networks/relu_mlp.py 2 additions, 2 deletionsalf/networks/relu_mlp.py
- alf/networks/relu_mlp_test.py 1 addition, 1 deletionalf/networks/relu_mlp_test.py
Loading
Please register or sign in to comment