How to use model.reset_states() in Keras?

reset_states clears only the hidden states of your network. It's worth to mention that depending on if the option stateful=True was set in your network - the behaviour of this function might be different. If it's not set - all states are automatically reset after every batch computations in your network (so e.g. after calling fit, predict and evaluate also). If not - you should call reset_states every time, when you want to make consecutive model calls independent.


If you use explicitly either of:

model.reset_states() 

to reset the states of all layers in the model, or

layer.reset_states() 

to reset the states of a specific stateful RNN layer (also LSTM layer), implemented here:

def reset_states(self, states=None):
  if not self.stateful:
     raise AttributeError('Layer must be stateful.')

this means your layer(s) must be stateful.

In LSTM you need to:

  • explicitly specify the batch size you are using, by passing a batch_size argument to the first layer in your model or batch_input_shape argument

  • set stateful=True.

  • specify shuffle=False when calling fit().


The benefits of using stateful models are probable best explained here.