Higher-Order Functions

Introduction

A higher-order function is a function that takes functions as arguments or returns functions as results. These enable powerful abstractions for transforming and combining data. This reading covers the essential higher-order functions used in functional programming.

Learning Objectives

By the end of this reading, you will be able to:

  • Use map, filter, and reduce effectively
  • Compose functions to build complex transformations
  • Apply partial application and currying
  • Implement common higher-order function patterns
  • Chain operations in a functional pipeline

1. Map

Concept

Map applies a function to each element of a collection, returning a new collection with transformed elements.

from typing import List, Callable, TypeVar, Iterable

T = TypeVar('T')
U = TypeVar('U')

# Built-in map
numbers = [1, 2, 3, 4, 5]
squared = list(map(lambda x: x ** 2, numbers))  # [1, 4, 9, 16, 25]

# Implementing map
def my_map(func: Callable[[T], U], items: Iterable[T]) -> List[U]:
    """Apply func to each item"""
    return [func(item) for item in items]

# Recursive implementation
def map_recursive(func: Callable[[T], U], items: List[T]) -> List[U]:
    if not items:
        return []
    return [func(items[0])] + map_recursive(func, items[1:])

# Examples
names = ['alice', 'bob', 'charlie']
upper_names = list(map(str.upper, names))  # ['ALICE', 'BOB', 'CHARLIE']

# Map with multiple iterables
a = [1, 2, 3]
b = [10, 20, 30]
sums = list(map(lambda x, y: x + y, a, b))  # [11, 22, 33]

Practical Applications

from dataclasses import dataclass
from typing import List

@dataclass
class User:
    name: str
    email: str
    age: int

users = [
    User("Alice", "alice@example.com", 30),
    User("Bob", "bob@example.com", 25),
    User("Charlie", "charlie@example.com", 35),
]

# Extract single field
emails = list(map(lambda u: u.email, users))

# Transform to different type
user_summaries = list(map(
    lambda u: f"{u.name} ({u.age})",
    users
))

# Apply method to all
class Temperature:
    def __init__(self, celsius: float):
        self.celsius = celsius

    def to_fahrenheit(self) -> float:
        return self.celsius * 9/5 + 32

temps = [Temperature(0), Temperature(100), Temperature(37)]
fahrenheits = list(map(lambda t: t.to_fahrenheit(), temps))
# Or: list(map(Temperature.to_fahrenheit, temps))

2. Filter

Concept

Filter selects elements from a collection that satisfy a predicate (a function returning boolean).

from typing import Callable, Iterable, List

# Built-in filter
numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
evens = list(filter(lambda x: x % 2 == 0, numbers))  # [2, 4, 6, 8, 10]

# Implementing filter
def my_filter(predicate: Callable[[T], bool], items: Iterable[T]) -> List[T]:
    """Keep items where predicate returns True"""
    return [item for item in items if predicate(item)]

# Recursive implementation
def filter_recursive(predicate: Callable[[T], bool], items: List[T]) -> List[T]:
    if not items:
        return []
    head, tail = items[0], items[1:]
    if predicate(head):
        return [head] + filter_recursive(predicate, tail)
    return filter_recursive(predicate, tail)

# Examples
words = ['hello', 'world', 'functional', 'programming']
long_words = list(filter(lambda w: len(w) > 5, words))  # ['functional', 'programming']

# Filter with None removes falsy values
mixed = [0, 1, '', 'hello', None, [], [1, 2]]
truthy = list(filter(None, mixed))  # [1, 'hello', [1, 2]]

Combining Filter and Map

from typing import Optional

# Common pattern: filter then map
def get_adult_names(users: List[User]) -> List[str]:
    adults = filter(lambda u: u.age >= 18, users)
    names = map(lambda u: u.name, adults)
    return list(names)

# Or as a pipeline
def get_adult_names_v2(users: List[User]) -> List[str]:
    return list(
        map(lambda u: u.name,
            filter(lambda u: u.age >= 18, users))
    )

# Using list comprehension (often clearer in Python)
def get_adult_names_v3(users: List[User]) -> List[str]:
    return [u.name for u in users if u.age >= 18]

# filter_map: filter and transform in one pass
def filter_map(
    predicate: Callable[[T], bool],
    transform: Callable[[T], U],
    items: Iterable[T]
) -> List[U]:
    """Filter then map in single pass"""
    return [transform(item) for item in items if predicate(item)]

3. Reduce (Fold)

Concept

Reduce (or fold) combines all elements of a collection into a single value using a binary function.

from functools import reduce
from typing import Callable, Iterable, TypeVar

T = TypeVar('T')

# Built-in reduce
numbers = [1, 2, 3, 4, 5]
total = reduce(lambda acc, x: acc + x, numbers)  # 15
product = reduce(lambda acc, x: acc * x, numbers)  # 120

# With initial value
total_with_init = reduce(lambda acc, x: acc + x, numbers, 0)  # 15
empty_safe = reduce(lambda acc, x: acc + x, [], 0)  # 0 (no error)

# Implementing reduce (left fold)
def my_reduce(
    func: Callable[[T, T], T],
    items: Iterable[T],
    initial: T = None
) -> T:
    """Reduce items to single value using func"""
    iterator = iter(items)

    if initial is None:
        try:
            accumulator = next(iterator)
        except StopIteration:
            raise TypeError("reduce() of empty sequence with no initial value")
    else:
        accumulator = initial

    for item in iterator:
        accumulator = func(accumulator, item)

    return accumulator

# Recursive implementation
def reduce_recursive(func: Callable[[T, T], T], items: List[T], acc: T) -> T:
    if not items:
        return acc
    return reduce_recursive(func, items[1:], func(acc, items[0]))

Fold Left vs Fold Right

from typing import List

# Fold left: ((((init op a) op b) op c) op d)
def fold_left(func, initial, items: List):
    acc = initial
    for item in items:
        acc = func(acc, item)
    return acc

# Fold right: (a op (b op (c op (d op init))))
def fold_right(func, initial, items: List):
    if not items:
        return initial
    return func(items[0], fold_right(func, initial, items[1:]))

# Iterative fold right (more efficient)
def fold_right_iter(func, initial, items: List):
    acc = initial
    for item in reversed(items):
        acc = func(item, acc)
    return acc

# Example where order matters
items = [1, 2, 3]

# Subtraction: left vs right
left_result = fold_left(lambda acc, x: acc - x, 0, items)   # ((0-1)-2)-3 = -6
right_result = fold_right(lambda x, acc: x - acc, 0, items)  # 1-(2-(3-0)) = 2

# List building: right fold is natural
def build_list_right(items):
    return fold_right(lambda x, acc: [x] + acc, [], items)

def build_list_left(items):
    return fold_left(lambda acc, x: acc + [x], [], items)

Practical Reduce Examples

from typing import Dict, Any

# Finding maximum
numbers = [3, 1, 4, 1, 5, 9, 2, 6]
maximum = reduce(lambda a, b: a if a > b else b, numbers)

# Flattening nested lists
nested = [[1, 2], [3, 4], [5, 6]]
flat = reduce(lambda acc, x: acc + x, nested, [])  # [1, 2, 3, 4, 5, 6]

# Building dictionary from pairs
pairs = [('a', 1), ('b', 2), ('c', 3)]
dictionary = reduce(
    lambda acc, pair: {**acc, pair[0]: pair[1]},
    pairs,
    {}
)

# Counting occurrences
words = ['apple', 'banana', 'apple', 'cherry', 'banana', 'apple']
counts = reduce(
    lambda acc, word: {**acc, word: acc.get(word, 0) + 1},
    words,
    {}
)

# Group by
users = [
    {'name': 'Alice', 'dept': 'Engineering'},
    {'name': 'Bob', 'dept': 'Sales'},
    {'name': 'Charlie', 'dept': 'Engineering'},
]

by_dept = reduce(
    lambda acc, user: {
        **acc,
        user['dept']: acc.get(user['dept'], []) + [user['name']]
    },
    users,
    {}
)
# {'Engineering': ['Alice', 'Charlie'], 'Sales': ['Bob']}

4. Function Composition

Concept

Composition combines functions to create new functions: (f ∘ g)(x) = f(g(x))

from typing import Callable, TypeVar

A = TypeVar('A')
B = TypeVar('B')
C = TypeVar('C')

# Basic composition
def compose(f: Callable[[B], C], g: Callable[[A], B]) -> Callable[[A], C]:
    """Compose two functions: compose(f, g)(x) = f(g(x))"""
    return lambda x: f(g(x))

# Example
def add_one(x: int) -> int:
    return x + 1

def double(x: int) -> int:
    return x * 2

def square(x: int) -> int:
    return x * x

# Compose functions
add_then_double = compose(double, add_one)  # double(add_one(x))
print(add_then_double(5))  # double(6) = 12

double_then_add = compose(add_one, double)  # add_one(double(x))
print(double_then_add(5))  # add_one(10) = 11

# Multiple composition
def compose_all(*funcs: Callable) -> Callable:
    """Compose multiple functions right-to-left"""
    return reduce(compose, funcs)

transform = compose_all(str, square, double, add_one)
# str(square(double(add_one(x))))
print(transform(2))  # str(square(double(3))) = str(square(6)) = str(36) = "36"

Pipe (Left-to-Right Composition)

def pipe(*funcs: Callable) -> Callable:
    """Compose functions left-to-right (more readable)"""
    def piped(x):
        result = x
        for func in funcs:
            result = func(result)
        return result
    return piped

# Same as compose_all but reversed order
transform = pipe(add_one, double, square, str)
# str(square(double(add_one(x)))) but written in order of application
print(transform(2))  # "36"

# Often used with data processing
process_user = pipe(
    lambda u: u.strip(),           # Clean whitespace
    lambda u: u.lower(),           # Lowercase
    lambda u: u.replace(' ', '_'), # Replace spaces
    lambda u: f"user_{u}"          # Add prefix
)

print(process_user("  John Doe  "))  # "user_john_doe"

Point-Free Style

# Point-free: Define functions without mentioning arguments

# Not point-free
def get_lengths(strings):
    return list(map(lambda s: len(s), strings))

# Point-free
get_lengths_pf = lambda strings: list(map(len, strings))

# More examples
from operator import add, mul

# Not point-free
def sum_list(numbers):
    return reduce(lambda a, b: a + b, numbers, 0)

# Point-free
sum_list_pf = lambda nums: reduce(add, nums, 0)

# Building complex functions point-free
from functools import partial

def compose2(f, g):
    return lambda x: f(g(x))

# is_even: number -> bool
is_even = lambda x: x % 2 == 0

# negate: bool -> bool
negate = lambda x: not x

# is_odd without mentioning argument
is_odd = compose2(negate, is_even)

print(is_odd(5))  # True
print(is_odd(4))  # False

5. Partial Application and Currying

Partial Application

Partial application fixes some arguments of a function, producing a function with fewer arguments.

from functools import partial

def greet(greeting: str, name: str) -> str:
    return f"{greeting}, {name}!"

# Partial application
say_hello = partial(greet, "Hello")
say_goodbye = partial(greet, "Goodbye")

print(say_hello("Alice"))   # "Hello, Alice!"
print(say_goodbye("Bob"))   # "Goodbye, Bob!"

# Practical example: configuration
def fetch_data(base_url: str, endpoint: str, params: dict) -> dict:
    """Simulated API fetch"""
    return {'url': f"{base_url}/{endpoint}", 'params': params}

# Create specialized fetchers
fetch_from_api_v1 = partial(fetch_data, "https://api.example.com/v1")
fetch_users = partial(fetch_from_api_v1, "users")

print(fetch_users({'limit': 10}))
# {'url': 'https://api.example.com/v1/users', 'params': {'limit': 10}}

Currying

Currying transforms a function of multiple arguments into a sequence of functions each taking one argument.

from typing import Callable

# Regular function
def add(a: int, b: int, c: int) -> int:
    return a + b + c

# Curried version
def add_curried(a: int) -> Callable[[int], Callable[[int], int]]:
    def add_b(b: int) -> Callable[[int], int]:
        def add_c(c: int) -> int:
            return a + b + c
        return add_c
    return add_b

# Using curried function
add_5 = add_curried(5)
add_5_3 = add_5(3)
result = add_5_3(2)  # 10

# Or chain calls
result2 = add_curried(5)(3)(2)  # 10

# Generic curry decorator
def curry(func: Callable) -> Callable:
    """Convert function to curried version"""
    import inspect
    num_args = len(inspect.signature(func).parameters)

    def curried(*args):
        if len(args) >= num_args:
            return func(*args)
        return lambda *more_args: curried(*args, *more_args)

    return curried

@curry
def multiply(a: int, b: int, c: int) -> int:
    return a * b * c

# All equivalent
print(multiply(2, 3, 4))      # 24
print(multiply(2)(3)(4))      # 24
print(multiply(2, 3)(4))      # 24
print(multiply(2)(3, 4))      # 24

# Practical curried functions
@curry
def map_over(func, items):
    return list(map(func, items))

@curry
def filter_by(predicate, items):
    return list(filter(predicate, items))

# Create reusable transformations
double_all = map_over(lambda x: x * 2)
keep_positive = filter_by(lambda x: x > 0)

numbers = [-2, -1, 0, 1, 2, 3]
print(double_all(numbers))     # [-4, -2, 0, 2, 4, 6]
print(keep_positive(numbers))  # [1, 2, 3]

6. Functional Pipelines

Building Data Pipelines

from typing import List, Dict, Any, Callable
from functools import reduce

# Pipeline helper
def pipeline(*steps: Callable) -> Callable:
    """Create a data processing pipeline"""
    def execute(data):
        return reduce(lambda d, step: step(d), steps, data)
    return execute

# Example: Processing user data
raw_users = [
    {'name': 'ALICE', 'age': 17, 'active': True},
    {'name': 'BOB', 'age': 25, 'active': False},
    {'name': 'CHARLIE', 'age': 30, 'active': True},
    {'name': 'DIANA', 'age': 22, 'active': True},
]

# Define transformation steps
def normalize_names(users: List[Dict]) -> List[Dict]:
    return [{**u, 'name': u['name'].title()} for u in users]

def filter_active(users: List[Dict]) -> List[Dict]:
    return [u for u in users if u['active']]

def filter_adults(users: List[Dict]) -> List[Dict]:
    return [u for u in users if u['age'] >= 18]

def extract_names(users: List[Dict]) -> List[str]:
    return [u['name'] for u in users]

def sort_names(names: List[str]) -> List[str]:
    return sorted(names)

# Build pipeline
process_users = pipeline(
    normalize_names,
    filter_active,
    filter_adults,
    extract_names,
    sort_names
)

result = process_users(raw_users)
print(result)  # ['Charlie', 'Diana']

Lazy Pipelines with Generators

from typing import Iterator, Callable, TypeVar, Iterable

T = TypeVar('T')
U = TypeVar('U')

def lazy_map(func: Callable[[T], U]) -> Callable[[Iterable[T]], Iterator[U]]:
    """Lazy map that returns generator"""
    def mapper(items: Iterable[T]) -> Iterator[U]:
        for item in items:
            yield func(item)
    return mapper

def lazy_filter(predicate: Callable[[T], bool]) -> Callable[[Iterable[T]], Iterator[T]]:
    """Lazy filter that returns generator"""
    def filterer(items: Iterable[T]) -> Iterator[T]:
        for item in items:
            if predicate(item):
                yield item
    return filterer

def lazy_take(n: int) -> Callable[[Iterable[T]], Iterator[T]]:
    """Take first n items"""
    def taker(items: Iterable[T]) -> Iterator[T]:
        for i, item in enumerate(items):
            if i >= n:
                break
            yield item
    return taker

# Lazy pipeline
def lazy_pipeline(*steps: Callable) -> Callable:
    """Create a lazy pipeline (generators all the way)"""
    def execute(data):
        result = data
        for step in steps:
            result = step(result)
        return result
    return execute

# Process potentially infinite sequence
def integers_from(start: int) -> Iterator[int]:
    """Generate infinite sequence"""
    n = start
    while True:
        yield n
        n += 1

# Find first 5 even squares greater than 100
find_numbers = lazy_pipeline(
    lazy_map(lambda x: x * x),           # Square
    lazy_filter(lambda x: x > 100),       # Greater than 100
    lazy_filter(lambda x: x % 2 == 0),    # Even
    lazy_take(5)                          # First 5
)

result = list(find_numbers(integers_from(1)))
print(result)  # [144, 196, 256, 324, 400]
# Only computed what was needed!

Fluent Interface

from typing import List, TypeVar, Callable, Generic
from functools import reduce

T = TypeVar('T')
U = TypeVar('U')

class Stream(Generic[T]):
    """Fluent functional stream interface"""

    def __init__(self, items: Iterable[T]):
        self._items = items

    def map(self, func: Callable[[T], U]) -> 'Stream[U]':
        return Stream(func(item) for item in self._items)

    def filter(self, predicate: Callable[[T], bool]) -> 'Stream[T]':
        return Stream(item for item in self._items if predicate(item))

    def reduce(self, func: Callable[[T, T], T], initial: T = None) -> T:
        if initial is None:
            return reduce(func, self._items)
        return reduce(func, self._items, initial)

    def take(self, n: int) -> 'Stream[T]':
        def take_gen():
            for i, item in enumerate(self._items):
                if i >= n:
                    break
                yield item
        return Stream(take_gen())

    def skip(self, n: int) -> 'Stream[T]':
        def skip_gen():
            for i, item in enumerate(self._items):
                if i >= n:
                    yield item
        return Stream(skip_gen())

    def flat_map(self, func: Callable[[T], Iterable[U]]) -> 'Stream[U]':
        def flat_gen():
            for item in self._items:
                yield from func(item)
        return Stream(flat_gen())

    def to_list(self) -> List[T]:
        return list(self._items)

    def first(self) -> T:
        return next(iter(self._items))

    def count(self) -> int:
        return sum(1 for _ in self._items)

# Usage
result = (
    Stream([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
    .filter(lambda x: x % 2 == 0)
    .map(lambda x: x ** 2)
    .take(3)
    .to_list()
)
print(result)  # [4, 16, 36]

# More complex example
words = ['hello', 'world', 'functional', 'programming']
result = (
    Stream(words)
    .flat_map(lambda w: list(w))  # Flatten to chars
    .filter(lambda c: c in 'aeiou')  # Keep vowels
    .map(str.upper)
    .to_list()
)
print(result)  # ['E', 'O', 'O', 'U', 'I', 'O', 'A', 'O', 'A', 'I']

Exercises

Basic

  1. Use map to convert a list of temperatures from Celsius to Fahrenheit.

  2. Use filter to get all palindromes from a list of words.

  3. Use reduce to find the longest string in a list.

Intermediate

  1. Implement zip_with(func, list1, list2) that combines two lists using func.

  2. Create a group_by(key_func, items) function using reduce.

  3. Build a data processing pipeline that: filters, transforms, and aggregates.

Advanced

  1. Implement a lazy flat_map that works with infinite sequences.

  2. Create a curried version of reduce that can be partially applied.

  3. Design a type-safe fluent API for building SQL-like queries.


Summary

  • Map transforms each element: [a, b, c] -> [f(a), f(b), f(c)]
  • Filter selects elements: [a, b, c] -> [x for x in items if p(x)]
  • Reduce combines elements: [a, b, c] -> f(f(f(init, a), b), c)
  • Composition builds complex functions from simple ones
  • Currying enables partial application of any argument
  • Pipelines chain transformations for readable data processing

Next Reading

Advanced Functional Concepts →