Semantic Analysis and Type Checking

Introduction

Semantic analysis ensures the program is meaningful beyond just syntactic correctness. This phase checks types, resolves names, enforces scope rules, and annotates the AST with semantic information. This reading covers symbol tables, type systems, and type checking algorithms.

Learning Objectives

By the end of this reading, you will be able to:

  • Build and manage symbol tables for name resolution
  • Implement scope handling (lexical scoping)
  • Design a type system
  • Implement type checking for expressions and statements
  • Handle type inference basics

1. Symbol Tables

Purpose

The symbol table tracks all identifiers (variables, functions, types) and their attributes:

  • Name
  • Type
  • Scope
  • Memory location (for code generation)
from dataclasses import dataclass, field
from typing import Dict, Optional, Any, List
from enum import Enum, auto

class SymbolKind(Enum):
    VARIABLE = auto()
    FUNCTION = auto()
    PARAMETER = auto()
    CLASS = auto()
    TYPE = auto()

@dataclass
class Symbol:
    name: str
    kind: SymbolKind
    type: 'Type'
    scope_level: int
    # Additional attributes
    is_initialized: bool = False
    is_const: bool = False
    # For functions
    parameters: List['Type'] = field(default_factory=list)
    return_type: Optional['Type'] = None

class SymbolTable:
    """Hierarchical symbol table with scope support"""

    def __init__(self, parent: 'SymbolTable' = None):
        self.symbols: Dict[str, Symbol] = {}
        self.parent = parent
        self.children: List['SymbolTable'] = []
        self.scope_level = 0 if parent is None else parent.scope_level + 1

        if parent:
            parent.children.append(self)

    def define(self, name: str, symbol: Symbol) -> bool:
        """Define a new symbol in current scope"""
        if name in self.symbols:
            return False  # Already defined in this scope
        self.symbols[name] = symbol
        symbol.scope_level = self.scope_level
        return True

    def lookup(self, name: str) -> Optional[Symbol]:
        """Look up symbol in current scope and parent scopes"""
        if name in self.symbols:
            return self.symbols[name]
        if self.parent:
            return self.parent.lookup(name)
        return None

    def lookup_local(self, name: str) -> Optional[Symbol]:
        """Look up symbol only in current scope"""
        return self.symbols.get(name)

    def create_child_scope(self) -> 'SymbolTable':
        """Create a new child scope"""
        return SymbolTable(self)

    def __repr__(self):
        symbols_str = ', '.join(f"{k}: {v.type}" for k, v in self.symbols.items())
        return f"SymbolTable(level={self.scope_level}, symbols={{{symbols_str}}})"

2. Type System

Type Representation

from abc import ABC, abstractmethod

class Type(ABC):
    """Base class for all types"""

    @abstractmethod
    def __eq__(self, other):
        pass

    @abstractmethod
    def __hash__(self):
        pass

@dataclass(frozen=True)
class PrimitiveType(Type):
    """Primitive types: int, float, bool, string"""
    name: str

    def __eq__(self, other):
        return isinstance(other, PrimitiveType) and self.name == other.name

    def __hash__(self):
        return hash(('primitive', self.name))

    def __repr__(self):
        return self.name

@dataclass(frozen=True)
class ArrayType(Type):
    """Array type: array<element_type>"""
    element_type: Type

    def __eq__(self, other):
        return isinstance(other, ArrayType) and self.element_type == other.element_type

    def __hash__(self):
        return hash(('array', self.element_type))

    def __repr__(self):
        return f"array<{self.element_type}>"

@dataclass(frozen=True)
class FunctionType(Type):
    """Function type: (param_types) -> return_type"""
    param_types: tuple
    return_type: Type

    def __eq__(self, other):
        return (isinstance(other, FunctionType) and
                self.param_types == other.param_types and
                self.return_type == other.return_type)

    def __hash__(self):
        return hash(('function', self.param_types, self.return_type))

    def __repr__(self):
        params = ', '.join(str(p) for p in self.param_types)
        return f"({params}) -> {self.return_type}"

@dataclass(frozen=True)
class ClassType(Type):
    """Class/struct type"""
    name: str
    fields: tuple  # ((name, type), ...)
    methods: tuple = ()

    def __eq__(self, other):
        return isinstance(other, ClassType) and self.name == other.name

    def __hash__(self):
        return hash(('class', self.name))

    def __repr__(self):
        return self.name

@dataclass(frozen=True)
class NoneType(Type):
    """Type for null/None/void"""

    def __eq__(self, other):
        return isinstance(other, NoneType)

    def __hash__(self):
        return hash('none')

    def __repr__(self):
        return "None"

# Built-in types
INT = PrimitiveType('int')
FLOAT = PrimitiveType('float')
BOOL = PrimitiveType('bool')
STRING = PrimitiveType('string')
NONE = NoneType()

3. Type Checker

Semantic Analyzer

class SemanticError(Exception):
    def __init__(self, message, node=None):
        self.message = message
        self.node = node
        super().__init__(message)

class TypeChecker:
    """Type checker and semantic analyzer"""

    def __init__(self):
        self.global_scope = SymbolTable()
        self.current_scope = self.global_scope
        self.errors: List[SemanticError] = []
        self.current_function: Optional[FunctionType] = None

        # Initialize built-in functions
        self._init_builtins()

    def _init_builtins(self):
        """Define built-in functions and types"""
        # print function
        print_type = FunctionType((STRING,), NONE)
        self.global_scope.define('print', Symbol(
            'print', SymbolKind.FUNCTION, print_type, 0
        ))

        # len function
        len_type = FunctionType((STRING,), INT)  # Simplified
        self.global_scope.define('len', Symbol(
            'len', SymbolKind.FUNCTION, len_type, 0
        ))

    def error(self, message, node=None):
        """Record semantic error"""
        error = SemanticError(message, node)
        self.errors.append(error)
        return error

    def enter_scope(self):
        """Enter new scope"""
        self.current_scope = self.current_scope.create_child_scope()

    def exit_scope(self):
        """Exit current scope"""
        if self.current_scope.parent:
            self.current_scope = self.current_scope.parent

    def check(self, node: AST) -> Type:
        """Type check a node and return its type"""
        method_name = f'check_{type(node).__name__}'
        checker = getattr(self, method_name, self.generic_check)
        return checker(node)

    def generic_check(self, node):
        self.error(f"No type checker for {type(node).__name__}", node)
        return NONE

    # ===== Expression Type Checking =====

    def check_NumberLiteral(self, node: NumberLiteral) -> Type:
        if isinstance(node.value, int):
            return INT
        return FLOAT

    def check_StringLiteral(self, node: StringLiteral) -> Type:
        return STRING

    def check_BoolLiteral(self, node: BoolLiteral) -> Type:
        return BOOL

    def check_Identifier(self, node: Identifier) -> Type:
        symbol = self.current_scope.lookup(node.name)
        if symbol is None:
            self.error(f"Undefined variable '{node.name}'", node)
            return NONE
        return symbol.type

    def check_BinaryOp(self, node: BinaryOp) -> Type:
        left_type = self.check(node.left)
        right_type = self.check(node.right)

        op = node.op

        # Arithmetic operators
        if op in ('+', '-', '*', '/'):
            if left_type == INT and right_type == INT:
                return INT if op != '/' else FLOAT
            if left_type in (INT, FLOAT) and right_type in (INT, FLOAT):
                return FLOAT
            if op == '+' and left_type == STRING and right_type == STRING:
                return STRING
            self.error(
                f"Invalid operands for '{op}': {left_type} and {right_type}",
                node
            )
            return NONE

        # Comparison operators
        if op in ('<', '>', '<=', '>='):
            if left_type in (INT, FLOAT) and right_type in (INT, FLOAT):
                return BOOL
            self.error(
                f"Invalid operands for '{op}': {left_type} and {right_type}",
                node
            )
            return BOOL

        # Equality operators
        if op in ('==', '!='):
            if left_type != right_type:
                self.error(
                    f"Cannot compare {left_type} with {right_type}",
                    node
                )
            return BOOL

        self.error(f"Unknown operator '{op}'", node)
        return NONE

    def check_UnaryOp(self, node: UnaryOp) -> Type:
        operand_type = self.check(node.operand)

        if node.op == '-':
            if operand_type in (INT, FLOAT):
                return operand_type
            self.error(f"Cannot negate {operand_type}", node)
            return NONE

        if node.op == 'not':
            return BOOL

        self.error(f"Unknown unary operator '{node.op}'", node)
        return NONE

    def check_Call(self, node: Call) -> Type:
        callee_type = self.check(node.callee)

        if not isinstance(callee_type, FunctionType):
            self.error(f"Cannot call non-function type {callee_type}", node)
            return NONE

        # Check argument count
        if len(node.arguments) != len(callee_type.param_types):
            self.error(
                f"Expected {len(callee_type.param_types)} arguments, "
                f"got {len(node.arguments)}",
                node
            )

        # Check argument types
        for i, (arg, param_type) in enumerate(
            zip(node.arguments, callee_type.param_types)
        ):
            arg_type = self.check(arg)
            if not self.is_compatible(arg_type, param_type):
                self.error(
                    f"Argument {i+1}: expected {param_type}, got {arg_type}",
                    node
                )

        return callee_type.return_type

    # ===== Statement Type Checking =====

    def check_Program(self, node: Program) -> Type:
        for stmt in node.statements:
            self.check(stmt)
        return NONE

    def check_ExpressionStmt(self, node: ExpressionStmt) -> Type:
        self.check(node.expr)
        return NONE

    def check_Assignment(self, node: Assignment) -> Type:
        value_type = self.check(node.value)

        # Check if variable exists
        symbol = self.current_scope.lookup(node.target.name)

        if symbol is None:
            # New variable - define it
            symbol = Symbol(
                node.target.name,
                SymbolKind.VARIABLE,
                value_type,
                self.current_scope.scope_level,
                is_initialized=True
            )
            self.current_scope.define(node.target.name, symbol)
        else:
            # Existing variable - check type compatibility
            if not self.is_compatible(value_type, symbol.type):
                self.error(
                    f"Cannot assign {value_type} to {symbol.type}",
                    node
                )
            symbol.is_initialized = True

        return NONE

    def check_IfStmt(self, node: IfStmt) -> Type:
        cond_type = self.check(node.condition)
        if cond_type != BOOL:
            self.error(f"Condition must be bool, got {cond_type}", node)

        self.enter_scope()
        for stmt in node.then_branch:
            self.check(stmt)
        self.exit_scope()

        if node.else_branch:
            self.enter_scope()
            for stmt in node.else_branch:
                self.check(stmt)
            self.exit_scope()

        return NONE

    def check_WhileStmt(self, node: WhileStmt) -> Type:
        cond_type = self.check(node.condition)
        if cond_type != BOOL:
            self.error(f"Condition must be bool, got {cond_type}", node)

        self.enter_scope()
        for stmt in node.body:
            self.check(stmt)
        self.exit_scope()

        return NONE

    def check_FunctionDef(self, node: FunctionDef) -> Type:
        # For simplicity, infer parameter types as 'any' and return type as NONE
        # In a real system, you'd need type annotations

        param_types = tuple(NONE for _ in node.params)  # Placeholder
        func_type = FunctionType(param_types, NONE)

        # Define function in current scope
        symbol = Symbol(
            node.name,
            SymbolKind.FUNCTION,
            func_type,
            self.current_scope.scope_level
        )
        if not self.current_scope.define(node.name, symbol):
            self.error(f"Function '{node.name}' already defined", node)

        # Enter function scope
        self.enter_scope()
        old_function = self.current_function
        self.current_function = func_type

        # Define parameters
        for i, param in enumerate(node.params):
            param_symbol = Symbol(
                param,
                SymbolKind.PARAMETER,
                param_types[i] if i < len(param_types) else NONE,
                self.current_scope.scope_level,
                is_initialized=True
            )
            self.current_scope.define(param, param_symbol)

        # Check body
        for stmt in node.body:
            self.check(stmt)

        self.current_function = old_function
        self.exit_scope()

        return NONE

    def check_Return(self, node: Return) -> Type:
        if self.current_function is None:
            self.error("Return outside of function", node)
            return NONE

        if node.value:
            value_type = self.check(node.value)
            expected = self.current_function.return_type
            if expected != NONE and not self.is_compatible(value_type, expected):
                self.error(
                    f"Return type {value_type} doesn't match {expected}",
                    node
                )
        return NONE

    def is_compatible(self, source_type: Type, target_type: Type) -> bool:
        """Check if source_type can be assigned to target_type"""
        if source_type == target_type:
            return True
        if target_type == NONE:  # Any type for now
            return True
        # Allow int -> float coercion
        if source_type == INT and target_type == FLOAT:
            return True
        return False

# Example usage
def analyze(source: str):
    lexer = Lexer(source)
    tokens = [t for t in lexer.tokenize() if t.type != TokenType.NEWLINE]
    parser = Parser(tokens)
    ast = parser.parse()

    checker = TypeChecker()
    checker.check(ast)

    if checker.errors:
        print("Semantic errors found:")
        for error in checker.errors:
            print(f"  {error.message}")
    else:
        print("No semantic errors")

    return ast, checker

# Test
source = '''
x = 10
y = 20
z = x + y
'''
ast, checker = analyze(source)

4. Type Inference

Hindley-Milner Type Inference (Simplified)

class TypeVariable(Type):
    """Type variable for inference"""
    _counter = 0

    def __init__(self):
        TypeVariable._counter += 1
        self.id = TypeVariable._counter
        self.bound_to: Optional[Type] = None

    def __eq__(self, other):
        if isinstance(other, TypeVariable):
            return self.id == other.id
        return False

    def __hash__(self):
        return hash(('typevar', self.id))

    def __repr__(self):
        if self.bound_to:
            return str(self.bound_to)
        return f"T{self.id}"

    def resolve(self) -> Type:
        """Follow chain of bindings to get concrete type"""
        if self.bound_to is None:
            return self
        if isinstance(self.bound_to, TypeVariable):
            return self.bound_to.resolve()
        return self.bound_to

class TypeInference:
    """Simple type inference engine"""

    def __init__(self):
        self.constraints: List[tuple] = []  # (type1, type2) pairs to unify

    def fresh_type_var(self) -> TypeVariable:
        """Create fresh type variable"""
        return TypeVariable()

    def add_constraint(self, t1: Type, t2: Type):
        """Add constraint that t1 = t2"""
        self.constraints.append((t1, t2))

    def unify(self, t1: Type, t2: Type) -> bool:
        """Unify two types"""
        t1 = self.resolve(t1)
        t2 = self.resolve(t2)

        if t1 == t2:
            return True

        if isinstance(t1, TypeVariable):
            return self.bind(t1, t2)

        if isinstance(t2, TypeVariable):
            return self.bind(t2, t1)

        if isinstance(t1, FunctionType) and isinstance(t2, FunctionType):
            if len(t1.param_types) != len(t2.param_types):
                return False
            for p1, p2 in zip(t1.param_types, t2.param_types):
                if not self.unify(p1, p2):
                    return False
            return self.unify(t1.return_type, t2.return_type)

        if isinstance(t1, ArrayType) and isinstance(t2, ArrayType):
            return self.unify(t1.element_type, t2.element_type)

        return False

    def resolve(self, t: Type) -> Type:
        """Resolve type variable to concrete type"""
        if isinstance(t, TypeVariable):
            return t.resolve()
        return t

    def bind(self, var: TypeVariable, t: Type) -> bool:
        """Bind type variable to type"""
        if self.occurs_in(var, t):
            return False  # Infinite type
        var.bound_to = t
        return True

    def occurs_in(self, var: TypeVariable, t: Type) -> bool:
        """Check if var occurs in t (prevents infinite types)"""
        t = self.resolve(t)
        if isinstance(t, TypeVariable):
            return t.id == var.id
        if isinstance(t, FunctionType):
            return (any(self.occurs_in(var, p) for p in t.param_types) or
                    self.occurs_in(var, t.return_type))
        if isinstance(t, ArrayType):
            return self.occurs_in(var, t.element_type)
        return False

    def solve(self) -> bool:
        """Solve all constraints"""
        for t1, t2 in self.constraints:
            if not self.unify(t1, t2):
                return False
        return True

5. Additional Semantic Checks

class SemanticAnalyzer(TypeChecker):
    """Extended semantic analyzer with additional checks"""

    def __init__(self):
        super().__init__()
        self.in_loop = False
        self.return_found = False

    def check_WhileStmt(self, node):
        old_in_loop = self.in_loop
        self.in_loop = True
        result = super().check_WhileStmt(node)
        self.in_loop = old_in_loop
        return result

    def check_Break(self, node):
        if not self.in_loop:
            self.error("'break' outside of loop", node)
        return NONE

    def check_Continue(self, node):
        if not self.in_loop:
            self.error("'continue' outside of loop", node)
        return NONE

    def check_FunctionDef(self, node):
        result = super().check_FunctionDef(node)

        # Check that function returns on all paths (simplified)
        if self.current_function and self.current_function.return_type != NONE:
            if not self.return_found:
                self.error(
                    f"Function '{node.name}' may not return a value",
                    node
                )
        return result

    def check_unused_variables(self):
        """Check for variables that are defined but never used"""
        # Implementation would track variable usage
        pass

    def check_uninitialized_variables(self):
        """Check for variables used before initialization"""
        # Implementation would track initialization flow
        pass

Exercises

Basic

  1. Extend the type checker to handle array types and indexing.

  2. Add support for type annotations: x: int = 10

  3. Implement detection of unused variables.

Intermediate

  1. Add support for classes with fields and methods.

  2. Implement the algorithm W for Hindley-Milner type inference.

  3. Add support for generic/parametric types like List<T>.

Advanced

  1. Implement flow-sensitive type analysis (types can change based on control flow).

  2. Add support for nullable types and null safety checks.

  3. Implement a type system with subtyping and variance.


Summary

  • Symbol tables track identifiers and their attributes
  • Scope handling manages visibility and lifetime of names
  • Type systems define what types exist and their relationships
  • Type checking ensures operations are valid for their operand types
  • Type inference automatically determines types without annotations
  • Semantic analysis catches errors that syntax analysis cannot

Next Reading

Code Generation →