How does PyTorch module do the back prop

Not having to implement backward() is the reason PyTorch or any other DL framework is so valuable. In fact, implementing backward() should only be done in very specific cases where you need to mess with the network's gradient (or when you create a custom Function that can't be expressed using PyTorch's built-in functions).

PyTorch computes backward gradients using a computational graph which keeps track of what operations have been done during your forward pass. Any operation done on a Variable implicitly get registered here. Then it's a matter of traversing the graph backward from the variable where it was called, and applying derivative chain rule to compute the gradients.

PyTorch's About page has a nice visualization of the graph and how it generally works. I'd also recommend looking up compute graphs and autograd mechanism on Google if you want more details.

EDIT: The source code where all this happens would be in the C part of PyTorch's codebase, where the actual graph is implemented. After some digging around, I found this:

/// Evaluates the function on the given inputs and returns the result of the
/// function call.
variable_list operator()(const variable_list& inputs) {
    profiler::RecordFunction rec(this);
    if (jit::tracer::isTracingVar(inputs)) {
        return traced_apply(inputs);
    }
    return apply(inputs);
}

So in each Function, PyTorch first checks if its inputs needs tracing, and performs trace_apply() as implemented here. You can see the node being created and appended to the graph:

// Insert a CppOp in the trace.
auto& graph = state->graph;
std::vector<VariableFlags> var_flags;
for(auto & input: inputs) {
    var_flags.push_back(VariableFlags::of(input));
}
auto* this_node = graph->createCppOp(get_shared_ptr(), std::move(var_flags));
// ...
for (auto& input: inputs) {
    this_node->addInput(tracer::getValueTrace(state, input));
}
graph->appendNode(this_node);

My best guess here is that every Function object registers itself and its inputs (if needed) upon execution. Every non-functional calls (eg. variable.dot()) simply defers to the corresponding function, so this still applies.

NOTE: I don't take part in PyTorch's development and is in no way an expert on its architecture. Any corrections or addition would be welcomed.


Maybe I'm not correct, but I have different kind of view.

The backward function is defined and be called by forward function.

For example:

#!/usr/bin/env python
# encoding: utf-8

###############################################################
# Parametrized example
# --------------------
#
# This implements a layer with learnable weights.
#
# It implements the Cross-correlation with a learnable kernel.
#
# In deep learning literature, it’s confusingly referred to as
# Convolution.
#
# The backward computes the gradients wrt the input and gradients wrt the
# filter.
#
# **Implementation:**
#
# *Please Note that the implementation serves as an illustration, and we
# did not verify it’s correctness*

import torch
from torch.autograd import Function
from torch.autograd import Variable

from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


class ScipyConv2dFunction(Function):
    @staticmethod
    def forward(ctx, input, filter):
        result = correlate2d(input.numpy(), filter.numpy(), mode='valid')
        ctx.save_for_backward(input, filter)
        return input.new(result)

    @staticmethod
    def backward(ctx, grad_output):
        input, filter = ctx.saved_tensors
        grad_output = grad_output.data
        grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
        grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')

        return Variable(grad_output.new(grad_input)), \
            Variable(grad_output.new(grad_filter))


class ScipyConv2d(Module):

    def __init__(self, kh, kw):
        super(ScipyConv2d, self).__init__()
        self.filter = Parameter(torch.randn(kh, kw))

    def forward(self, input):
        return ScipyConv2dFunction.apply(input, self.filter)

###############################################################
# **Example usage:**

module = ScipyConv2d(3, 3)
print(list(module.parameters()))
input = Variable(torch.randn(10, 10), requires_grad=True)
output = module(input)
print(output)
output.backward(torch.randn(8, 8))
print(input.grad)

In this example, backward function is defined by ScipyConv2dFunction function.

And ScipyConv2dFunction is called by forward function.

Am I correct?