what does padding_idx do in nn.embeddings()

padding_idx is indeed quite badly described in the documentation.

Basically, it specifies which index passed during call will mean "zero vector" (which is quite often used in NLP in case some token is missing). By default no index will mean "zero vector", as you can see in the example below:

import torch

embedding = torch.nn.Embedding(10, 3)
input = torch.LongTensor([[0, 1, 0, 5]])
print(embedding(input))

Will give you:

tensor([[[ 0.1280, -1.1390, -2.5007],
         [ 0.3617, -0.9280,  1.2894],
         [ 0.1280, -1.1390, -2.5007],
         [-1.3135, -0.0229,  0.2451]]], grad_fn=<EmbeddingBackward>)

If you specify padding_idx=0 every input where the value is equal to 0 (so zero-th and second row) will be zero-ed out like this (code: embedding = torch.nn.Embedding(10, 3, padding_idx=0)):

tensor([[[ 0.0000,  0.0000,  0.0000],
         [-0.4448, -0.2076,  1.1575],
         [ 0.0000,  0.0000,  0.0000],
         [ 1.3602, -0.6299, -0.5809]]], grad_fn=<EmbeddingBackward>

If you were to specify padding_idx=5 last row would be full of zeros etc.


As per the docs, padding_idx pads the output with the embedding vector at padding_idx (initialized to zeros) whenever it encounters the index.

What this means is that wherever you have an item equal to padding_idx, the output of the embedding layer at that index will be all zeros.

Here is an example: Let us say you have word embeddings of 1000 words, each 50-dimensional ie num_embeddingss=1000, embedding_dim=50. Then torch.nn.Embedding works like a lookup table (lookup table is trainable though):

emb_layer = torch.nn.Embedding(1000,50)
x = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
y = emb_layer(x)

y will be a tensor of shape 2x4x50. I hope this part is clear to you.

Now if I specify padding_idx=2, ie

emb_layer = torch.nn.Embedding(1000,50, padding_idx=2)
x = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
y = emb_layer(x)

then output will still be 2x4x50 but the 50-dim vector at (1,2) and (2,3) will be all zeros since x[1,2] and x[2,3] values are 2 which is equal to the padding_idx. You can think of it as 3rd word in the lookup table (since lookup table would be 0-indexed) is not being used for training.