Quick example¶
With NNGeometry, you can easily manipulate \(d \times d\) matrices and \(d\) vectors where \(d\) is the number of parameter of your neural network, for modern neural networks where \(d\) can be as big as \(10^8\). These matrices include for instance:
The Fisher Information Matrix (FIM) used in statistics, in the natural gradient algorithm, or as an approximate of the Hessian matrix in some applications.
Posterior covariances in Bayesian Deep Learning.
You can also compute finite tangent kernels.
A naive computation of the FIM would require storing \(d \times d\) scalars in memory. This is prohibitively large for modern neural network architectures, and a line of research has focused at finding lower memory intensive approximations specific to neural networks, such as KFAC, EKFAC, low-rank approximations, etc. This library proposes a common interface for manipulating these different approximations, called representations.
Let us now illustrate this by computing the FIM using the KFAC representation.
>>> F_kfac = FIM(model=model,
loader=loader,
representation=PMatKFAC,
n_output=10,
variant='classif_logits',
device='cuda')
>>> print(F_kfac.trace())
Computing the FIM requires the following arguments:
The
torch.nn.Module
model
object is the PyTorch model used as our neural network.The
torch.utils.data.DataLoader
loader
object is the dataloader that contains examples used for computing the FIM.The
object.PMatKFAC
PMatKFAC
argument specifies which representation to use in order to store the FIM.We will next define a vector in parameter space, by using the current value given by our model:
>>> v = PVector.from_model(model)We can now compute the matrix-vector product \(F v\) by simply calling:
>>> Fv = F_kfac.mv(v)Note that switching from the
object.PMatKFAC
representation to any other representation such asobject.PMatDense
is as simple as passingrepresentation=PMatDense
when building theF_kfac
object.
More examples¶
More notebook examples can be found at https://github.com/tfjgeorge/nngeometry/tree/master/examples