Code Generation

Introduction

Code generation is the final phase of compilation. The code generator transforms the AST (often after optimization) into executable code. This reading covers intermediate representations, instruction selection, register allocation, and generating code for a stack-based virtual machine.

Learning Objectives

By the end of this reading, you will be able to:

  • Design intermediate representations (IR)
  • Generate code for a stack-based VM
  • Implement a simple bytecode interpreter
  • Understand register allocation basics
  • Generate assembly code (conceptually)

1. Intermediate Representations

Why IR?

  • Decouple front-end (parsing) from back-end (code generation)
  • Enable machine-independent optimizations
  • Support multiple source languages and target platforms

Three-Address Code

from dataclasses import dataclass
from typing import Union, List, Optional
from enum import Enum, auto

class IROpcode(Enum):
    # Arithmetic
    ADD = auto()
    SUB = auto()
    MUL = auto()
    DIV = auto()
    NEG = auto()

    # Comparison
    LT = auto()
    GT = auto()
    EQ = auto()
    NEQ = auto()

    # Data movement
    COPY = auto()
    LOAD_CONST = auto()
    LOAD_VAR = auto()
    STORE_VAR = auto()

    # Control flow
    LABEL = auto()
    JUMP = auto()
    JUMP_IF_TRUE = auto()
    JUMP_IF_FALSE = auto()

    # Functions
    CALL = auto()
    RETURN = auto()
    PARAM = auto()

@dataclass
class IRInstruction:
    opcode: IROpcode
    dest: Optional[str] = None  # Destination (temporary or variable)
    arg1: Optional[Union[str, int, float]] = None
    arg2: Optional[Union[str, int, float]] = None

    def __repr__(self):
        if self.opcode == IROpcode.LABEL:
            return f"{self.dest}:"
        elif self.opcode == IROpcode.JUMP:
            return f"  jump {self.arg1}"
        elif self.opcode == IROpcode.JUMP_IF_FALSE:
            return f"  if not {self.arg1} goto {self.arg2}"
        elif self.opcode == IROpcode.LOAD_CONST:
            return f"  {self.dest} = {self.arg1}"
        elif self.opcode == IROpcode.COPY:
            return f"  {self.dest} = {self.arg1}"
        elif self.opcode == IROpcode.CALL:
            return f"  {self.dest} = call {self.arg1}, {self.arg2} args"
        elif self.opcode == IROpcode.RETURN:
            return f"  return {self.arg1 or ''}"
        elif self.opcode in (IROpcode.ADD, IROpcode.SUB, IROpcode.MUL, IROpcode.DIV):
            op = {IROpcode.ADD: '+', IROpcode.SUB: '-',
                  IROpcode.MUL: '*', IROpcode.DIV: '/'}[self.opcode]
            return f"  {self.dest} = {self.arg1} {op} {self.arg2}"
        elif self.opcode == IROpcode.NEG:
            return f"  {self.dest} = -{self.arg1}"
        elif self.opcode in (IROpcode.LT, IROpcode.GT, IROpcode.EQ, IROpcode.NEQ):
            op = {IROpcode.LT: '<', IROpcode.GT: '>',
                  IROpcode.EQ: '==', IROpcode.NEQ: '!='}[self.opcode]
            return f"  {self.dest} = {self.arg1} {op} {self.arg2}"
        else:
            return f"  {self.opcode.name} {self.dest} {self.arg1} {self.arg2}"

class IRGenerator:
    """Generate three-address code from AST"""

    def __init__(self):
        self.instructions: List[IRInstruction] = []
        self.temp_counter = 0
        self.label_counter = 0

    def new_temp(self) -> str:
        self.temp_counter += 1
        return f"t{self.temp_counter}"

    def new_label(self) -> str:
        self.label_counter += 1
        return f"L{self.label_counter}"

    def emit(self, opcode: IROpcode, dest=None, arg1=None, arg2=None):
        self.instructions.append(IRInstruction(opcode, dest, arg1, arg2))

    def generate(self, node: AST) -> str:
        """Generate IR for node, return temp holding result"""
        method_name = f'gen_{type(node).__name__}'
        generator = getattr(self, method_name, self.gen_default)
        return generator(node)

    def gen_default(self, node):
        raise NotImplementedError(f"No IR generator for {type(node).__name__}")

    def gen_Program(self, node: Program):
        for stmt in node.statements:
            self.generate(stmt)

    def gen_NumberLiteral(self, node: NumberLiteral) -> str:
        temp = self.new_temp()
        self.emit(IROpcode.LOAD_CONST, temp, node.value)
        return temp

    def gen_StringLiteral(self, node: StringLiteral) -> str:
        temp = self.new_temp()
        self.emit(IROpcode.LOAD_CONST, temp, node.value)
        return temp

    def gen_BoolLiteral(self, node: BoolLiteral) -> str:
        temp = self.new_temp()
        self.emit(IROpcode.LOAD_CONST, temp, 1 if node.value else 0)
        return temp

    def gen_Identifier(self, node: Identifier) -> str:
        temp = self.new_temp()
        self.emit(IROpcode.LOAD_VAR, temp, node.name)
        return temp

    def gen_BinaryOp(self, node: BinaryOp) -> str:
        left = self.generate(node.left)
        right = self.generate(node.right)
        result = self.new_temp()

        op_map = {
            '+': IROpcode.ADD,
            '-': IROpcode.SUB,
            '*': IROpcode.MUL,
            '/': IROpcode.DIV,
            '<': IROpcode.LT,
            '>': IROpcode.GT,
            '==': IROpcode.EQ,
            '!=': IROpcode.NEQ,
        }

        self.emit(op_map[node.op], result, left, right)
        return result

    def gen_UnaryOp(self, node: UnaryOp) -> str:
        operand = self.generate(node.operand)
        result = self.new_temp()

        if node.op == '-':
            self.emit(IROpcode.NEG, result, operand)
        return result

    def gen_Assignment(self, node: Assignment) -> str:
        value = self.generate(node.value)
        self.emit(IROpcode.STORE_VAR, node.target.name, value)
        return node.target.name

    def gen_ExpressionStmt(self, node: ExpressionStmt):
        self.generate(node.expr)

    def gen_IfStmt(self, node: IfStmt):
        cond = self.generate(node.condition)

        if node.else_branch:
            else_label = self.new_label()
            end_label = self.new_label()

            self.emit(IROpcode.JUMP_IF_FALSE, None, cond, else_label)

            for stmt in node.then_branch:
                self.generate(stmt)
            self.emit(IROpcode.JUMP, None, end_label)

            self.emit(IROpcode.LABEL, else_label)
            for stmt in node.else_branch:
                self.generate(stmt)

            self.emit(IROpcode.LABEL, end_label)
        else:
            end_label = self.new_label()
            self.emit(IROpcode.JUMP_IF_FALSE, None, cond, end_label)

            for stmt in node.then_branch:
                self.generate(stmt)

            self.emit(IROpcode.LABEL, end_label)

    def gen_WhileStmt(self, node: WhileStmt):
        start_label = self.new_label()
        end_label = self.new_label()

        self.emit(IROpcode.LABEL, start_label)

        cond = self.generate(node.condition)
        self.emit(IROpcode.JUMP_IF_FALSE, None, cond, end_label)

        for stmt in node.body:
            self.generate(stmt)

        self.emit(IROpcode.JUMP, None, start_label)
        self.emit(IROpcode.LABEL, end_label)

    def gen_Call(self, node: Call) -> str:
        # Generate arguments
        arg_temps = []
        for arg in node.arguments:
            arg_temps.append(self.generate(arg))

        # Push arguments
        for temp in arg_temps:
            self.emit(IROpcode.PARAM, None, temp)

        # Call function
        result = self.new_temp()
        if isinstance(node.callee, Identifier):
            self.emit(IROpcode.CALL, result, node.callee.name, len(arg_temps))
        else:
            func = self.generate(node.callee)
            self.emit(IROpcode.CALL, result, func, len(arg_temps))

        return result

    def gen_Return(self, node: Return):
        if node.value:
            value = self.generate(node.value)
            self.emit(IROpcode.RETURN, None, value)
        else:
            self.emit(IROpcode.RETURN)

    def print_ir(self):
        """Print generated IR"""
        for instr in self.instructions:
            print(instr)

# Example
source = '''
x = 10
y = 20
if x < y:
    z = x + y
'''

lexer = Lexer(source)
tokens = [t for t in lexer.tokenize() if t.type != TokenType.NEWLINE]
parser = Parser(tokens)
ast = parser.parse()

ir_gen = IRGenerator()
ir_gen.generate(ast)
print("Three-Address Code:")
ir_gen.print_ir()

2. Stack-Based Virtual Machine

Bytecode Definition

class Opcode(Enum):
    # Stack operations
    PUSH = auto()       # Push constant
    POP = auto()        # Pop top
    DUP = auto()        # Duplicate top

    # Arithmetic
    ADD = auto()
    SUB = auto()
    MUL = auto()
    DIV = auto()
    NEG = auto()

    # Comparison
    LT = auto()
    GT = auto()
    EQ = auto()
    NEQ = auto()
    LEQ = auto()
    GEQ = auto()

    # Variables
    LOAD = auto()       # Load variable
    STORE = auto()      # Store variable
    LOAD_GLOBAL = auto()
    STORE_GLOBAL = auto()

    # Control flow
    JUMP = auto()
    JUMP_IF_TRUE = auto()
    JUMP_IF_FALSE = auto()

    # Functions
    CALL = auto()
    RETURN = auto()

    # Other
    PRINT = auto()
    HALT = auto()

@dataclass
class Bytecode:
    opcode: Opcode
    operand: any = None

    def __repr__(self):
        if self.operand is not None:
            return f"{self.opcode.name} {self.operand}"
        return self.opcode.name

Bytecode Generator

class BytecodeGenerator:
    """Generate bytecode from AST"""

    def __init__(self):
        self.code: List[Bytecode] = []
        self.constants: List[any] = []
        self.variables: Dict[str, int] = {}  # Variable name -> slot index
        self.var_counter = 0

    def emit(self, opcode: Opcode, operand=None):
        self.code.append(Bytecode(opcode, operand))
        return len(self.code) - 1  # Return instruction index

    def add_constant(self, value) -> int:
        """Add constant to pool and return index"""
        if value in self.constants:
            return self.constants.index(value)
        self.constants.append(value)
        return len(self.constants) - 1

    def get_var_slot(self, name: str) -> int:
        """Get or create variable slot"""
        if name not in self.variables:
            self.variables[name] = self.var_counter
            self.var_counter += 1
        return self.variables[name]

    def current_addr(self) -> int:
        """Get current code address"""
        return len(self.code)

    def patch_jump(self, addr: int, target: int):
        """Patch jump instruction with target address"""
        self.code[addr].operand = target

    def generate(self, node: AST):
        method_name = f'gen_{type(node).__name__}'
        generator = getattr(self, method_name, self.gen_default)
        return generator(node)

    def gen_default(self, node):
        raise NotImplementedError(f"No generator for {type(node).__name__}")

    def gen_Program(self, node: Program):
        for stmt in node.statements:
            self.generate(stmt)
        self.emit(Opcode.HALT)

    def gen_NumberLiteral(self, node: NumberLiteral):
        idx = self.add_constant(node.value)
        self.emit(Opcode.PUSH, idx)

    def gen_StringLiteral(self, node: StringLiteral):
        idx = self.add_constant(node.value)
        self.emit(Opcode.PUSH, idx)

    def gen_BoolLiteral(self, node: BoolLiteral):
        idx = self.add_constant(1 if node.value else 0)
        self.emit(Opcode.PUSH, idx)

    def gen_Identifier(self, node: Identifier):
        slot = self.get_var_slot(node.name)
        self.emit(Opcode.LOAD, slot)

    def gen_BinaryOp(self, node: BinaryOp):
        self.generate(node.left)
        self.generate(node.right)

        op_map = {
            '+': Opcode.ADD,
            '-': Opcode.SUB,
            '*': Opcode.MUL,
            '/': Opcode.DIV,
            '<': Opcode.LT,
            '>': Opcode.GT,
            '==': Opcode.EQ,
            '!=': Opcode.NEQ,
        }
        self.emit(op_map[node.op])

    def gen_UnaryOp(self, node: UnaryOp):
        self.generate(node.operand)
        if node.op == '-':
            self.emit(Opcode.NEG)

    def gen_Assignment(self, node: Assignment):
        self.generate(node.value)
        slot = self.get_var_slot(node.target.name)
        self.emit(Opcode.STORE, slot)

    def gen_ExpressionStmt(self, node: ExpressionStmt):
        self.generate(node.expr)
        self.emit(Opcode.POP)

    def gen_IfStmt(self, node: IfStmt):
        self.generate(node.condition)

        if node.else_branch:
            jump_else = self.emit(Opcode.JUMP_IF_FALSE, 0)  # Placeholder

            for stmt in node.then_branch:
                self.generate(stmt)
            jump_end = self.emit(Opcode.JUMP, 0)  # Placeholder

            else_addr = self.current_addr()
            for stmt in node.else_branch:
                self.generate(stmt)

            end_addr = self.current_addr()
            self.patch_jump(jump_else, else_addr)
            self.patch_jump(jump_end, end_addr)
        else:
            jump_end = self.emit(Opcode.JUMP_IF_FALSE, 0)

            for stmt in node.then_branch:
                self.generate(stmt)

            self.patch_jump(jump_end, self.current_addr())

    def gen_WhileStmt(self, node: WhileStmt):
        start_addr = self.current_addr()

        self.generate(node.condition)
        jump_end = self.emit(Opcode.JUMP_IF_FALSE, 0)

        for stmt in node.body:
            self.generate(stmt)

        self.emit(Opcode.JUMP, start_addr)
        self.patch_jump(jump_end, self.current_addr())

    def gen_Call(self, node: Call):
        if isinstance(node.callee, Identifier) and node.callee.name == 'print':
            # Special handling for print
            for arg in node.arguments:
                self.generate(arg)
                self.emit(Opcode.PRINT)
            # Push None as return value
            self.emit(Opcode.PUSH, self.add_constant(None))
            return

        # General function call
        for arg in node.arguments:
            self.generate(arg)

        if isinstance(node.callee, Identifier):
            func_idx = self.add_constant(node.callee.name)
            self.emit(Opcode.CALL, (func_idx, len(node.arguments)))

    def disassemble(self):
        """Print bytecode disassembly"""
        print("Constants:", self.constants)
        print("Variables:", self.variables)
        print("\nBytecode:")
        for i, instr in enumerate(self.code):
            print(f"  {i:4d}: {instr}")

Virtual Machine

class VM:
    """Stack-based virtual machine"""

    def __init__(self, bytecode: List[Bytecode], constants: List[any]):
        self.code = bytecode
        self.constants = constants
        self.stack: List[any] = []
        self.variables: Dict[int, any] = {}
        self.ip = 0  # Instruction pointer

    def push(self, value):
        self.stack.append(value)

    def pop(self):
        return self.stack.pop()

    def peek(self):
        return self.stack[-1]

    def run(self):
        """Execute bytecode"""
        while self.ip < len(self.code):
            instr = self.code[self.ip]
            self.ip += 1

            if instr.opcode == Opcode.HALT:
                break

            elif instr.opcode == Opcode.PUSH:
                self.push(self.constants[instr.operand])

            elif instr.opcode == Opcode.POP:
                self.pop()

            elif instr.opcode == Opcode.DUP:
                self.push(self.peek())

            elif instr.opcode == Opcode.LOAD:
                slot = instr.operand
                self.push(self.variables.get(slot, 0))

            elif instr.opcode == Opcode.STORE:
                slot = instr.operand
                self.variables[slot] = self.pop()

            elif instr.opcode == Opcode.ADD:
                b, a = self.pop(), self.pop()
                self.push(a + b)

            elif instr.opcode == Opcode.SUB:
                b, a = self.pop(), self.pop()
                self.push(a - b)

            elif instr.opcode == Opcode.MUL:
                b, a = self.pop(), self.pop()
                self.push(a * b)

            elif instr.opcode == Opcode.DIV:
                b, a = self.pop(), self.pop()
                self.push(a / b)

            elif instr.opcode == Opcode.NEG:
                self.push(-self.pop())

            elif instr.opcode == Opcode.LT:
                b, a = self.pop(), self.pop()
                self.push(1 if a < b else 0)

            elif instr.opcode == Opcode.GT:
                b, a = self.pop(), self.pop()
                self.push(1 if a > b else 0)

            elif instr.opcode == Opcode.EQ:
                b, a = self.pop(), self.pop()
                self.push(1 if a == b else 0)

            elif instr.opcode == Opcode.NEQ:
                b, a = self.pop(), self.pop()
                self.push(1 if a != b else 0)

            elif instr.opcode == Opcode.JUMP:
                self.ip = instr.operand

            elif instr.opcode == Opcode.JUMP_IF_FALSE:
                cond = self.pop()
                if not cond:
                    self.ip = instr.operand

            elif instr.opcode == Opcode.JUMP_IF_TRUE:
                cond = self.pop()
                if cond:
                    self.ip = instr.operand

            elif instr.opcode == Opcode.PRINT:
                print(self.pop())

        return self.peek() if self.stack else None

# Complete example
def compile_and_run(source: str):
    """Compile source and run on VM"""
    # Lex
    lexer = Lexer(source)
    tokens = [t for t in lexer.tokenize() if t.type != TokenType.NEWLINE]

    # Parse
    parser = Parser(tokens)
    ast = parser.parse()

    # Generate bytecode
    codegen = BytecodeGenerator()
    codegen.generate(ast)

    print("=== Bytecode ===")
    codegen.disassemble()

    # Run
    print("\n=== Output ===")
    vm = VM(codegen.code, codegen.constants)
    result = vm.run()

    return result

# Test
source = '''
x = 10
y = 20
z = x + y * 2
print(z)
'''

compile_and_run(source)

3. Optimization Passes

class Optimizer:
    """Simple bytecode optimizer"""

    def optimize(self, code: List[Bytecode], constants: List[any]):
        """Apply optimization passes"""
        code = self.constant_folding(code, constants)
        code = self.peephole(code)
        return code

    def constant_folding(self, code: List[Bytecode], constants: List[any]):
        """Fold constant expressions at compile time"""
        result = []
        i = 0

        while i < len(code):
            # Look for PUSH PUSH OP pattern
            if (i + 2 < len(code) and
                code[i].opcode == Opcode.PUSH and
                code[i+1].opcode == Opcode.PUSH and
                code[i+2].opcode in (Opcode.ADD, Opcode.SUB, Opcode.MUL, Opcode.DIV)):

                a = constants[code[i].operand]
                b = constants[code[i+1].operand]
                op = code[i+2].opcode

                # Compute constant result
                if op == Opcode.ADD:
                    val = a + b
                elif op == Opcode.SUB:
                    val = a - b
                elif op == Opcode.MUL:
                    val = a * b
                elif op == Opcode.DIV:
                    val = a / b

                # Add result to constants
                if val not in constants:
                    constants.append(val)
                idx = constants.index(val)

                result.append(Bytecode(Opcode.PUSH, idx))
                i += 3
            else:
                result.append(code[i])
                i += 1

        return result

    def peephole(self, code: List[Bytecode]):
        """Simple peephole optimizations"""
        result = []
        i = 0

        while i < len(code):
            # PUSH followed by POP - remove both
            if (i + 1 < len(code) and
                code[i].opcode == Opcode.PUSH and
                code[i+1].opcode == Opcode.POP):
                i += 2
                continue

            # JUMP to next instruction - remove
            if (code[i].opcode == Opcode.JUMP and
                code[i].operand == i + 1):
                i += 1
                continue

            result.append(code[i])
            i += 1

        return result

Summary

  • Three-address code is a common IR with simple instructions
  • Stack-based VMs use a stack for operands (simpler code generation)
  • Bytecode is a compact, portable intermediate format
  • Instruction selection maps IR to target instructions
  • Optimizations can be applied at IR or bytecode level

Module Complete

This completes the Compilers module! You've learned:

  1. Lexical analysis with regular expressions and finite automata
  2. Parsing with recursive descent and precedence climbing
  3. Semantic analysis with type checking
  4. Code generation for a stack-based VM

These concepts form the foundation for understanding programming language implementation.

← Previous: Semantic Analysis | Back to Course Index