Code Generation
Introduction
Code generation is the final phase of compilation. The code generator transforms the AST (often after optimization) into executable code. This reading covers intermediate representations, instruction selection, register allocation, and generating code for a stack-based virtual machine.
Learning Objectives
By the end of this reading, you will be able to:
- Design intermediate representations (IR)
- Generate code for a stack-based VM
- Implement a simple bytecode interpreter
- Understand register allocation basics
- Generate assembly code (conceptually)
1. Intermediate Representations
Why IR?
- Decouple front-end (parsing) from back-end (code generation)
- Enable machine-independent optimizations
- Support multiple source languages and target platforms
Three-Address Code
from dataclasses import dataclass
from typing import Union, List, Optional
from enum import Enum, auto
class IROpcode(Enum):
# Arithmetic
ADD = auto()
SUB = auto()
MUL = auto()
DIV = auto()
NEG = auto()
# Comparison
LT = auto()
GT = auto()
EQ = auto()
NEQ = auto()
# Data movement
COPY = auto()
LOAD_CONST = auto()
LOAD_VAR = auto()
STORE_VAR = auto()
# Control flow
LABEL = auto()
JUMP = auto()
JUMP_IF_TRUE = auto()
JUMP_IF_FALSE = auto()
# Functions
CALL = auto()
RETURN = auto()
PARAM = auto()
@dataclass
class IRInstruction:
opcode: IROpcode
dest: Optional[str] = None # Destination (temporary or variable)
arg1: Optional[Union[str, int, float]] = None
arg2: Optional[Union[str, int, float]] = None
def __repr__(self):
if self.opcode == IROpcode.LABEL:
return f"{self.dest}:"
elif self.opcode == IROpcode.JUMP:
return f" jump {self.arg1}"
elif self.opcode == IROpcode.JUMP_IF_FALSE:
return f" if not {self.arg1} goto {self.arg2}"
elif self.opcode == IROpcode.LOAD_CONST:
return f" {self.dest} = {self.arg1}"
elif self.opcode == IROpcode.COPY:
return f" {self.dest} = {self.arg1}"
elif self.opcode == IROpcode.CALL:
return f" {self.dest} = call {self.arg1}, {self.arg2} args"
elif self.opcode == IROpcode.RETURN:
return f" return {self.arg1 or ''}"
elif self.opcode in (IROpcode.ADD, IROpcode.SUB, IROpcode.MUL, IROpcode.DIV):
op = {IROpcode.ADD: '+', IROpcode.SUB: '-',
IROpcode.MUL: '*', IROpcode.DIV: '/'}[self.opcode]
return f" {self.dest} = {self.arg1} {op} {self.arg2}"
elif self.opcode == IROpcode.NEG:
return f" {self.dest} = -{self.arg1}"
elif self.opcode in (IROpcode.LT, IROpcode.GT, IROpcode.EQ, IROpcode.NEQ):
op = {IROpcode.LT: '<', IROpcode.GT: '>',
IROpcode.EQ: '==', IROpcode.NEQ: '!='}[self.opcode]
return f" {self.dest} = {self.arg1} {op} {self.arg2}"
else:
return f" {self.opcode.name} {self.dest} {self.arg1} {self.arg2}"
class IRGenerator:
"""Generate three-address code from AST"""
def __init__(self):
self.instructions: List[IRInstruction] = []
self.temp_counter = 0
self.label_counter = 0
def new_temp(self) -> str:
self.temp_counter += 1
return f"t{self.temp_counter}"
def new_label(self) -> str:
self.label_counter += 1
return f"L{self.label_counter}"
def emit(self, opcode: IROpcode, dest=None, arg1=None, arg2=None):
self.instructions.append(IRInstruction(opcode, dest, arg1, arg2))
def generate(self, node: AST) -> str:
"""Generate IR for node, return temp holding result"""
method_name = f'gen_{type(node).__name__}'
generator = getattr(self, method_name, self.gen_default)
return generator(node)
def gen_default(self, node):
raise NotImplementedError(f"No IR generator for {type(node).__name__}")
def gen_Program(self, node: Program):
for stmt in node.statements:
self.generate(stmt)
def gen_NumberLiteral(self, node: NumberLiteral) -> str:
temp = self.new_temp()
self.emit(IROpcode.LOAD_CONST, temp, node.value)
return temp
def gen_StringLiteral(self, node: StringLiteral) -> str:
temp = self.new_temp()
self.emit(IROpcode.LOAD_CONST, temp, node.value)
return temp
def gen_BoolLiteral(self, node: BoolLiteral) -> str:
temp = self.new_temp()
self.emit(IROpcode.LOAD_CONST, temp, 1 if node.value else 0)
return temp
def gen_Identifier(self, node: Identifier) -> str:
temp = self.new_temp()
self.emit(IROpcode.LOAD_VAR, temp, node.name)
return temp
def gen_BinaryOp(self, node: BinaryOp) -> str:
left = self.generate(node.left)
right = self.generate(node.right)
result = self.new_temp()
op_map = {
'+': IROpcode.ADD,
'-': IROpcode.SUB,
'*': IROpcode.MUL,
'/': IROpcode.DIV,
'<': IROpcode.LT,
'>': IROpcode.GT,
'==': IROpcode.EQ,
'!=': IROpcode.NEQ,
}
self.emit(op_map[node.op], result, left, right)
return result
def gen_UnaryOp(self, node: UnaryOp) -> str:
operand = self.generate(node.operand)
result = self.new_temp()
if node.op == '-':
self.emit(IROpcode.NEG, result, operand)
return result
def gen_Assignment(self, node: Assignment) -> str:
value = self.generate(node.value)
self.emit(IROpcode.STORE_VAR, node.target.name, value)
return node.target.name
def gen_ExpressionStmt(self, node: ExpressionStmt):
self.generate(node.expr)
def gen_IfStmt(self, node: IfStmt):
cond = self.generate(node.condition)
if node.else_branch:
else_label = self.new_label()
end_label = self.new_label()
self.emit(IROpcode.JUMP_IF_FALSE, None, cond, else_label)
for stmt in node.then_branch:
self.generate(stmt)
self.emit(IROpcode.JUMP, None, end_label)
self.emit(IROpcode.LABEL, else_label)
for stmt in node.else_branch:
self.generate(stmt)
self.emit(IROpcode.LABEL, end_label)
else:
end_label = self.new_label()
self.emit(IROpcode.JUMP_IF_FALSE, None, cond, end_label)
for stmt in node.then_branch:
self.generate(stmt)
self.emit(IROpcode.LABEL, end_label)
def gen_WhileStmt(self, node: WhileStmt):
start_label = self.new_label()
end_label = self.new_label()
self.emit(IROpcode.LABEL, start_label)
cond = self.generate(node.condition)
self.emit(IROpcode.JUMP_IF_FALSE, None, cond, end_label)
for stmt in node.body:
self.generate(stmt)
self.emit(IROpcode.JUMP, None, start_label)
self.emit(IROpcode.LABEL, end_label)
def gen_Call(self, node: Call) -> str:
# Generate arguments
arg_temps = []
for arg in node.arguments:
arg_temps.append(self.generate(arg))
# Push arguments
for temp in arg_temps:
self.emit(IROpcode.PARAM, None, temp)
# Call function
result = self.new_temp()
if isinstance(node.callee, Identifier):
self.emit(IROpcode.CALL, result, node.callee.name, len(arg_temps))
else:
func = self.generate(node.callee)
self.emit(IROpcode.CALL, result, func, len(arg_temps))
return result
def gen_Return(self, node: Return):
if node.value:
value = self.generate(node.value)
self.emit(IROpcode.RETURN, None, value)
else:
self.emit(IROpcode.RETURN)
def print_ir(self):
"""Print generated IR"""
for instr in self.instructions:
print(instr)
# Example
source = '''
x = 10
y = 20
if x < y:
z = x + y
'''
lexer = Lexer(source)
tokens = [t for t in lexer.tokenize() if t.type != TokenType.NEWLINE]
parser = Parser(tokens)
ast = parser.parse()
ir_gen = IRGenerator()
ir_gen.generate(ast)
print("Three-Address Code:")
ir_gen.print_ir()
2. Stack-Based Virtual Machine
Bytecode Definition
class Opcode(Enum):
# Stack operations
PUSH = auto() # Push constant
POP = auto() # Pop top
DUP = auto() # Duplicate top
# Arithmetic
ADD = auto()
SUB = auto()
MUL = auto()
DIV = auto()
NEG = auto()
# Comparison
LT = auto()
GT = auto()
EQ = auto()
NEQ = auto()
LEQ = auto()
GEQ = auto()
# Variables
LOAD = auto() # Load variable
STORE = auto() # Store variable
LOAD_GLOBAL = auto()
STORE_GLOBAL = auto()
# Control flow
JUMP = auto()
JUMP_IF_TRUE = auto()
JUMP_IF_FALSE = auto()
# Functions
CALL = auto()
RETURN = auto()
# Other
PRINT = auto()
HALT = auto()
@dataclass
class Bytecode:
opcode: Opcode
operand: any = None
def __repr__(self):
if self.operand is not None:
return f"{self.opcode.name} {self.operand}"
return self.opcode.name
Bytecode Generator
class BytecodeGenerator:
"""Generate bytecode from AST"""
def __init__(self):
self.code: List[Bytecode] = []
self.constants: List[any] = []
self.variables: Dict[str, int] = {} # Variable name -> slot index
self.var_counter = 0
def emit(self, opcode: Opcode, operand=None):
self.code.append(Bytecode(opcode, operand))
return len(self.code) - 1 # Return instruction index
def add_constant(self, value) -> int:
"""Add constant to pool and return index"""
if value in self.constants:
return self.constants.index(value)
self.constants.append(value)
return len(self.constants) - 1
def get_var_slot(self, name: str) -> int:
"""Get or create variable slot"""
if name not in self.variables:
self.variables[name] = self.var_counter
self.var_counter += 1
return self.variables[name]
def current_addr(self) -> int:
"""Get current code address"""
return len(self.code)
def patch_jump(self, addr: int, target: int):
"""Patch jump instruction with target address"""
self.code[addr].operand = target
def generate(self, node: AST):
method_name = f'gen_{type(node).__name__}'
generator = getattr(self, method_name, self.gen_default)
return generator(node)
def gen_default(self, node):
raise NotImplementedError(f"No generator for {type(node).__name__}")
def gen_Program(self, node: Program):
for stmt in node.statements:
self.generate(stmt)
self.emit(Opcode.HALT)
def gen_NumberLiteral(self, node: NumberLiteral):
idx = self.add_constant(node.value)
self.emit(Opcode.PUSH, idx)
def gen_StringLiteral(self, node: StringLiteral):
idx = self.add_constant(node.value)
self.emit(Opcode.PUSH, idx)
def gen_BoolLiteral(self, node: BoolLiteral):
idx = self.add_constant(1 if node.value else 0)
self.emit(Opcode.PUSH, idx)
def gen_Identifier(self, node: Identifier):
slot = self.get_var_slot(node.name)
self.emit(Opcode.LOAD, slot)
def gen_BinaryOp(self, node: BinaryOp):
self.generate(node.left)
self.generate(node.right)
op_map = {
'+': Opcode.ADD,
'-': Opcode.SUB,
'*': Opcode.MUL,
'/': Opcode.DIV,
'<': Opcode.LT,
'>': Opcode.GT,
'==': Opcode.EQ,
'!=': Opcode.NEQ,
}
self.emit(op_map[node.op])
def gen_UnaryOp(self, node: UnaryOp):
self.generate(node.operand)
if node.op == '-':
self.emit(Opcode.NEG)
def gen_Assignment(self, node: Assignment):
self.generate(node.value)
slot = self.get_var_slot(node.target.name)
self.emit(Opcode.STORE, slot)
def gen_ExpressionStmt(self, node: ExpressionStmt):
self.generate(node.expr)
self.emit(Opcode.POP)
def gen_IfStmt(self, node: IfStmt):
self.generate(node.condition)
if node.else_branch:
jump_else = self.emit(Opcode.JUMP_IF_FALSE, 0) # Placeholder
for stmt in node.then_branch:
self.generate(stmt)
jump_end = self.emit(Opcode.JUMP, 0) # Placeholder
else_addr = self.current_addr()
for stmt in node.else_branch:
self.generate(stmt)
end_addr = self.current_addr()
self.patch_jump(jump_else, else_addr)
self.patch_jump(jump_end, end_addr)
else:
jump_end = self.emit(Opcode.JUMP_IF_FALSE, 0)
for stmt in node.then_branch:
self.generate(stmt)
self.patch_jump(jump_end, self.current_addr())
def gen_WhileStmt(self, node: WhileStmt):
start_addr = self.current_addr()
self.generate(node.condition)
jump_end = self.emit(Opcode.JUMP_IF_FALSE, 0)
for stmt in node.body:
self.generate(stmt)
self.emit(Opcode.JUMP, start_addr)
self.patch_jump(jump_end, self.current_addr())
def gen_Call(self, node: Call):
if isinstance(node.callee, Identifier) and node.callee.name == 'print':
# Special handling for print
for arg in node.arguments:
self.generate(arg)
self.emit(Opcode.PRINT)
# Push None as return value
self.emit(Opcode.PUSH, self.add_constant(None))
return
# General function call
for arg in node.arguments:
self.generate(arg)
if isinstance(node.callee, Identifier):
func_idx = self.add_constant(node.callee.name)
self.emit(Opcode.CALL, (func_idx, len(node.arguments)))
def disassemble(self):
"""Print bytecode disassembly"""
print("Constants:", self.constants)
print("Variables:", self.variables)
print("\nBytecode:")
for i, instr in enumerate(self.code):
print(f" {i:4d}: {instr}")
Virtual Machine
class VM:
"""Stack-based virtual machine"""
def __init__(self, bytecode: List[Bytecode], constants: List[any]):
self.code = bytecode
self.constants = constants
self.stack: List[any] = []
self.variables: Dict[int, any] = {}
self.ip = 0 # Instruction pointer
def push(self, value):
self.stack.append(value)
def pop(self):
return self.stack.pop()
def peek(self):
return self.stack[-1]
def run(self):
"""Execute bytecode"""
while self.ip < len(self.code):
instr = self.code[self.ip]
self.ip += 1
if instr.opcode == Opcode.HALT:
break
elif instr.opcode == Opcode.PUSH:
self.push(self.constants[instr.operand])
elif instr.opcode == Opcode.POP:
self.pop()
elif instr.opcode == Opcode.DUP:
self.push(self.peek())
elif instr.opcode == Opcode.LOAD:
slot = instr.operand
self.push(self.variables.get(slot, 0))
elif instr.opcode == Opcode.STORE:
slot = instr.operand
self.variables[slot] = self.pop()
elif instr.opcode == Opcode.ADD:
b, a = self.pop(), self.pop()
self.push(a + b)
elif instr.opcode == Opcode.SUB:
b, a = self.pop(), self.pop()
self.push(a - b)
elif instr.opcode == Opcode.MUL:
b, a = self.pop(), self.pop()
self.push(a * b)
elif instr.opcode == Opcode.DIV:
b, a = self.pop(), self.pop()
self.push(a / b)
elif instr.opcode == Opcode.NEG:
self.push(-self.pop())
elif instr.opcode == Opcode.LT:
b, a = self.pop(), self.pop()
self.push(1 if a < b else 0)
elif instr.opcode == Opcode.GT:
b, a = self.pop(), self.pop()
self.push(1 if a > b else 0)
elif instr.opcode == Opcode.EQ:
b, a = self.pop(), self.pop()
self.push(1 if a == b else 0)
elif instr.opcode == Opcode.NEQ:
b, a = self.pop(), self.pop()
self.push(1 if a != b else 0)
elif instr.opcode == Opcode.JUMP:
self.ip = instr.operand
elif instr.opcode == Opcode.JUMP_IF_FALSE:
cond = self.pop()
if not cond:
self.ip = instr.operand
elif instr.opcode == Opcode.JUMP_IF_TRUE:
cond = self.pop()
if cond:
self.ip = instr.operand
elif instr.opcode == Opcode.PRINT:
print(self.pop())
return self.peek() if self.stack else None
# Complete example
def compile_and_run(source: str):
"""Compile source and run on VM"""
# Lex
lexer = Lexer(source)
tokens = [t for t in lexer.tokenize() if t.type != TokenType.NEWLINE]
# Parse
parser = Parser(tokens)
ast = parser.parse()
# Generate bytecode
codegen = BytecodeGenerator()
codegen.generate(ast)
print("=== Bytecode ===")
codegen.disassemble()
# Run
print("\n=== Output ===")
vm = VM(codegen.code, codegen.constants)
result = vm.run()
return result
# Test
source = '''
x = 10
y = 20
z = x + y * 2
print(z)
'''
compile_and_run(source)
3. Optimization Passes
class Optimizer:
"""Simple bytecode optimizer"""
def optimize(self, code: List[Bytecode], constants: List[any]):
"""Apply optimization passes"""
code = self.constant_folding(code, constants)
code = self.peephole(code)
return code
def constant_folding(self, code: List[Bytecode], constants: List[any]):
"""Fold constant expressions at compile time"""
result = []
i = 0
while i < len(code):
# Look for PUSH PUSH OP pattern
if (i + 2 < len(code) and
code[i].opcode == Opcode.PUSH and
code[i+1].opcode == Opcode.PUSH and
code[i+2].opcode in (Opcode.ADD, Opcode.SUB, Opcode.MUL, Opcode.DIV)):
a = constants[code[i].operand]
b = constants[code[i+1].operand]
op = code[i+2].opcode
# Compute constant result
if op == Opcode.ADD:
val = a + b
elif op == Opcode.SUB:
val = a - b
elif op == Opcode.MUL:
val = a * b
elif op == Opcode.DIV:
val = a / b
# Add result to constants
if val not in constants:
constants.append(val)
idx = constants.index(val)
result.append(Bytecode(Opcode.PUSH, idx))
i += 3
else:
result.append(code[i])
i += 1
return result
def peephole(self, code: List[Bytecode]):
"""Simple peephole optimizations"""
result = []
i = 0
while i < len(code):
# PUSH followed by POP - remove both
if (i + 1 < len(code) and
code[i].opcode == Opcode.PUSH and
code[i+1].opcode == Opcode.POP):
i += 2
continue
# JUMP to next instruction - remove
if (code[i].opcode == Opcode.JUMP and
code[i].operand == i + 1):
i += 1
continue
result.append(code[i])
i += 1
return result
Summary
- Three-address code is a common IR with simple instructions
- Stack-based VMs use a stack for operands (simpler code generation)
- Bytecode is a compact, portable intermediate format
- Instruction selection maps IR to target instructions
- Optimizations can be applied at IR or bytecode level
Module Complete
This completes the Compilers module! You've learned:
- Lexical analysis with regular expressions and finite automata
- Parsing with recursive descent and precedence climbing
- Semantic analysis with type checking
- Code generation for a stack-based VM
These concepts form the foundation for understanding programming language implementation.