How to flatten a list that has: primitives data types, lists and generators?

A faster approach is to avoid the use of global variables:

def to_flatten3(my_list, primitives=(bool, str, int, float)):
    flatten = []
    for item in my_list:
        if isinstance(item, primitives):
            flatten.append(item)
        else:
            flatten.extend(item)
    return flatten

whose timings are:

list_1 = [1, 2, 3, 'ID45785', False, '', 2.85, [1, 2, 'ID85639', True, 1.8], (e for e in range(589, 591))]

%timeit to_flatten(list_1 * 100)
# 1000 loops, best of 3: 296 µs per loop
%timeit to_flatten1(list_1 * 100)
# 1000 loops, best of 3: 255 µs per loop
%timeit to_flatten2(list_1 * 100)
# 10000 loops, best of 3: 183 µs per loop
%timeit to_flatten3(list_1 * 100)
# 10000 loops, best of 3: 168 µs per loop

Note that this would not flatten arbitrarily nested inputs, but only a single nesting level.


To flatten arbitrarily nested inputs, one could use:

def flatten_iter(items, primitives=(bool, int, float, str)):
    buffer = []
    iter_items = iter(items)
    while True:
        try:
            item = next(iter_items)
            if isinstance(item, primitives) or not hasattr(item, '__iter__'):
                yield item
            else:
                buffer.append(iter_items)
                iter_items = iter(item)
        except StopIteration:
            if buffer:
                iter_items = buffer.pop()
            else:
                break

or:

def flatten_recursive(
        items,
        primitives=(bool, int, float, str)):
    for item in items:
        if isinstance(item, primitives) or not hasattr(item, '__iter__'):
            yield item
        else:
            for subitem in flatten_recursive(item, primitives):
                yield subitem

which are both slower, but work correctly for deeper nesting (the result of to_flatten3(), like the original approach, is not flat):

list_2 = [list_1, [[[[1], 2], 3], 4], 5]
print(to_flatten3(list_2))
# [1, 2, 3, 'ID45785', False, '', 2.85, [1, 2, 'ID85639', True, 1.8], <generator object <genexpr> at 0x7f1c92dff6d0>, [[[1], 2], 3], 4, 5]
print(list(flatten_iter(list_2)))
# [1, 2, 3, 'ID45785', False, '', 2.85, 1, 2, 'ID85639', True, 1.8, 1, 2, 3, 4, 5]
print(list(flatten_recursive(list_2)))
# [1, 2, 3, 'ID45785', False, '', 2.85, 1, 2, 'ID85639', True, 1.8, 1, 2, 3, 4, 5]

(Note that the generator expression is already consumed here and hence produces no objects.)

Timewise, the iterative solution proposed here is ~3x slower, while the recursive solution is ~2x slower for the tested input, which only has one nesting level (and to_flatten3() would also work correctly):

%timeit list(flatten_iter(list_1 * 100))
# 1000 loops, best of 3: 450 µs per loop
%timeit list(flatten_recursive(list_1 * 100))
# 1000 loops, best of 3: 291 µs per loop

When the input has more nesting levels, the timings are:

%timeit list(flatten_iter(list_2 * 100))
# 1000 loops, best of 3: 953 µs per loop
%timeit list(flatten_recursive(list_2 * 100))
# 1000 loops, best of 3: 714 µs per loop

And the recursive solution is again faster (by approx. 30% for the tested input) than the iterative one.

While, typically, iterative methods execute faster in Python because it avoids expensive function calls, in the proposed solution, the cost of recursive function calls is offset by the try / except clause and the repeated use of iter().

These timings can be slightly improved with Cython.


hey I made this recursive function incase there are lists inside list

def flatten(list_to_flatten):
    flattened_list = []
    if(type(list_to_flatten) in [str, bool, int, float]):
        return [list_to_flatten]
    else:
        for item in list_to_flatten:
            flattened_list.extend(flatten(item))
    return flattened_list

after I did more tests I found that @juanpa.arrivillaga suggestion improved my code ~10% faster, also if I put the primitives types in a variable I get my code with ~20% faster:

def to_flatten(my_list):
    flatten = []
    for item in my_list:
        if isinstance(item, (str, bool, int, float)) :
            flatten.append(item)
        else:
            flatten.extend(list(item))

    return flatten

def to_flatten1(my_list):
    """with @juanpa.arrivillaga suggestion"""

    flatten = []
    for item in my_list:
        if isinstance(item, (bool, str, int, float)):
            flatten.append(item)
        else:
            flatten.extend(item)

    return flatten

primitives = (bool, str, int, float)    


def to_flatten2(my_list):
    flatten = []
    for item in my_list:
        if isinstance(item, primitives):
            flatten.append(item)
        else:
            flatten.extend(item)

    return flatten

%timeit to_flatten(list_1)
%timeit to_flatten1(list_1)
%timeit to_flatten2(list_1)

output:

3.5 µs ± 18.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
3.15 µs ± 35.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
2.31 µs ± 12.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)