Recursion and Divide-and-Conquer
Introduction
Recursion is a problem-solving technique where a function calls itself to solve smaller subproblems. Divide-and-conquer is a strategy that uses recursion to break problems into independent subproblems, solve them, and combine the results.
Learning Objectives
By the end of this reading, you will be able to:
- Understand how recursion works
- Identify base cases and recursive cases
- Apply divide-and-conquer to solve problems
- Analyze recursive algorithms using recurrence relations
- Convert between recursive and iterative solutions
1. Understanding Recursion
The Recursive Pattern
Every recursive function needs:
- Base case(s): Condition(s) to stop recursion
- Recursive case(s): Call itself with a smaller problem
- Progress: Each call must move toward a base case
Simple Example: Factorial
def factorial(n):
# Base case
if n <= 1:
return 1
# Recursive case
return n * factorial(n - 1)
# factorial(5) = 5 * factorial(4)
# = 5 * 4 * factorial(3)
# = 5 * 4 * 3 * factorial(2)
# = 5 * 4 * 3 * 2 * factorial(1)
# = 5 * 4 * 3 * 2 * 1
# = 120
The Call Stack
Each recursive call adds a frame to the call stack:
factorial(5)
├── factorial(4)
│ ├── factorial(3)
│ │ ├── factorial(2)
│ │ │ ├── factorial(1)
│ │ │ │ └── return 1
│ │ │ └── return 2 * 1 = 2
│ │ └── return 3 * 2 = 6
│ └── return 4 * 6 = 24
└── return 5 * 24 = 120
Common Recursion Mistakes
# Missing base case - infinite recursion!
def bad_factorial(n):
return n * bad_factorial(n - 1)
# Not making progress toward base case
def bad_countdown(n):
if n == 0:
return
print(n)
bad_countdown(n) # Should be n - 1
# Wrong base case
def wrong_sum(lst):
if len(lst) == 0:
return lst[0] # Error! Empty list has no elements
return lst[0] + wrong_sum(lst[1:])
2. Types of Recursion
Linear Recursion
One recursive call per function invocation.
def sum_list(lst):
if not lst:
return 0
return lst[0] + sum_list(lst[1:])
Binary/Multiple Recursion
Multiple recursive calls per invocation.
def fibonacci(n):
if n <= 1:
return n
return fibonacci(n - 1) + fibonacci(n - 2)
Tail Recursion
Recursive call is the last operation (can be optimized).
# Not tail recursive - multiplication after recursive call
def factorial(n):
if n <= 1:
return 1
return n * factorial(n - 1)
# Tail recursive version - uses accumulator
def factorial_tail(n, acc=1):
if n <= 1:
return acc
return factorial_tail(n - 1, n * acc)
Note: Python doesn't optimize tail recursion, but many languages do.
Mutual Recursion
Functions call each other.
def is_even(n):
if n == 0:
return True
return is_odd(n - 1)
def is_odd(n):
if n == 0:
return False
return is_even(n - 1)
3. Classic Recursive Problems
Sum of Digits
def sum_digits(n):
if n < 10:
return n
return n % 10 + sum_digits(n // 10)
# sum_digits(1234) = 4 + sum_digits(123)
# = 4 + 3 + sum_digits(12)
# = 4 + 3 + 2 + sum_digits(1)
# = 4 + 3 + 2 + 1 = 10
Reverse String
def reverse_string(s):
if len(s) <= 1:
return s
return reverse_string(s[1:]) + s[0]
# reverse_string("hello")
# = reverse_string("ello") + "h"
# = reverse_string("llo") + "e" + "h"
# = reverse_string("lo") + "l" + "e" + "h"
# = reverse_string("o") + "l" + "l" + "e" + "h"
# = "o" + "l" + "l" + "e" + "h" = "olleh"
Power Function
def power(base, exp):
if exp == 0:
return 1
if exp < 0:
return 1 / power(base, -exp)
return base * power(base, exp - 1)
# Efficient version: O(log n)
def power_fast(base, exp):
if exp == 0:
return 1
if exp < 0:
return 1 / power_fast(base, -exp)
if exp % 2 == 0:
half = power_fast(base, exp // 2)
return half * half
return base * power_fast(base, exp - 1)
Tower of Hanoi
def hanoi(n, source, auxiliary, target):
"""Move n disks from source to target using auxiliary."""
if n == 1:
print(f"Move disk from {source} to {target}")
return
hanoi(n - 1, source, target, auxiliary)
print(f"Move disk from {source} to {target}")
hanoi(n - 1, auxiliary, source, target)
# hanoi(3, 'A', 'B', 'C')
# 2^n - 1 = 7 moves
4. Divide and Conquer
The Strategy
- Divide: Break problem into smaller subproblems
- Conquer: Solve subproblems recursively
- Combine: Merge subproblem solutions
Merge Sort
def merge_sort(arr):
# Base case
if len(arr) <= 1:
return arr
# Divide
mid = len(arr) // 2
left = arr[:mid]
right = arr[mid:]
# Conquer
left = merge_sort(left)
right = merge_sort(right)
# Combine
return merge(left, right)
def merge(left, right):
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
Binary Search (Recursive)
def binary_search(arr, target, left=0, right=None):
if right is None:
right = len(arr) - 1
# Base case
if left > right:
return -1
# Divide
mid = (left + right) // 2
# Conquer
if arr[mid] == target:
return mid
elif arr[mid] < target:
return binary_search(arr, target, mid + 1, right)
else:
return binary_search(arr, target, left, mid - 1)
Quick Sort
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
# Divide
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
# Conquer and Combine
return quick_sort(left) + middle + quick_sort(right)
Maximum Subarray (Divide and Conquer)
def max_crossing_sum(arr, left, mid, right):
# Left sum
left_sum = float('-inf')
total = 0
for i in range(mid, left - 1, -1):
total += arr[i]
left_sum = max(left_sum, total)
# Right sum
right_sum = float('-inf')
total = 0
for i in range(mid + 1, right + 1):
total += arr[i]
right_sum = max(right_sum, total)
return left_sum + right_sum
def max_subarray(arr, left=0, right=None):
if right is None:
right = len(arr) - 1
# Base case
if left == right:
return arr[left]
mid = (left + right) // 2
# Maximum of:
# 1. Max subarray in left half
# 2. Max subarray in right half
# 3. Max subarray crossing midpoint
return max(
max_subarray(arr, left, mid),
max_subarray(arr, mid + 1, right),
max_crossing_sum(arr, left, mid, right)
)
5. Recurrence Relations
Writing Recurrences
For a recursive algorithm, express time complexity as a recurrence:
Merge Sort:
T(n) = 2T(n/2) + O(n)
↑ ↑
Two Merge
halves step
Binary Search:
T(n) = T(n/2) + O(1)
Fibonacci:
T(n) = T(n-1) + T(n-2) + O(1)
Solving Recurrences
Substitution Method:
- Guess the form of solution
- Substitute and verify
Master Theorem: For T(n) = aT(n/b) + f(n):
- If f(n) = O(n^(log_b(a) - ε)), then T(n) = Θ(n^log_b(a))
- If f(n) = Θ(n^log_b(a)), then T(n) = Θ(n^log_b(a) * log n)
- If f(n) = Ω(n^(log_b(a) + ε)), then T(n) = Θ(f(n))
Examples:
Merge Sort: T(n) = 2T(n/2) + n
a=2, b=2, f(n)=n, log_2(2)=1
f(n) = Θ(n^1), case 2: T(n) = Θ(n log n)
Binary Search: T(n) = T(n/2) + 1
a=1, b=2, f(n)=1, log_2(1)=0
f(n) = Θ(n^0) = Θ(1), case 2: T(n) = Θ(log n)
6. Recursion with Data Structures
Linked Lists
def length(head):
if head is None:
return 0
return 1 + length(head.next)
def reverse(head, prev=None):
if head is None:
return prev
next_node = head.next
head.next = prev
return reverse(next_node, head)
Trees
def tree_height(node):
if node is None:
return -1
return 1 + max(tree_height(node.left), tree_height(node.right))
def tree_sum(node):
if node is None:
return 0
return node.val + tree_sum(node.left) + tree_sum(node.right)
def inorder(node, result=None):
if result is None:
result = []
if node:
inorder(node.left, result)
result.append(node.val)
inorder(node.right, result)
return result
Graphs
def dfs(graph, node, visited=None):
if visited is None:
visited = set()
visited.add(node)
print(node)
for neighbor in graph[node]:
if neighbor not in visited:
dfs(graph, neighbor, visited)
7. Recursion to Iteration
Any recursive algorithm can be converted to iteration using an explicit stack.
Factorial
# Recursive
def factorial_rec(n):
if n <= 1:
return 1
return n * factorial_rec(n - 1)
# Iterative
def factorial_iter(n):
result = 1
for i in range(2, n + 1):
result *= i
return result
Tree Traversal
# Recursive inorder
def inorder_rec(node):
if node:
inorder_rec(node.left)
print(node.val)
inorder_rec(node.right)
# Iterative inorder
def inorder_iter(root):
stack = []
current = root
while stack or current:
while current:
stack.append(current)
current = current.left
current = stack.pop()
print(current.val)
current = current.right
DFS
# Recursive DFS
def dfs_rec(graph, start, visited=None):
if visited is None:
visited = set()
visited.add(start)
for neighbor in graph[start]:
if neighbor not in visited:
dfs_rec(graph, neighbor, visited)
# Iterative DFS
def dfs_iter(graph, start):
visited = set()
stack = [start]
while stack:
node = stack.pop()
if node not in visited:
visited.add(node)
for neighbor in graph[node]:
if neighbor not in visited:
stack.append(neighbor)
8. Memoization
Cache results to avoid redundant computation.
# Naive Fibonacci - O(2^n)
def fib_naive(n):
if n <= 1:
return n
return fib_naive(n - 1) + fib_naive(n - 2)
# Memoized Fibonacci - O(n)
def fib_memo(n, cache=None):
if cache is None:
cache = {}
if n in cache:
return cache[n]
if n <= 1:
return n
cache[n] = fib_memo(n - 1, cache) + fib_memo(n - 2, cache)
return cache[n]
# Using decorator
from functools import lru_cache
@lru_cache(maxsize=None)
def fib_cached(n):
if n <= 1:
return n
return fib_cached(n - 1) + fib_cached(n - 2)
9. Common Patterns
Generate All Subsets
def subsets(arr):
if not arr:
return [[]]
rest = subsets(arr[1:])
return rest + [[arr[0]] + s for s in rest]
# subsets([1,2,3]) = [[], [3], [2], [2,3], [1], [1,3], [1,2], [1,2,3]]
Generate All Permutations
def permutations(arr):
if len(arr) <= 1:
return [arr]
result = []
for i, elem in enumerate(arr):
rest = arr[:i] + arr[i+1:]
for perm in permutations(rest):
result.append([elem] + perm)
return result
Backtracking Template
def backtrack(state, choices):
if is_solution(state):
process_solution(state)
return
for choice in choices:
if is_valid(choice, state):
make_choice(state, choice)
backtrack(state, updated_choices)
undo_choice(state, choice) # Backtrack
Exercises
Basic
Write a recursive function to compute the sum of an array.
Write a recursive function to check if a string is a palindrome.
Write a recursive function to count occurrences of a character in a string.
Intermediate
Implement recursive binary search.
Write a function to find all subsets of a set that sum to a target.
Solve the Tower of Hanoi and analyze its time complexity.
Advanced
Implement merge sort and analyze using recurrence relations.
Write a recursive function to generate all valid combinations of n pairs of parentheses.
Solve the N-Queens problem using backtracking.
Summary
- Recursion needs base case, recursive case, and progress
- Call stack stores state for each recursive call
- Divide-and-conquer: divide, conquer, combine
- Master theorem solves common recurrences
- Memoization avoids redundant computation
- Any recursion can be converted to iteration with explicit stack
- Use recursion when problem has recursive structure