Code Generation

Code generation is the final phase of compilation. The code generator transforms the AST (often after optimization) into executable code. ## 1. Intermediate Representations

Why IR?

  • Decouple front-end (parsing) from back-end (code generation)
  • Enable machine-independent optimizations
  • Support multiple source languages and target platforms

Three-Address Code

from dataclasses import dataclass
from typing import Union, List, Optional
from enum import Enum, auto

class IROpcode(Enum):
    # Arithmetic
    ADD = auto()
    SUB = auto()
    MUL = auto()
    DIV = auto()
    NEG = auto()

    # Comparison
    LT = auto()
    GT = auto()
    EQ = auto()
    NEQ = auto()

    # Data movement
    COPY = auto()
    LOAD_CONST = auto()
    LOAD_VAR = auto()
    STORE_VAR = auto()

    # Control flow
    LABEL = auto()
    JUMP = auto()
    JUMP_IF_TRUE = auto()
    JUMP_IF_FALSE = auto()

    # Functions
    CALL = auto()
    RETURN = auto()
    PARAM = auto()

@dataclass
class IRInstruction:
    opcode: IROpcode
    dest: Optional[str] = None  # Destination (temporary or variable)
    arg1: Optional[Union[str, int, float]] = None
    arg2: Optional[Union[str, int, float]] = None

    def __repr__(self):
        if self.opcode == IROpcode.LABEL:
            return f"{self.dest}:"
        elif self.opcode == IROpcode.JUMP:
            return f"  jump {self.arg1}"
        elif self.opcode == IROpcode.JUMP_IF_FALSE:
            return f"  if not {self.arg1} goto {self.arg2}"
        elif self.opcode == IROpcode.LOAD_CONST:
            return f"  {self.dest} = {self.arg1}"
        elif self.opcode == IROpcode.COPY:
            return f"  {self.dest} = {self.arg1}"
        elif self.opcode == IROpcode.CALL:
            return f"  {self.dest} = call {self.arg1}, {self.arg2} args"
        elif self.opcode == IROpcode.RETURN:
            return f"  return {self.arg1 or ''}"
        elif self.opcode in (IROpcode.ADD, IROpcode.SUB, IROpcode.MUL, IROpcode.DIV):
            op = {IROpcode.ADD: '+', IROpcode.SUB: '-',
                  IROpcode.MUL: '*', IROpcode.DIV: '/'}[self.opcode]
            return f"  {self.dest} = {self.arg1} {op} {self.arg2}"
        elif self.opcode == IROpcode.NEG:
            return f"  {self.dest} = -{self.arg1}"
        elif self.opcode in (IROpcode.LT, IROpcode.GT, IROpcode.EQ, IROpcode.NEQ):
            op = {IROpcode.LT: '<', IROpcode.GT: '>',
                  IROpcode.EQ: '==', IROpcode.NEQ: '!='}[self.opcode]
            return f"  {self.dest} = {self.arg1} {op} {self.arg2}"
        else:
            return f"  {self.opcode.name} {self.dest} {self.arg1} {self.arg2}"

class IRGenerator:
    """Generate three-address code from AST"""

    def __init__(self):
        self.instructions: List[IRInstruction] = []
        self.temp_counter = 0
        self.label_counter = 0

    def new_temp(self) -> str:
        self.temp_counter += 1
        return f"t{self.temp_counter}"

    def new_label(self) -> str:
        self.label_counter += 1
        return f"L{self.label_counter}"

    def emit(self, opcode: IROpcode, dest=None, arg1=None, arg2=None):
        self.instructions.append(IRInstruction(opcode, dest, arg1, arg2))

    def generate(self, node: AST) -> str:
        """Generate IR for node, return temp holding result"""
        method_name = f'gen_{type(node).__name__}'
        generator = getattr(self, method_name, self.gen_default)
        return generator(node)

    def gen_default(self, node):
        raise NotImplementedError(f"No IR generator for {type(node).__name__}")

    def gen_Program(self, node: Program):
        for stmt in node.statements:
            self.generate(stmt)

    def gen_NumberLiteral(self, node: NumberLiteral) -> str:
        temp = self.new_temp()
        self.emit(IROpcode.LOAD_CONST, temp, node.value)
        return temp

    def gen_StringLiteral(self, node: StringLiteral) -> str:
        temp = self.new_temp()
        self.emit(IROpcode.LOAD_CONST, temp, node.value)
        return temp

    def gen_BoolLiteral(self, node: BoolLiteral) -> str:
        temp = self.new_temp()
        self.emit(IROpcode.LOAD_CONST, temp, 1 if node.value else 0)
        return temp

    def gen_Identifier(self, node: Identifier) -> str:
        temp = self.new_temp()
        self.emit(IROpcode.LOAD_VAR, temp, node.name)
        return temp

    def gen_BinaryOp(self, node: BinaryOp) -> str:
        left = self.generate(node.left)
        right = self.generate(node.right)
        result = self.new_temp()

        op_map = {
            '+': IROpcode.ADD,
            '-': IROpcode.SUB,
            '*': IROpcode.MUL,
            '/': IROpcode.DIV,
            '<': IROpcode.LT,
            '>': IROpcode.GT,
            '==': IROpcode.EQ,
            '!=': IROpcode.NEQ,
        }

        self.emit(op_map[node.op], result, left, right)
        return result

    def gen_UnaryOp(self, node: UnaryOp) -> str:
        operand = self.generate(node.operand)
        result = self.new_temp()

        if node.op == '-':
            self.emit(IROpcode.NEG, result, operand)
        return result

    def gen_Assignment(self, node: Assignment) -> str:
        value = self.generate(node.value)
        self.emit(IROpcode.STORE_VAR, node.target.name, value)
        return node.target.name

    def gen_ExpressionStmt(self, node: ExpressionStmt):
        self.generate(node.expr)

    def gen_IfStmt(self, node: IfStmt):
        cond = self.generate(node.condition)

        if node.else_branch:
            else_label = self.new_label()
            end_label = self.new_label()

            self.emit(IROpcode.JUMP_IF_FALSE, None, cond, else_label)

            for stmt in node.then_branch:
                self.generate(stmt)
            self.emit(IROpcode.JUMP, None, end_label)

            self.emit(IROpcode.LABEL, else_label)
            for stmt in node.else_branch:
                self.generate(stmt)

            self.emit(IROpcode.LABEL, end_label)
        else:
            end_label = self.new_label()
            self.emit(IROpcode.JUMP_IF_FALSE, None, cond, end_label)

            for stmt in node.then_branch:
                self.generate(stmt)

            self.emit(IROpcode.LABEL, end_label)

    def gen_WhileStmt(self, node: WhileStmt):
        start_label = self.new_label()
        end_label = self.new_label()

        self.emit(IROpcode.LABEL, start_label)

        cond = self.generate(node.condition)
        self.emit(IROpcode.JUMP_IF_FALSE, None, cond, end_label)

        for stmt in node.body:
            self.generate(stmt)

        self.emit(IROpcode.JUMP, None, start_label)
        self.emit(IROpcode.LABEL, end_label)

    def gen_Call(self, node: Call) -> str:
        # Generate arguments
        arg_temps = []
        for arg in node.arguments:
            arg_temps.append(self.generate(arg))

        # Push arguments
        for temp in arg_temps:
            self.emit(IROpcode.PARAM, None, temp)

        # Call function
        result = self.new_temp()
        if isinstance(node.callee, Identifier):
            self.emit(IROpcode.CALL, result, node.callee.name, len(arg_temps))
        else:
            func = self.generate(node.callee)
            self.emit(IROpcode.CALL, result, func, len(arg_temps))

        return result

    def gen_Return(self, node: Return):
        if node.value:
            value = self.generate(node.value)
            self.emit(IROpcode.RETURN, None, value)
        else:
            self.emit(IROpcode.RETURN)

    def print_ir(self):
        """Print generated IR"""
        for instr in self.instructions:
            print(instr)

# Example
source = '''
x = 10
y = 20
if x < y:
    z = x + y
'''

lexer = Lexer(source)
tokens = [t for t in lexer.tokenize() if t.type != TokenType.NEWLINE]
parser = Parser(tokens)
ast = parser.parse()

ir_gen = IRGenerator()
ir_gen.generate(ast)
print("Three-Address Code:")
ir_gen.print_ir()

2. Stack-Based Virtual Machine

Bytecode Definition

class Opcode(Enum):
    # Stack operations
    PUSH = auto()       # Push constant
    POP = auto()        # Pop top
    DUP = auto()        # Duplicate top

    # Arithmetic
    ADD = auto()
    SUB = auto()
    MUL = auto()
    DIV = auto()
    NEG = auto()

    # Comparison
    LT = auto()
    GT = auto()
    EQ = auto()
    NEQ = auto()
    LEQ = auto()
    GEQ = auto()

    # Variables
    LOAD = auto()       # Load variable
    STORE = auto()      # Store variable
    LOAD_GLOBAL = auto()
    STORE_GLOBAL = auto()

    # Control flow
    JUMP = auto()
    JUMP_IF_TRUE = auto()
    JUMP_IF_FALSE = auto()

    # Functions
    CALL = auto()
    RETURN = auto()

    # Other
    PRINT = auto()
    HALT = auto()

@dataclass
class Bytecode:
    opcode: Opcode
    operand: any = None

    def __repr__(self):
        if self.operand is not None:
            return f"{self.opcode.name} {self.operand}"
        return self.opcode.name

Bytecode Generator

class BytecodeGenerator:
    """Generate bytecode from AST"""

    def __init__(self):
        self.code: List[Bytecode] = []
        self.constants: List[any] = []
        self.variables: Dict[str, int] = {}  # Variable name -> slot index
        self.var_counter = 0

    def emit(self, opcode: Opcode, operand=None):
        self.code.append(Bytecode(opcode, operand))
        return len(self.code) - 1  # Return instruction index

    def add_constant(self, value) -> int:
        """Add constant to pool and return index"""
        if value in self.constants:
            return self.constants.index(value)
        self.constants.append(value)
        return len(self.constants) - 1

    def get_var_slot(self, name: str) -> int:
        """Get or create variable slot"""
        if name not in self.variables:
            self.variables[name] = self.var_counter
            self.var_counter += 1
        return self.variables[name]

    def current_addr(self) -> int:
        """Get current code address"""
        return len(self.code)

    def patch_jump(self, addr: int, target: int):
        """Patch jump instruction with target address"""
        self.code[addr].operand = target

    def generate(self, node: AST):
        method_name = f'gen_{type(node).__name__}'
        generator = getattr(self, method_name, self.gen_default)
        return generator(node)

    def gen_default(self, node):
        raise NotImplementedError(f"No generator for {type(node).__name__}")

    def gen_Program(self, node: Program):
        for stmt in node.statements:
            self.generate(stmt)
        self.emit(Opcode.HALT)

    def gen_NumberLiteral(self, node: NumberLiteral):
        idx = self.add_constant(node.value)
        self.emit(Opcode.PUSH, idx)

    def gen_StringLiteral(self, node: StringLiteral):
        idx = self.add_constant(node.value)
        self.emit(Opcode.PUSH, idx)

    def gen_BoolLiteral(self, node: BoolLiteral):
        idx = self.add_constant(1 if node.value else 0)
        self.emit(Opcode.PUSH, idx)

    def gen_Identifier(self, node: Identifier):
        slot = self.get_var_slot(node.name)
        self.emit(Opcode.LOAD, slot)

    def gen_BinaryOp(self, node: BinaryOp):
        self.generate(node.left)
        self.generate(node.right)

        op_map = {
            '+': Opcode.ADD,
            '-': Opcode.SUB,
            '*': Opcode.MUL,
            '/': Opcode.DIV,
            '<': Opcode.LT,
            '>': Opcode.GT,
            '==': Opcode.EQ,
            '!=': Opcode.NEQ,
        }
        self.emit(op_map[node.op])

    def gen_UnaryOp(self, node: UnaryOp):
        self.generate(node.operand)
        if node.op == '-':
            self.emit(Opcode.NEG)

    def gen_Assignment(self, node: Assignment):
        self.generate(node.value)
        slot = self.get_var_slot(node.target.name)
        self.emit(Opcode.STORE, slot)

    def gen_ExpressionStmt(self, node: ExpressionStmt):
        self.generate(node.expr)
        self.emit(Opcode.POP)

    def gen_IfStmt(self, node: IfStmt):
        self.generate(node.condition)

        if node.else_branch:
            jump_else = self.emit(Opcode.JUMP_IF_FALSE, 0)  # Placeholder

            for stmt in node.then_branch:
                self.generate(stmt)
            jump_end = self.emit(Opcode.JUMP, 0)  # Placeholder

            else_addr = self.current_addr()
            for stmt in node.else_branch:
                self.generate(stmt)

            end_addr = self.current_addr()
            self.patch_jump(jump_else, else_addr)
            self.patch_jump(jump_end, end_addr)
        else:
            jump_end = self.emit(Opcode.JUMP_IF_FALSE, 0)

            for stmt in node.then_branch:
                self.generate(stmt)

            self.patch_jump(jump_end, self.current_addr())

    def gen_WhileStmt(self, node: WhileStmt):
        start_addr = self.current_addr()

        self.generate(node.condition)
        jump_end = self.emit(Opcode.JUMP_IF_FALSE, 0)

        for stmt in node.body:
            self.generate(stmt)

        self.emit(Opcode.JUMP, start_addr)
        self.patch_jump(jump_end, self.current_addr())

    def gen_Call(self, node: Call):
        if isinstance(node.callee, Identifier) and node.callee.name == 'print':
            # Special handling for print
            for arg in node.arguments:
                self.generate(arg)
                self.emit(Opcode.PRINT)
            # Push None as return value
            self.emit(Opcode.PUSH, self.add_constant(None))
            return

        # General function call
        for arg in node.arguments:
            self.generate(arg)

        if isinstance(node.callee, Identifier):
            func_idx = self.add_constant(node.callee.name)
            self.emit(Opcode.CALL, (func_idx, len(node.arguments)))

    def disassemble(self):
        """Print bytecode disassembly"""
        print("Constants:", self.constants)
        print("Variables:", self.variables)
        print("\nBytecode:")
        for i, instr in enumerate(self.code):
            print(f"  {i:4d}: {instr}")

Virtual Machine

class VM:
    """Stack-based virtual machine"""

    def __init__(self, bytecode: List[Bytecode], constants: List[any]):
        self.code = bytecode
        self.constants = constants
        self.stack: List[any] = []
        self.variables: Dict[int, any] = {}
        self.ip = 0  # Instruction pointer

    def push(self, value):
        self.stack.append(value)

    def pop(self):
        return self.stack.pop()

    def peek(self):
        return self.stack[-1]

    def run(self):
        """Execute bytecode"""
        while self.ip < len(self.code):
            instr = self.code[self.ip]
            self.ip += 1

            if instr.opcode == Opcode.HALT:
                break

            elif instr.opcode == Opcode.PUSH:
                self.push(self.constants[instr.operand])

            elif instr.opcode == Opcode.POP:
                self.pop()

            elif instr.opcode == Opcode.DUP:
                self.push(self.peek())

            elif instr.opcode == Opcode.LOAD:
                slot = instr.operand
                self.push(self.variables.get(slot, 0))

            elif instr.opcode == Opcode.STORE:
                slot = instr.operand
                self.variables[slot] = self.pop()

            elif instr.opcode == Opcode.ADD:
                b, a = self.pop(), self.pop()
                self.push(a + b)

            elif instr.opcode == Opcode.SUB:
                b, a = self.pop(), self.pop()
                self.push(a - b)

            elif instr.opcode == Opcode.MUL:
                b, a = self.pop(), self.pop()
                self.push(a * b)

            elif instr.opcode == Opcode.DIV:
                b, a = self.pop(), self.pop()
                self.push(a / b)

            elif instr.opcode == Opcode.NEG:
                self.push(-self.pop())

            elif instr.opcode == Opcode.LT:
                b, a = self.pop(), self.pop()
                self.push(1 if a < b else 0)

            elif instr.opcode == Opcode.GT:
                b, a = self.pop(), self.pop()
                self.push(1 if a > b else 0)

            elif instr.opcode == Opcode.EQ:
                b, a = self.pop(), self.pop()
                self.push(1 if a == b else 0)

            elif instr.opcode == Opcode.NEQ:
                b, a = self.pop(), self.pop()
                self.push(1 if a != b else 0)

            elif instr.opcode == Opcode.JUMP:
                self.ip = instr.operand

            elif instr.opcode == Opcode.JUMP_IF_FALSE:
                cond = self.pop()
                if not cond:
                    self.ip = instr.operand

            elif instr.opcode == Opcode.JUMP_IF_TRUE:
                cond = self.pop()
                if cond:
                    self.ip = instr.operand

            elif instr.opcode == Opcode.PRINT:
                print(self.pop())

        return self.peek() if self.stack else None

# Complete example
def compile_and_run(source: str):
    """Compile source and run on VM"""
    # Lex
    lexer = Lexer(source)
    tokens = [t for t in lexer.tokenize() if t.type != TokenType.NEWLINE]

    # Parse
    parser = Parser(tokens)
    ast = parser.parse()

    # Generate bytecode
    codegen = BytecodeGenerator()
    codegen.generate(ast)

    print("=== Bytecode ===")
    codegen.disassemble()

    # Run
    print("\n=== Output ===")
    vm = VM(codegen.code, codegen.constants)
    result = vm.run()

    return result

# Test
source = '''
x = 10
y = 20
z = x + y * 2
print(z)
'''

compile_and_run(source)

3. Optimization Passes

class Optimizer:
    """Simple bytecode optimizer"""

    def optimize(self, code: List[Bytecode], constants: List[any]):
        """Apply optimization passes"""
        code = self.constant_folding(code, constants)
        code = self.peephole(code)
        return code

    def constant_folding(self, code: List[Bytecode], constants: List[any]):
        """Fold constant expressions at compile time"""
        result = []
        i = 0

        while i < len(code):
            # Look for PUSH PUSH OP pattern
            if (i + 2 < len(code) and
                code[i].opcode == Opcode.PUSH and
                code[i+1].opcode == Opcode.PUSH and
                code[i+2].opcode in (Opcode.ADD, Opcode.SUB, Opcode.MUL, Opcode.DIV)):

                a = constants[code[i].operand]
                b = constants[code[i+1].operand]
                op = code[i+2].opcode

                # Compute constant result
                if op == Opcode.ADD:
                    val = a + b
                elif op == Opcode.SUB:
                    val = a - b
                elif op == Opcode.MUL:
                    val = a * b
                elif op == Opcode.DIV:
                    val = a / b

                # Add result to constants
                if val not in constants:
                    constants.append(val)
                idx = constants.index(val)

                result.append(Bytecode(Opcode.PUSH, idx))
                i += 3
            else:
                result.append(code[i])
                i += 1

        return result

    def peephole(self, code: List[Bytecode]):
        """Simple peephole optimizations"""
        result = []
        i = 0

        while i < len(code):
            # PUSH followed by POP - remove both
            if (i + 1 < len(code) and
                code[i].opcode == Opcode.PUSH and
                code[i+1].opcode == Opcode.POP):
                i += 2
                continue

            # JUMP to next instruction - remove
            if (code[i].opcode == Opcode.JUMP and
                code[i].operand == i + 1):
                i += 1
                continue

            result.append(code[i])
            i += 1

        return result

Summary

  • Three-address code is a common IR with simple instructions
  • Stack-based VMs use a stack for operands (simpler code generation)
  • Bytecode is a compact, portable intermediate format
  • Instruction selection maps IR to target instructions
  • Optimizations can be applied at IR or bytecode level

Module Complete

This completes the Compilers module! You've learned:

  1. Lexical analysis with regular expressions and finite automata
  2. Parsing with recursive descent and precedence climbing
  3. Semantic analysis with type checking
  4. Code generation for a stack-based VM

These concepts form the foundation for understanding programming language implementation.

← Previous: Semantic Analysis | Back to Course Index