Python: how does the functools cmp_to_key function works?

I just realized that, despite not being a function, the K class is a callable, because it's a class! and classes are callables that, when called, creates a new instance, initializes it by calling the corresponding __init__ and then returns that instance.

This way it behaves as a key function because K receives the object when called, and wraps this object in a K instance, which is able to be compared against other K instances.

Correct me if I'm wrong. I feel I'm getting into the, unfamiliar to me, meta-classes territory.


No, sorted function (or list.sort) internally does not need to check if the object it received is a function or a class . All it cares about is that the object it received in key argument should be callable and should return a value that can be compared to other values when called.

Classes are also callable , when you call a class , you receive the instance of that class back.

To answer your question, first we need to understand (atleast at a basic level) how key argument works -

  1. The key callable is called for each element and it receives back the object with which it should sort.

  2. After receiving the new object, it compares this to other objects (again received by calling the key callable with the othe element).

Now the important thing to note here is that the new object received is compared against other same objects.

Now onto your equivalent code, when you create an instance of that class, it can be compared to other instances of the same class using your mycmp function. And sort when sorting the values compares these objects (in-effect) calling your mycmp() function to determine whether the value is less than or greater than the other object.

Example with print statements -

>>> def cmp_to_key(mycmp):
...     'Convert a cmp= function into a key= function'
...     class K(object):
...         def __init__(self, obj, *args):
...             print('obj created with ',obj)
...             self.obj = obj
...         def __lt__(self, other):
...             print('comparing less than ',self.obj)
...             return mycmp(self.obj, other.obj) < 0
...         def __gt__(self, other):
...             print('comparing greter than ',self.obj)
...             return mycmp(self.obj, other.obj) > 0
...         def __eq__(self, other):
...             print('comparing equal to ',self.obj)
...             return mycmp(self.obj, other.obj) == 0
...         def __le__(self, other):
...             print('comparing less than equal ',self.obj)
...             return mycmp(self.obj, other.obj) <= 0
...         def __ge__(self, other):
...             print('comparing greater than equal',self.obj)
...             return mycmp(self.obj, other.obj) >= 0
...         def __ne__(self, other):
...             print('comparing not equal ',self.obj)
...             return mycmp(self.obj, other.obj) != 0
...     return K
...
>>> def mycmp(a, b):
...     print("In Mycmp for", a, ' ', b)
...     if a < b:
...         return -1
...     elif a > b:
...         return 1
...     return 0
...
>>> print(sorted([3,4,2,5],key=cmp_to_key(mycmp)))
obj created with  3
obj created with  4
obj created with  2
obj created with  5
comparing less than  4
In Mycmp for 4   3
comparing less than  2
In Mycmp for 2   4
comparing less than  2
In Mycmp for 2   4
comparing less than  2
In Mycmp for 2   3
comparing less than  5
In Mycmp for 5   3
comparing less than  5
In Mycmp for 5   4
[2, 3, 4, 5]