How to implement neural network pruning?

If you add a mask, then only a subset of your weights will contribute to the computation, hence your model will be pruned. For instance, autoregressive models use a mask to mask out the weights that refer to future data so that the output at time step t only depends on time steps 0, 1, ..., t-1.

In your case, since you have a simple fully connected layer, it is better to use dropout. It randomly turns off some neurons at each iteration step so it reduces the computation complexity. However, the main reason dropout was invented is to tackle overfitting: by having some neurons turned off randomly, you reduce neurons' co-dependencies, i.e. you avoid that some neurons rely on others. Moreover, at each iteration, your model will be different (different number of active neurons and different connections between them), hence your final model can be interpreted as an ensamble (collection) of several diifferent models, each specialized (we hope) in the understanding of a specific subset of the input space.


Based on the discussion in the comments, here is a way to prune a layer (a weight matrix) of your neural network. What the method essentially does is selects the k% smallest weights (elements of the matrix) based on their norm, and sets them to zero. That way, the corresponding matrix can be treated as a sparse matrix, and we can perform dense-sparse matrix multiplication which can be faster if enough weights are pruned.

def weight_pruning(w: tf.Variable, k: float) -> tf.Variable:
    """Performs pruning on a weight matrix w in the following way:

    - The absolute value of all elements in the weight matrix are computed.
    - The indices of the smallest k% elements based on their absolute values are selected.
    - All elements with the matching indices are set to 0.

    Args:
        w: The weight matrix.
        k: The percentage of values (units) that should be pruned from the matrix.

    Returns:
        The unit pruned weight matrix.

    """
    k = tf.cast(tf.round(tf.size(w, out_type=tf.float32) * tf.constant(k)), dtype=tf.int32)
    w_reshaped = tf.reshape(w, [-1])
    _, indices = tf.nn.top_k(tf.negative(tf.abs(w_reshaped)), k, sorted=True, name=None)
    mask = tf.scatter_nd_update(tf.Variable(tf.ones_like(w_reshaped, dtype=tf.float32), name="mask", trainable=False), tf.reshape(indices, [-1, 1]), tf.zeros([k], tf.float32))

    return w.assign(tf.reshape(w_reshaped * mask, tf.shape(w)))

While the method above prunes a single connection (weight), the method below prunes a whole neuron from a weight matrix. Namely, the method select the k% smallest neurons (columns of the weight matrix) based on the Euclidean norm, and sets them to zero.

def unit_pruning(w: tf.Variable, k: float) -> tf.Variable:
    """Performs pruning on a weight matrix w in the following way:

    - The euclidean norm of each column is computed.
    - The indices of smallest k% columns based on their euclidean norms are selected.
    - All elements in the columns that have the matching indices are set to 0.

    Args:
        w: The weight matrix.
        k: The percentage of columns that should be pruned from the matrix.

    Returns:
        The weight pruned weight matrix.

    """
    k = tf.cast(
        tf.round(tf.cast(tf.shape(w)[1], tf.float32) * tf.constant(k)), dtype=tf.int32
    )
    norm = tf.norm(w, axis=0)
    row_indices = tf.tile(tf.range(tf.shape(w)[0]), [k])
    _, col_indices = tf.nn.top_k(tf.negative(norm), k, sorted=True, name=None)
    col_indices = tf.reshape(
        tf.tile(tf.reshape(col_indices, [-1, 1]), [1, tf.shape(w)[0]]), [-1]
    )
    indices = tf.stack([row_indices, col_indices], axis=1)

    return w.assign(
        tf.scatter_nd_update(w, indices, tf.zeros(tf.shape(w)[0] * k, tf.float32))
    )

Finally, this Github repository goes through the pruning methods explained here and performs experiments on the MNIST dataset.