Welcome to NNGeometry’s documentation!¶
NNGeometry is a library built on top of PyTorch aiming at giving tools to easily manipulate and study properties of Fisher Information Matrices and tangent kernels.
You can start by looking at the quick start example below. Convinced? Then install NNGeometry, try the tutorials or explore the API reference.
Warning
NNGeometry is under developement, as such it is possible that core components change when between versions.
Quick example¶
Computing the Fisher Information Matrix on a given PyTorch model using a KFAC representation, and then computing its trace is as simple as:
>>> F_kfac = FIM(model=model,
loader=loader,
representation=PMatKFAC,
n_output=10,
variant='classif_logits',
device='cuda')
>>> print(F_kfac.trace())
If we instead wanted to choose a nngeometry.object.pspace.PMatBlockDiag
representation, we can just replace representation=PMatKFAC
with representation=PMatBlockDiag
in the above.
This example is further detailed in Quick example. Other available parameter space representations are listed in Parameter space representations.
More examples¶
More notebook examples can be found at https://github.com/tfjgeorge/nngeometry/tree/master/examples
Indices and tables¶
In-depth¶
Quick example¶
With NNGeometry, you can easily manipulate \(d \times d\) matrices and \(d\) vectors where \(d\) is the number of parameter of your neural network, for modern neural networks where \(d\) can be as big as \(10^8\). These matrices include for instance:
The Fisher Information Matrix (FIM) used in statistics, in the natural gradient algorithm, or as an approximate of the Hessian matrix in some applications.
Posterior covariances in Bayesian Deep Learning.
You can also compute finite tangent kernels.
A naive computation of the FIM would require storing \(d \times d\) scalars in memory. This is prohibitively large for modern neural network architectures, and a line of research has focused at finding lower memory intensive approximations specific to neural networks, such as KFAC, EKFAC, low-rank approximations, etc. This library proposes a common interface for manipulating these different approximations, called representations.
Let us now illustrate this by computing the FIM using the KFAC representation.
>>> F_kfac = FIM(model=model,
loader=loader,
representation=PMatKFAC,
n_output=10,
variant='classif_logits',
device='cuda')
>>> print(F_kfac.trace())
Computing the FIM requires the following arguments:
The
torch.nn.Module
model
object is the PyTorch model used as our neural network.The
torch.utils.data.DataLoader
loader
object is the dataloader that contains examples used for computing the FIM.The
object.PMatKFAC
PMatKFAC
argument specifies which representation to use in order to store the FIM.We will next define a vector in parameter space, by using the current value given by our model:
>>> v = PVector.from_model(model)We can now compute the matrix-vector product \(F v\) by simply calling:
>>> Fv = F_kfac.mv(v)Note that switching from the
object.PMatKFAC
representation to any other representation such asobject.PMatDense
is as simple as passingrepresentation=PMatDense
when building theF_kfac
object.
More examples¶
More notebook examples can be found at https://github.com/tfjgeorge/nngeometry/tree/master/examples
Installing NNGeometry¶
The development version of NNGeometry can be installed by directly cloning the master
branch of the repository using pip
:
$ pip install git+git://github.com/tfjgeorge/nngeometry.git
The only dependency is PyTorch.
Parameter space representations¶
Parameter space representations are \(d \times d\) objects that define metrics in parameter space such as:
Fisher Information Matrices/Gauss-Newton matrix
Gradient 2nd moment (e.g. the sometimes called Empirical Fisher)
Other covariances such as in Bayesian Deep Learning
These matrices are often too large to fit in memory, for instance when \(d\) is in the order of \(10^6 - 10^8\) as is typical in current deep networks. Here is a list of parameter space representations that are available in NNGeometry, computed on a small network, represented as images where each pixel represent a component of the matrix, and the color is the magnitude of these components. These matrices are normalized by their diagonal (i.e. these are correlation matrices) for better visualization:
nngeometry.object.pspace.PMatDense
representation: this is the usual dense matrix. Memory cost: \(d \times d\)

nngeometry.object.pspace.PMatBlockDiag
representation: a block-diagonal representation where diagonal blocks are
dense matrices corresponding to parameters of a single layer, and cross-layer interactions are ignored (their coefficients are
set to \(0\)). Memory cost: \(\sum_l d_l \times d_l\) where \(d_l\) is the number of parameters of layer \(l\).

nngeometry.object.pspace.PMatKFAC
representation [GM16, MG15]: a block-diagonal representation where diagonal blocks are
factored as the Kronecker product of two smaller matrices, and cross-layer interactions are ignored (their coefficients are
set to \(0\)). Memory cost: \(\sum_l g_l \times g_l + a_l \times a_l\) where \(a_l\) is the number of neurons of the
input of layer \(l\) and \(g_l\) is the number of pre-activations of the output of layer \(l\).

nngeometry.object.pspace.PMatEKFAC
representation [GLB+18]: a block-diagonal representation where diagonal blocks are
factored as a diagonal matrix in a Kronecker factored eigenbasis, and cross-layer interactions are ignored (their coefficients are
set to \(0\)). Memory cost: \(\sum_l g_l \times g_l + a_l \times a_l + d_l\) where \(a_l\) is the number of neurons of the
input of layer \(l\) and \(g_l\) is the number of pre-activations of the output of layer \(l\), and \(d_l\) is

nngeometry.object.pspace.PMatDiag
representation: a diagonal representation that ignores all interactions between parameters.
Memory cost: \(d\)

nngeometry.object.pspace.PMatQuasiDiag
representation [Oll15]: a diagonal representation where for each neuron, a coefficient is also
stored that measures the interaction between this neuron’s weights and the corresponding bias.
Memory cost: \(2 \times d\)

API Reference¶
Warning
This API reference is currently nothing but a dump of docstrings, ordered alphabetically.
Generators¶
The spirit of NNGeometry is that you do not directly manipulate Generator objects, these can be considered as a backend that you do not need to worry about once instantiated. You instead instantiate concrete representations such as PMatDense or PMatKFAC and directly call linear algebra operations on these concrete representations.
Layer collection¶
Layer collections describe the structure of parameters that will be differentiated. We need the LayerCollection object in order to be able to map components of different objects together. As an example, when performing a matrix-vector product using a block diagonal representation, we need to make sure that elements of the vector corresponding to parameters from layer 1 are multiplied with the diagonal block also corresponding to parameters from layer 1, and so on.
Typical use cases include:
All parameters of your network: for this you can simply use the constructor
nngeometry.layercollection.LayerCollection.from_model()
Only parameters from some layers of your network. In this case you need to
instantiate a new LayerCollection object
add your layers one at a time using
nngeometry.layercollection.LayerCollection.add_layer_from_model()
- class nngeometry.layercollection.AbstractLayer¶
- class nngeometry.layercollection.Conv2dLayer(in_channels, out_channels, kernel_size, bias=True)¶
- numel()¶
- class nngeometry.layercollection.ConvTranspose2dLayer(in_channels, out_channels, kernel_size, bias=True)¶
- numel()¶
- class nngeometry.layercollection.LayerCollection(layers=None)¶
This class describes a set or subset of layers, that can be used in order to instantiate
nngeometry.object.PVector
ornngeometry.object.PSpaceDense
objects- Parameters
layers –
- add_layer(name, layer)¶
- add_layer_from_model(model, module)¶
Add a layer by specifying the module corresponding to this layer (e.g. torch.nn.Linear or torch.nn.BatchNorm1d)
- Parameters
model – The model defining the neural network
module – The layer to be added
- from_model(ignore_unsupported_layers=False)¶
Constructs a new LayerCollection object by using all parameters of the model passed as argument.
- Parameters
model (nn.Module) – The PyTorch model
ignore_unsupported_layers – If false, will raise an error
when model contains layers that are not supported yet. If true, will silently ignore the layer :type ignore_unsupported_layers: bool
- get_layerid_module_maps(model)¶
- numel()¶
Total number of scalar parameters in this LayerCollection object
- Returns
number of scalar parameters
- Return type
int
- parameters(layerid_to_module)¶
Metrics¶
Metrics introduction
Parameter space matrix representations¶
All parameter space matrix representations inherit from nngeometry.object.pspace.PMatAbstract
. This abstract class defines all method that can be used with all representations (with some exceptions!). nngeometry.object.pspace.PMatAbstract
cannot be instantiated, you instead have to choose one of the concrete representations below.
- class nngeometry.object.pspace.PMatAbstract(generator, data=None, examples=None)¶
A \(d \times d\) matrix in parameter space. This abstract class defines common methods used in concrete representations.
- Parameters
generator (
nngeometry.generator.jacobian.Jacobian
) – The generatordata – if None, it requires examples to be different from None, and it uses the generator to populate the matrix data
examples – if data is None, it uses these examples to populate the matrix using the generator. examples is either a Dataloader, or a single mini-batch of (inputs, targets) from a Dataloader
Note
Either data or examples has to be different from None, and both cannot be not None at the same time.
- abstract get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- size(dim=None)¶
Size of the matrix as a tuple, regardless of the actual size in memory.
- Parameters
dim (int or None) – dimension
>>> M.size() (1254, 1254) >>> M.size(0) 1254
- abstract solve(v, regul)¶
Solves Fx = v in x
- Parameters
regul (PVector) – Tikhonov regularization
v – v
- abstract vTMv(v)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
Concrete representations¶
NNGeometry allows to switch between representations easily. With each representation comes a tradeof between accuracy and memory/computational cost. If testing a new algorithm, we recommend testing on a small network using the most accurate representation that fits in memory (typically nngeometry.object.pspace.PMatDense
), then switch to a larger scale experiment, and to a lower memory representation.
- class nngeometry.object.pspace.PMatBlockDiag(generator, data=None, examples=None)¶
- get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- mm(other)¶
Matrix-matrix product where other is another instance of PMatBlockDiag
- Parameters
other (
nngeometry.object.PMatBlockDiag
) – Other FIM matrix- Returns
The matrix-matrix product
- Return type
nngeometry.object.PMatBlockDiag
- solve(vs, regul=1e-08)¶
Solves Fx = v in x
- Parameters
regul (PVector) – Tikhonov regularization
v – v
- vTMv(vector)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatDense(generator, data=None, examples=None)¶
- get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- mm(other)¶
Matrix-matrix product where other is another instance of PMatDense
- Parameters
other (
nngeometry.object.PMatDense
) – Other FIM matrix- Returns
The matrix-matrix product
- Return type
nngeometry.object.PMatDense
- solve(v, regul=1e-08, impl='solve')¶
solves v = Ax in x
- vTMv(v)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatDiag(generator, data=None, examples=None)¶
- get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- mm(other)¶
Matrix-matrix product where other is another instance of PMatDiag
- Parameters
other (
nngeometry.object.PMatDiag
) – Other FIM matrix- Returns
The matrix-matrix product
- Return type
nngeometry.object.PMatDiag
- solve(v, regul=1e-08)¶
solves v = Ax in x
- vTMv(v)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatEKFAC(generator, data=None, examples=None)¶
EKFAC representation from George, Laurent et al., Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis, NIPS 2018
- get_KFE(split_weight_bias=True)¶
Returns a dict index by layers, of dense eigenvectors constructed from Kronecker-factored eigenvectors
split_weight_bias (bool): if True then the parameters are ordered in
the same way as in the dense or blockdiag representation, but it involves more operations. Otherwise the coefficients corresponding to the bias are mixed between coefficients of the weight matrix
- get_dense_tensor(split_weight_bias=True)¶
split_weight_bias (bool): if True then the parameters are ordered in
the same way as in the dense or blockdiag representation, but it involves more operations. Otherwise the coefficients corresponding to the bias are mixed between coefficients of the weight matrix
- get_diag(v)¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- solve(vs, regul=1e-08)¶
Solves Fx = v in x
- Parameters
regul (PVector) – Tikhonov regularization
v – v
- update_diag(examples)¶
Will update the diagonal in the KFE (aka the approximate eigenvalues) using current values of the model’s parameters
- vTMv(vector)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatImplicit(generator, data=None, examples=None)¶
PMatImplicit is a very special representation, since no elements of the matrix is ever computed, but instead various linear algebra operations are performed implicitely using efficient tricks.
The computations are done exactly, meaning that there is no approximation involved. This is useful for networks too big to fit in memory.
- get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- vTMv(v)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatKFAC(generator, data=None, examples=None)¶
- get_dense_tensor(split_weight_bias=True)¶
split_weight_bias (bool): if True then the parameters are ordered in
the same way as in the dense or blockdiag representation, but it involves more operations. Otherwise the coefficients corresponding to the bias are mixed between coefficients of the weight matrix
- get_diag(split_weight_bias=True)¶
split_weight_bias (bool): if True then the parameters are ordered in
the same way as in the dense or blockdiag representation, but it involves more operations. Otherwise the coefficients corresponding to the bias are mixed between coefficients of the weight matrix
- mm(other)¶
Matrix-matrix product where other is another instance of PMatKFAC
- Parameters
other (
nngeometry.object.PMatKFAC
) – Other FIM matrix- Returns
The matrix-matrix product
- Return type
nngeometry.object.PMatKFAC
- solve(vs, regul=1e-08, use_pi=True)¶
Solves Fx = v in x
- Parameters
regul (PVector) – Tikhonov regularization
v – v
- vTMv(vector)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatLowRank(generator, data=None, examples=None)¶
- get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- solve(b, regul=1e-08)¶
Solves Fx = v in x
- Parameters
regul (PVector) – Tikhonov regularization
v – v
- vTMv(v)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- class nngeometry.object.pspace.PMatQuasiDiag(generator, data=None, examples=None)¶
Quasidiagonal approximation as decribed in Ollivier, Riemannian metrics for neural networks I: feedforward networks, Information and Inference: A Journal of the IMA, 2015
- get_diag()¶
Computes and returns the diagonal elements of this matrix.
- Returns
a PyTorch Tensor
- solve(vs, regul=1e-08)¶
Solves Fx = v in x
- Parameters
regul (PVector) – Tikhonov regularization
v – v
- vTMv(vs)¶
Computes the quadratic form defined by M in v, namely the product \(v^\top M v\)
- Parameters
v (
object.vector.PVector
) – vector \(v\)
- nngeometry.object.pspace.bdot(A, B)¶
batched dot product
Vector representations¶
In NNGeometry, vectors are not just a bunch of scalars, but they have a semantic meaning.
nngeometry.object.vector.PVector
objects are vectors living in the parameter space of a neural network model. An example of such vector is \(\delta \mathbf w\) in the EWC penalty \(\delta \mathbf w^\top F \delta \mathbf w\).
nngeometry.object.vector.FVector
objects are vectors living in the function space of a neural network model. An example of such vector is \(\mathbf{f}=\left(f\left(x_{1}\right),\ldots,f\left(x_{n}\right)\right)^{\top}\) where \(f\) is a neural network and \(x_1,\ldots,x_n\) are examples from a training dataset.
- class nngeometry.object.vector.FVector(vector_repr=None)¶
Bases:
object
A vector in function space
- get_flat_representation()¶
- class nngeometry.object.vector.PVector(layer_collection, vector_repr=None, dict_repr=None)¶
Bases:
object
A vector in parameter space
- Param
- add_to_model(model)¶
Updates model parameter values by adding the current PVector
Note. This is an inplace operation
- clone()¶
Returns a clone of the current object
- copy_to_model(model)¶
Updates model parameter values with the current PVector
Note. This is an inplace operation
- detach()¶
Detachs the current PVector from the computation graph
- dot(other)¶
Computes the dot product between self and other
- Parameters
other – The other PVector
- static from_model(model)¶
Creates a PVector using the current values of the given model
- static from_model_grad(model)¶
Creates a PVector using the current values of the .grad fields of parameters of the given model
- get_dict_representation()¶
- get_flat_representation()¶
Returns a Pytorch 1d tensor of the flatten vector.
Warning
The ordering in which the parameters are flattened can seem to be arbitrary. It is in fact the same ordering as specified by the
layercollection.LayerCollection
object.- Returns
a Pytorch Tensor
- norm(p=2)¶
Computes the Lp norm of the PVector
- size()¶
The size of the PVector, or equivalently the number of parameters of the layer collection
- nngeometry.object.vector.random_fvector(n_samples, n_output=1, device=None)¶
- nngeometry.object.vector.random_pvector(layer_collection, device=None)¶
Returns a random
nngeometry.object.PVector
object using the structure defined by the layer_collection parameter, with each components drawn from a normal distribution with mean 0 and standard deviation 1.The returned PVector will internally use a flat representation.
- Parameters
layer_collection – The
nngeometry.layercollection.LayerCollection
describing the structure of the random pvector
- nngeometry.object.vector.random_pvector_dict(layer_collection, device=None)¶
Returns a random
nngeometry.object.PVector
object using the structure defined by the layer_collection parameter, with each components drawn from a normal distribution with mean 0 and standard deviation 1.The returned PVector will internally use a dict representation.
- Parameters
layer_collection – The
nngeometry.layercollection.LayerCollection
describing the structure of the random pvector
References¶
- GLB+18
Thomas George, César Laurent, Xavier Bouthillier, Nicolas Ballas, and Pascal Vincent. Fast approximate natural gradient descent in a kronecker factored eigenbasis. Advances in Neural Information Processing Systems, 31:9550–9560, 2018.
- GM16
Roger Grosse and James Martens. A kronecker-factored approximate fisher matrix for convolution layers. In International Conference on Machine Learning, 573–582. PMLR, 2016.
- MG15
James Martens and Roger Grosse. Optimizing neural networks with kronecker-factored approximate curvature. In International conference on machine learning, 2408–2417. PMLR, 2015.
- Oll15
Yann Ollivier. Riemannian metrics for neural networks i: feedforward networks. Information and Inference: A Journal of the IMA, 4(2):108–153, 2015.