Pruning and the deep double descent

September 27, 2023   
Keywords: machine learning, deep learning, double descent, generalization, pruning, model compression
Prerequisites: Deep Learning
Difficulty: Medium/Hard (M.Sc.). Not suitable for B.Sc.

Abstract

The deep double descent (DD) phenomenon is a pattern observed when training deep neural networks (DNNs) with different complexity on the same dataset. As the complexity (usually approximated with the number of parameter) increases, the test loss decreases, then, after attaining a minimum, it grows again, while the training loss keeps decreasing: up to this point, it conforms to the classical bias-variance tradeoff, which states that too complex/variable models have trouble generalizing, thus ending up overfitting. However, DNNs usually showcase a remarkable behavior: further increasing complexity causes the test loss to peak (interpolation threshold, IT), after which it decreases again, even at lower levels than the previous minimum, as it can be appreciated in the figure below: figure showcasing the deep double descent Double descent in a two-layer neural network, provided by HaeB under the license CC BY.

Pruning is a common strategy employed to reduce the number of parameters in a machine learning model. Pruning is employed mainly for two goals: (a) reduce the size of a model for a smaller memory footprint, or (b) to act as a regularizer, since often DNNs which are pruned at a shallow level tend to perform better than their dense counterparts. Specifically, dense-to-sparse pruning requires a baseline dense model to be fully trained, after which pruning is applied to reach the target sparsity, and the DNN is trained or finetuned again to recover for a potential decrease in accuracy.

Despite there being already several works identifying ties between DD and pruning, there still remains unexplored areas. For instance, what happens when we prune a model which is exactly in, or close to, the IT? Does this critical state acts as a bad starting point for pruning a DNN, or can the model recover from it? If the pruned model shows a recovery, does it follow somehow the trajectory of the DD curve backwards, or is the observed trend different?

Required work

  • Literature review on the concept of DD
  • Literature review on pruning
  • Pick multiple datasets, possibly one simple (not MNIST), one medium (e.g., CIFAR10), and one hard (e.g., Tiny-Imagenet or CIFAR100).
  • Try to reproduce the DD phenomenon on the simple dataset
  • Identify a set of checkpoints around the IT, then apply (iterative) pruning to the models with these weights configurations
  • (control sample) Apply pruning also to other weights configuration before/after the IT
  • Analyze the data

Relevant literature