How to accumulate gradients for large batch sizes in Keras

We have published an open-source tool to automatically add gradient accumulation support in Keras models we implemented at Run:AI to help us with batch sizing issues.

Using gradient accumulation in our models allowed us to use large batch sizes while being limited by GPU memory. It specifically allowed us running neural networks with large batch sizes using only a single GPU.

The project is available at https://github.com/run-ai/runai/tree/master/runai/ga along with explanations and examples you can use right out of the box.

Using this tool, all you have to do is add a single line of code to your Python script, and you can add gradient accumulation support to your optimizer.

The Python package is available at PyPI and can be installed using the command: pip install runai.

Adding gradient accumulation support to Keras models is extremely easy. First, import the package to your code: import runai.ga. Then, you have to create a gradient accumulation optimizer. There are two ways to do so:

1. Wrap an existing Keras optimizer

You can take any Keras optimizer - whether it's a built-in one (SGD, Adam, etc...) or a custom optimizer with your algorithm implementation - and add gradient accumulation support to it using the next line:

optimizer = runai.ga.keras.optimizers.Optimizer(optimizer, steps=STEPS)

Where optimizer is your optimizer, and STEPS is the number of steps you want to accumulate gradients over.

2. Create a gradient accumulation version of any of the built-ins optimizers

There are gradient accumulation versions of all built-in optimizers (SGD, Adam, etc...) available in the package. They can be created using this line:

optimizer = runai.ga.keras.optimizers.Adam(steps=STEPS)

Here, we create a gradient accumulation version of Adam optimizer, and we accumulate gradients over STEPS steps.

More information, explanations, and examples are available in GitHub.

In addition to the open-source tool itself, we have published a series of 3 articles on Towards Data Science (Medium), where we explained issues when using large batch sizes, what is gradient accumulation and how can it help in solving these issues, how it works, and how we implemented it. Here are links to the articles:

  • The problem of batch sizing and limited GPU memory

  • What is Gradient Accumulation and how does it help?

  • How-to guide to using the gradient accumulation mechanism and how we implemented it

Let us know if the tool helped you in using gradient accumulation in your own Keras models. We are here to give any support and help with the problems you encounter when using it in your own models.


As was mentioned in the question, there is no off-the-shelf function/method to achieve this with Keras/Tensorflow. However this can be done by writing a custom optimizer for Keras.

The main idea is to use a flag to determine whether to update the weights during each batch.

The following implementation is based on this github post by "alexeydevederkin" and it is an accumulating Adam optimizer:

import keras.backend as K
from keras.legacy import interfaces
from keras.optimizers import Optimizer


class AdamAccumulate(Optimizer):

    def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999,
                 epsilon=None, decay=0., amsgrad=False, accum_iters=1, **kwargs):
        if accum_iters < 1:
            raise ValueError('accum_iters must be >= 1')
        super(AdamAccumulate, self).__init__(**kwargs)
        with K.name_scope(self.__class__.__name__):
            self.iterations = K.variable(0, dtype='int64', name='iterations')
            self.lr = K.variable(lr, name='lr')
            self.beta_1 = K.variable(beta_1, name='beta_1')
            self.beta_2 = K.variable(beta_2, name='beta_2')
            self.decay = K.variable(decay, name='decay')
        if epsilon is None:
            epsilon = K.epsilon()
        self.epsilon = epsilon
        self.initial_decay = decay
        self.amsgrad = amsgrad
        self.accum_iters = K.variable(accum_iters, K.dtype(self.iterations))
        self.accum_iters_float = K.cast(self.accum_iters, K.floatx())

    @interfaces.legacy_get_updates_support
    def get_updates(self, loss, params):
        grads = self.get_gradients(loss, params)
        self.updates = [K.update_add(self.iterations, 1)]

        lr = self.lr

        completed_updates = K.cast(K.tf.floordiv(self.iterations, self.accum_iters), K.floatx())

        if self.initial_decay > 0:
            lr = lr * (1. / (1. + self.decay * completed_updates))

        t = completed_updates + 1

        lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)))

        # self.iterations incremented after processing a batch
        # batch:              1 2 3 4 5 6 7 8 9
        # self.iterations:    0 1 2 3 4 5 6 7 8
        # update_switch = 1:        x       x    (if accum_iters=4)  
        update_switch = K.equal((self.iterations + 1) % self.accum_iters, 0)
        update_switch = K.cast(update_switch, K.floatx())

        ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        gs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]

        if self.amsgrad:
            vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
        else:
            vhats = [K.zeros(1) for _ in params]

        self.weights = [self.iterations] + ms + vs + vhats

        for p, g, m, v, vhat, tg in zip(params, grads, ms, vs, vhats, gs):

            sum_grad = tg + g
            avg_grad = sum_grad / self.accum_iters_float

            m_t = (self.beta_1 * m) + (1. - self.beta_1) * avg_grad
            v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(avg_grad)

            if self.amsgrad:
                vhat_t = K.maximum(vhat, v_t)
                p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
                self.updates.append(K.update(vhat, (1 - update_switch) * vhat + update_switch * vhat_t))
            else:
                p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)

            self.updates.append(K.update(m, (1 - update_switch) * m + update_switch * m_t))
            self.updates.append(K.update(v, (1 - update_switch) * v + update_switch * v_t))
            self.updates.append(K.update(tg, (1 - update_switch) * sum_grad))
            new_p = p_t

            # Apply constraints.
            if getattr(p, 'constraint', None) is not None:
                new_p = p.constraint(new_p)

            self.updates.append(K.update(p, (1 - update_switch) * p + update_switch * new_p))
        return self.updates

    def get_config(self):
        config = {'lr': float(K.get_value(self.lr)),
                  'beta_1': float(K.get_value(self.beta_1)),
                  'beta_2': float(K.get_value(self.beta_2)),
                  'decay': float(K.get_value(self.decay)),
                  'epsilon': self.epsilon,
                  'amsgrad': self.amsgrad}
        base_config = super(AdamAccumulate, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

It can be used in the following way:

opt = AdamAccumulate(lr=0.001, decay=1e-5, accum_iters=5)
model.compile( loss='categorical_crossentropy',   # Loss function
                            optimizer=opt,        # Optimization technique
                            metrics=['accuracy']) # Accuracy matrix
model.fit(X_train, y_train, batch_size = 10)

In this example, the model processes 10 samples in every iteration ("batch_size"), but the update to the weights only happens after accumulating 5 such batches ("accum_iters"). So the actual batch size for updating the weights is 50.