pytorch RuntimeError: Expected object of scalar type Double but got scalar type Float

Now that I have more experience with pytorch, I think I can explain the error message. It seems that the line

RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'mat2' in call to _th_mm

is actually refering to the weights of the linear layer when the matrix multiplication is called. Since the input is double while the weights are float, it makes sense for the line

output = input.matmul(weight.t())

to expect the weights to be double.