Code Generation
Code generation is the final phase of compilation. The code generator transforms the AST (often after optimization) into executable code. ## 1. Intermediate Representations
Why IR?
- Decouple front-end (parsing) from back-end (code generation)
- Enable machine-independent optimizations
- Support multiple source languages and target platforms
Three-Address Code
from dataclasses import dataclass
from typing import Union, List, Optional
from enum import Enum, auto
class IROpcode(Enum):
# Arithmetic
ADD = auto()
SUB = auto()
MUL = auto()
DIV = auto()
NEG = auto()
# Comparison
LT = auto()
GT = auto()
EQ = auto()
NEQ = auto()
# Data movement
COPY = auto()
LOAD_CONST = auto()
LOAD_VAR = auto()
STORE_VAR = auto()
# Control flow
LABEL = auto()
JUMP = auto()
JUMP_IF_TRUE = auto()
JUMP_IF_FALSE = auto()
# Functions
CALL = auto()
RETURN = auto()
PARAM = auto()
@dataclass
class IRInstruction:
opcode: IROpcode
dest: Optional[str] = None # Destination (temporary or variable)
arg1: Optional[Union[str, int, float]] = None
arg2: Optional[Union[str, int, float]] = None
def __repr__(self):
if self.opcode == IROpcode.LABEL:
return f"{self.dest}:"
elif self.opcode == IROpcode.JUMP:
return f" jump {self.arg1}"
elif self.opcode == IROpcode.JUMP_IF_FALSE:
return f" if not {self.arg1} goto {self.arg2}"
elif self.opcode == IROpcode.LOAD_CONST:
return f" {self.dest} = {self.arg1}"
elif self.opcode == IROpcode.COPY:
return f" {self.dest} = {self.arg1}"
elif self.opcode == IROpcode.CALL:
return f" {self.dest} = call {self.arg1}, {self.arg2} args"
elif self.opcode == IROpcode.RETURN:
return f" return {self.arg1 or ''}"
elif self.opcode in (IROpcode.ADD, IROpcode.SUB, IROpcode.MUL, IROpcode.DIV):
op = {IROpcode.ADD: '+', IROpcode.SUB: '-',
IROpcode.MUL: '*', IROpcode.DIV: '/'}[self.opcode]
return f" {self.dest} = {self.arg1} {op} {self.arg2}"
elif self.opcode == IROpcode.NEG:
return f" {self.dest} = -{self.arg1}"
elif self.opcode in (IROpcode.LT, IROpcode.GT, IROpcode.EQ, IROpcode.NEQ):
op = {IROpcode.LT: '<', IROpcode.GT: '>',
IROpcode.EQ: '==', IROpcode.NEQ: '!='}[self.opcode]
return f" {self.dest} = {self.arg1} {op} {self.arg2}"
else:
return f" {self.opcode.name} {self.dest} {self.arg1} {self.arg2}"
class IRGenerator:
"""Generate three-address code from AST"""
def __init__(self):
self.instructions: List[IRInstruction] = []
self.temp_counter = 0
self.label_counter = 0
def new_temp(self) -> str:
self.temp_counter += 1
return f"t{self.temp_counter}"
def new_label(self) -> str:
self.label_counter += 1
return f"L{self.label_counter}"
def emit(self, opcode: IROpcode, dest=None, arg1=None, arg2=None):
self.instructions.append(IRInstruction(opcode, dest, arg1, arg2))
def generate(self, node: AST) -> str:
"""Generate IR for node, return temp holding result"""
method_name = f'gen_{type(node).__name__}'
generator = getattr(self, method_name, self.gen_default)
return generator(node)
def gen_default(self, node):
raise NotImplementedError(f"No IR generator for {type(node).__name__}")
def gen_Program(self, node: Program):
for stmt in node.statements:
self.generate(stmt)
def gen_NumberLiteral(self, node: NumberLiteral) -> str:
temp = self.new_temp()
self.emit(IROpcode.LOAD_CONST, temp, node.value)
return temp
def gen_StringLiteral(self, node: StringLiteral) -> str:
temp = self.new_temp()
self.emit(IROpcode.LOAD_CONST, temp, node.value)
return temp
def gen_BoolLiteral(self, node: BoolLiteral) -> str:
temp = self.new_temp()
self.emit(IROpcode.LOAD_CONST, temp, 1 if node.value else 0)
return temp
def gen_Identifier(self, node: Identifier) -> str:
temp = self.new_temp()
self.emit(IROpcode.LOAD_VAR, temp, node.name)
return temp
def gen_BinaryOp(self, node: BinaryOp) -> str:
left = self.generate(node.left)
right = self.generate(node.right)
result = self.new_temp()
op_map = {
'+': IROpcode.ADD,
'-': IROpcode.SUB,
'*': IROpcode.MUL,
'/': IROpcode.DIV,
'<': IROpcode.LT,
'>': IROpcode.GT,
'==': IROpcode.EQ,
'!=': IROpcode.NEQ,
}
self.emit(op_map[node.op], result, left, right)
return result
def gen_UnaryOp(self, node: UnaryOp) -> str:
operand = self.generate(node.operand)
result = self.new_temp()
if node.op == '-':
self.emit(IROpcode.NEG, result, operand)
return result
def gen_Assignment(self, node: Assignment) -> str:
value = self.generate(node.value)
self.emit(IROpcode.STORE_VAR, node.target.name, value)
return node.target.name
def gen_ExpressionStmt(self, node: ExpressionStmt):
self.generate(node.expr)
def gen_IfStmt(self, node: IfStmt):
cond = self.generate(node.condition)
if node.else_branch:
else_label = self.new_label()
end_label = self.new_label()
self.emit(IROpcode.JUMP_IF_FALSE, None, cond, else_label)
for stmt in node.then_branch:
self.generate(stmt)
self.emit(IROpcode.JUMP, None, end_label)
self.emit(IROpcode.LABEL, else_label)
for stmt in node.else_branch:
self.generate(stmt)
self.emit(IROpcode.LABEL, end_label)
else:
end_label = self.new_label()
self.emit(IROpcode.JUMP_IF_FALSE, None, cond, end_label)
for stmt in node.then_branch:
self.generate(stmt)
self.emit(IROpcode.LABEL, end_label)
def gen_WhileStmt(self, node: WhileStmt):
start_label = self.new_label()
end_label = self.new_label()
self.emit(IROpcode.LABEL, start_label)
cond = self.generate(node.condition)
self.emit(IROpcode.JUMP_IF_FALSE, None, cond, end_label)
for stmt in node.body:
self.generate(stmt)
self.emit(IROpcode.JUMP, None, start_label)
self.emit(IROpcode.LABEL, end_label)
def gen_Call(self, node: Call) -> str:
# Generate arguments
arg_temps = []
for arg in node.arguments:
arg_temps.append(self.generate(arg))
# Push arguments
for temp in arg_temps:
self.emit(IROpcode.PARAM, None, temp)
# Call function
result = self.new_temp()
if isinstance(node.callee, Identifier):
self.emit(IROpcode.CALL, result, node.callee.name, len(arg_temps))
else:
func = self.generate(node.callee)
self.emit(IROpcode.CALL, result, func, len(arg_temps))
return result
def gen_Return(self, node: Return):
if node.value:
value = self.generate(node.value)
self.emit(IROpcode.RETURN, None, value)
else:
self.emit(IROpcode.RETURN)
def print_ir(self):
"""Print generated IR"""
for instr in self.instructions:
print(instr)
# Example
source = '''
x = 10
y = 20
if x < y:
z = x + y
'''
lexer = Lexer(source)
tokens = [t for t in lexer.tokenize() if t.type != TokenType.NEWLINE]
parser = Parser(tokens)
ast = parser.parse()
ir_gen = IRGenerator()
ir_gen.generate(ast)
print("Three-Address Code:")
ir_gen.print_ir()
2. Stack-Based Virtual Machine
Bytecode Definition
class Opcode(Enum):
# Stack operations
PUSH = auto() # Push constant
POP = auto() # Pop top
DUP = auto() # Duplicate top
# Arithmetic
ADD = auto()
SUB = auto()
MUL = auto()
DIV = auto()
NEG = auto()
# Comparison
LT = auto()
GT = auto()
EQ = auto()
NEQ = auto()
LEQ = auto()
GEQ = auto()
# Variables
LOAD = auto() # Load variable
STORE = auto() # Store variable
LOAD_GLOBAL = auto()
STORE_GLOBAL = auto()
# Control flow
JUMP = auto()
JUMP_IF_TRUE = auto()
JUMP_IF_FALSE = auto()
# Functions
CALL = auto()
RETURN = auto()
# Other
PRINT = auto()
HALT = auto()
@dataclass
class Bytecode:
opcode: Opcode
operand: any = None
def __repr__(self):
if self.operand is not None:
return f"{self.opcode.name} {self.operand}"
return self.opcode.name
Bytecode Generator
class BytecodeGenerator:
"""Generate bytecode from AST"""
def __init__(self):
self.code: List[Bytecode] = []
self.constants: List[any] = []
self.variables: Dict[str, int] = {} # Variable name -> slot index
self.var_counter = 0
def emit(self, opcode: Opcode, operand=None):
self.code.append(Bytecode(opcode, operand))
return len(self.code) - 1 # Return instruction index
def add_constant(self, value) -> int:
"""Add constant to pool and return index"""
if value in self.constants:
return self.constants.index(value)
self.constants.append(value)
return len(self.constants) - 1
def get_var_slot(self, name: str) -> int:
"""Get or create variable slot"""
if name not in self.variables:
self.variables[name] = self.var_counter
self.var_counter += 1
return self.variables[name]
def current_addr(self) -> int:
"""Get current code address"""
return len(self.code)
def patch_jump(self, addr: int, target: int):
"""Patch jump instruction with target address"""
self.code[addr].operand = target
def generate(self, node: AST):
method_name = f'gen_{type(node).__name__}'
generator = getattr(self, method_name, self.gen_default)
return generator(node)
def gen_default(self, node):
raise NotImplementedError(f"No generator for {type(node).__name__}")
def gen_Program(self, node: Program):
for stmt in node.statements:
self.generate(stmt)
self.emit(Opcode.HALT)
def gen_NumberLiteral(self, node: NumberLiteral):
idx = self.add_constant(node.value)
self.emit(Opcode.PUSH, idx)
def gen_StringLiteral(self, node: StringLiteral):
idx = self.add_constant(node.value)
self.emit(Opcode.PUSH, idx)
def gen_BoolLiteral(self, node: BoolLiteral):
idx = self.add_constant(1 if node.value else 0)
self.emit(Opcode.PUSH, idx)
def gen_Identifier(self, node: Identifier):
slot = self.get_var_slot(node.name)
self.emit(Opcode.LOAD, slot)
def gen_BinaryOp(self, node: BinaryOp):
self.generate(node.left)
self.generate(node.right)
op_map = {
'+': Opcode.ADD,
'-': Opcode.SUB,
'*': Opcode.MUL,
'/': Opcode.DIV,
'<': Opcode.LT,
'>': Opcode.GT,
'==': Opcode.EQ,
'!=': Opcode.NEQ,
}
self.emit(op_map[node.op])
def gen_UnaryOp(self, node: UnaryOp):
self.generate(node.operand)
if node.op == '-':
self.emit(Opcode.NEG)
def gen_Assignment(self, node: Assignment):
self.generate(node.value)
slot = self.get_var_slot(node.target.name)
self.emit(Opcode.STORE, slot)
def gen_ExpressionStmt(self, node: ExpressionStmt):
self.generate(node.expr)
self.emit(Opcode.POP)
def gen_IfStmt(self, node: IfStmt):
self.generate(node.condition)
if node.else_branch:
jump_else = self.emit(Opcode.JUMP_IF_FALSE, 0) # Placeholder
for stmt in node.then_branch:
self.generate(stmt)
jump_end = self.emit(Opcode.JUMP, 0) # Placeholder
else_addr = self.current_addr()
for stmt in node.else_branch:
self.generate(stmt)
end_addr = self.current_addr()
self.patch_jump(jump_else, else_addr)
self.patch_jump(jump_end, end_addr)
else:
jump_end = self.emit(Opcode.JUMP_IF_FALSE, 0)
for stmt in node.then_branch:
self.generate(stmt)
self.patch_jump(jump_end, self.current_addr())
def gen_WhileStmt(self, node: WhileStmt):
start_addr = self.current_addr()
self.generate(node.condition)
jump_end = self.emit(Opcode.JUMP_IF_FALSE, 0)
for stmt in node.body:
self.generate(stmt)
self.emit(Opcode.JUMP, start_addr)
self.patch_jump(jump_end, self.current_addr())
def gen_Call(self, node: Call):
if isinstance(node.callee, Identifier) and node.callee.name == 'print':
# Special handling for print
for arg in node.arguments:
self.generate(arg)
self.emit(Opcode.PRINT)
# Push None as return value
self.emit(Opcode.PUSH, self.add_constant(None))
return
# General function call
for arg in node.arguments:
self.generate(arg)
if isinstance(node.callee, Identifier):
func_idx = self.add_constant(node.callee.name)
self.emit(Opcode.CALL, (func_idx, len(node.arguments)))
def disassemble(self):
"""Print bytecode disassembly"""
print("Constants:", self.constants)
print("Variables:", self.variables)
print("\nBytecode:")
for i, instr in enumerate(self.code):
print(f" {i:4d}: {instr}")
Virtual Machine
class VM:
"""Stack-based virtual machine"""
def __init__(self, bytecode: List[Bytecode], constants: List[any]):
self.code = bytecode
self.constants = constants
self.stack: List[any] = []
self.variables: Dict[int, any] = {}
self.ip = 0 # Instruction pointer
def push(self, value):
self.stack.append(value)
def pop(self):
return self.stack.pop()
def peek(self):
return self.stack[-1]
def run(self):
"""Execute bytecode"""
while self.ip < len(self.code):
instr = self.code[self.ip]
self.ip += 1
if instr.opcode == Opcode.HALT:
break
elif instr.opcode == Opcode.PUSH:
self.push(self.constants[instr.operand])
elif instr.opcode == Opcode.POP:
self.pop()
elif instr.opcode == Opcode.DUP:
self.push(self.peek())
elif instr.opcode == Opcode.LOAD:
slot = instr.operand
self.push(self.variables.get(slot, 0))
elif instr.opcode == Opcode.STORE:
slot = instr.operand
self.variables[slot] = self.pop()
elif instr.opcode == Opcode.ADD:
b, a = self.pop(), self.pop()
self.push(a + b)
elif instr.opcode == Opcode.SUB:
b, a = self.pop(), self.pop()
self.push(a - b)
elif instr.opcode == Opcode.MUL:
b, a = self.pop(), self.pop()
self.push(a * b)
elif instr.opcode == Opcode.DIV:
b, a = self.pop(), self.pop()
self.push(a / b)
elif instr.opcode == Opcode.NEG:
self.push(-self.pop())
elif instr.opcode == Opcode.LT:
b, a = self.pop(), self.pop()
self.push(1 if a < b else 0)
elif instr.opcode == Opcode.GT:
b, a = self.pop(), self.pop()
self.push(1 if a > b else 0)
elif instr.opcode == Opcode.EQ:
b, a = self.pop(), self.pop()
self.push(1 if a == b else 0)
elif instr.opcode == Opcode.NEQ:
b, a = self.pop(), self.pop()
self.push(1 if a != b else 0)
elif instr.opcode == Opcode.JUMP:
self.ip = instr.operand
elif instr.opcode == Opcode.JUMP_IF_FALSE:
cond = self.pop()
if not cond:
self.ip = instr.operand
elif instr.opcode == Opcode.JUMP_IF_TRUE:
cond = self.pop()
if cond:
self.ip = instr.operand
elif instr.opcode == Opcode.PRINT:
print(self.pop())
return self.peek() if self.stack else None
# Complete example
def compile_and_run(source: str):
"""Compile source and run on VM"""
# Lex
lexer = Lexer(source)
tokens = [t for t in lexer.tokenize() if t.type != TokenType.NEWLINE]
# Parse
parser = Parser(tokens)
ast = parser.parse()
# Generate bytecode
codegen = BytecodeGenerator()
codegen.generate(ast)
print("=== Bytecode ===")
codegen.disassemble()
# Run
print("\n=== Output ===")
vm = VM(codegen.code, codegen.constants)
result = vm.run()
return result
# Test
source = '''
x = 10
y = 20
z = x + y * 2
print(z)
'''
compile_and_run(source)
3. Optimization Passes
class Optimizer:
"""Simple bytecode optimizer"""
def optimize(self, code: List[Bytecode], constants: List[any]):
"""Apply optimization passes"""
code = self.constant_folding(code, constants)
code = self.peephole(code)
return code
def constant_folding(self, code: List[Bytecode], constants: List[any]):
"""Fold constant expressions at compile time"""
result = []
i = 0
while i < len(code):
# Look for PUSH PUSH OP pattern
if (i + 2 < len(code) and
code[i].opcode == Opcode.PUSH and
code[i+1].opcode == Opcode.PUSH and
code[i+2].opcode in (Opcode.ADD, Opcode.SUB, Opcode.MUL, Opcode.DIV)):
a = constants[code[i].operand]
b = constants[code[i+1].operand]
op = code[i+2].opcode
# Compute constant result
if op == Opcode.ADD:
val = a + b
elif op == Opcode.SUB:
val = a - b
elif op == Opcode.MUL:
val = a * b
elif op == Opcode.DIV:
val = a / b
# Add result to constants
if val not in constants:
constants.append(val)
idx = constants.index(val)
result.append(Bytecode(Opcode.PUSH, idx))
i += 3
else:
result.append(code[i])
i += 1
return result
def peephole(self, code: List[Bytecode]):
"""Simple peephole optimizations"""
result = []
i = 0
while i < len(code):
# PUSH followed by POP - remove both
if (i + 1 < len(code) and
code[i].opcode == Opcode.PUSH and
code[i+1].opcode == Opcode.POP):
i += 2
continue
# JUMP to next instruction - remove
if (code[i].opcode == Opcode.JUMP and
code[i].operand == i + 1):
i += 1
continue
result.append(code[i])
i += 1
return result
Summary
- Three-address code is a common IR with simple instructions
- Stack-based VMs use a stack for operands (simpler code generation)
- Bytecode is a compact, portable intermediate format
- Instruction selection maps IR to target instructions
- Optimizations can be applied at IR or bytecode level
Module Complete
This completes the Compilers module! You've learned:
- Lexical analysis with regular expressions and finite automata
- Parsing with recursive descent and precedence climbing
- Semantic analysis with type checking
- Code generation for a stack-based VM
These concepts form the foundation for understanding programming language implementation.