测试生成 Agent:测试用例设计与自动化测试

测试生成 Agent 通过代码理解、符号执行和边界分析,自动生成高质量测试用例,提升测试覆盖率,降低测试维护成本,让开发者能够更专注于业务逻辑实现。

测试生成 Agent:测试用例设计与自动化测试

概述与动机

在软件开发过程中,测试是保证代码质量和功能正确性的关键环节。然而,编写高质量的测试用例往往比实现业务逻辑还要耗时耗力。开发者需要在正常用例、边界用例、异常用例之间找到平衡,确保测试既能覆盖主要功能,又能捕获潜在的边界问题。随着代码复杂度的增长,手动编写和维护测试用件变得越来越困难。

测试生成 Agent 通过结合静态分析、符号执行和机器学习技术,自动化测试用例的生成和维护。它能够理解代码的输入输出契约,识别代码中的分支和边界条件,生成全面且高效的测试用例。更重要的是,Agent 能够随着代码的演进而自动更新测试,保持测试的同步性和有效性。

从业务价值角度看,高质量的测试用例能够显著降低线上故障率,提升系统稳定性。测试生成 Agent 的价值不仅在于减少编写测试的时间,更在于提升测试的质量和覆盖率。它能发现手工测试容易遗漏的边界情况和潜在 bug,提供更高的测试置信度。同时,通过自动化的测试维护,降低了代码重构的成本,让开发者更有信心进行代码优化和重构。

核心概念与架构设计

代码理解与契约推断

测试用例生成的基础是对代码的深度理解。Agent 需要理解代码的功能、输入输出契约、边界条件和异常情况。这种理解不仅包括语法层面的分析,还包括语义层面的推断。

代码理解的核心是构建代码的控制流图和数据流图。控制流图展示代码的执行路径,包括条件分支、循环、异常处理等。数据流图展示数据在代码中的流动,包括变量定义、使用、修改等。通过分析这两个图,Agent 能够识别代码中的关键路径和边界条件。

契约推断是代码理解的高级形式。Agent 能够从代码的实现中推断出输入输出的约束条件,比如函数的参数类型、返回值范围、前提条件和后置条件。这些契约信息对于生成有效的测试用例至关重要。

边界条件识别

边界条件是测试的关键,很多 bug 就隐藏在边界情况中。Agent 需要识别各种类型的边界条件,包括数值边界、字符串边界、集合边界、时间边界等。

数值边界包括最大值、最小值、零值、负值、浮点数边界(如 NaN、Infinity)。字符串边界包括空字符串、单字符、最大长度字符串、特殊字符(如 Unicode、控制字符)。集合边界包括空集合、单元素集合、最大容量集合、重复元素。时间边界包括时间戳零值、最大时间戳、闰秒、时区边界。

边界条件识别的一个挑战是代码可能显式或隐式地假设某些边界。Agent 需要通过代码分析发现这些假设,并生成相应的测试用例。比如,代码中可能有 if (index < array.length) 这样的检查,Agent 应该生成 index = array.length - 1index = array.length 的测试用例。

符号执行与路径覆盖

符号执行是一种程序分析技术,通过将输入视为符号变量而不是具体值,分析程序的所有可能执行路径。Agent 可以利用符号执行自动生成能够覆盖代码中每条路径的测试用例。

符号执行的核心是路径条件(Path Condition)。对于代码中的每个条件分支,符号执行会产生不同的路径条件。通过求解路径条件,Agent 能够生成触发该路径的具体输入值。

路径覆盖的一个挑战是路径爆炸问题。随着代码复杂度的增加,可能的执行路径呈指数级增长。Agent 需要采用启发式方法,优先选择重要的路径(如错误处理路径、边界路径)进行覆盖,而不是盲目地覆盖所有路径。

断言生成策略

测试用例的价值不仅在于输入设计,还在于断言的正确性。Agent 需要生成能够验证代码行为的断言,包括输出断言、状态断言和异常断言。

输出断言验证函数的返回值是否符合预期。Agent 可以通过分析代码逻辑推断输出值,或者通过执行代码并观察输出值来生成断言。状态断言验证函数执行后的状态变化,比如对象属性的修改、全局变量的变化等。异常断言验证函数在特定输入下是否抛出预期的异常。

断言生成的一个挑战是平衡断言的严格性和灵活性。过于严格的断言可能因为实现细节的变化而失效,过于宽松的断言则无法有效验证代码行为。Agent 需要理解代码的契约,生成基于契约的断言,而不是基于实现的断言。

Agent 架构设计

测试生成 Agent 的架构采用流水线设计,从代码分析到测试生成的各个阶段职责清晰。

Rendering diagram...

代码解析模块负责将源代码转换为中间表示,包括 AST、控制流图、数据流图等。契约推断模块分析代码的输入输出契约,包括类型约束、值域约束、前置条件、后置条件等。边界条件识别模块识别代码中的各种边界条件。路径分析模块分析代码的控制流,识别需要覆盖的关键路径。测试用例生成模块基于契约、边界条件和路径分析,生成测试输入。断言生成模块为每个测试用例生成相应的断言。测试代码生成模块将测试用例和断言转换为可执行的测试代码。测试执行模块执行测试代码,收集测试结果。覆盖率分析模块分析测试覆盖率,识别未覆盖的代码区域。测试优化模块基于覆盖率分析结果,优化测试用例,提高覆盖率。

关键技术实现

代码解析与契约推断

代码解析是测试生成的基础。我们实现一个代码解析器和契约推断器。

import ast
import inspect
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from enum import Enum

class ParamType(Enum):
    """参数类型"""
    INT = "int"
    FLOAT = "float"
    STR = "str"
    BOOL = "bool"
    LIST = "list"
    DICT = "dict"
    CUSTOM = "custom"

@dataclass
class FunctionContract:
    """函数契约"""
    name: str
    parameters: List[Dict[str, Any]]
    return_type: Optional[ParamType]
    preconditions: List[str]
    postconditions: List[str]
    exceptions: List[str]

class ContractInferencer:
    """契约推断器"""

    def __init__(self):
        self.type_hints = {
            'int': ParamType.INT,
            'float': ParamType.FLOAT,
            'str': ParamType.STR,
            'bool': ParamType.BOOL,
            'list': ParamType.LIST,
            'dict': ParamType.DICT,
        }

    def infer_contract(self, func) -> FunctionContract:
        """推断函数契约"""
        # 获取函数信息
        func_name = func.__name__
        func_code = inspect.getsource(func)

        # 解析函数
        tree = ast.parse(func_code)
        func_node = self._find_function_node(tree, func_name)

        if not func_node:
            raise ValueError(f"无法找到函数节点: {func_name}")

        # 推断参数
        parameters = self._infer_parameters(func, func_node)

        # 推断返回类型
        return_type = self._infer_return_type(func, func_node)

        # 推断前置条件
        preconditions = self._infer_preconditions(func_node)

        # 推断后置条件
        postconditions = self._infer_postconditions(func_node)

        # 推断异常
        exceptions = self._infer_exceptions(func_node)

        return FunctionContract(
            name=func_name,
            parameters=parameters,
            return_type=return_type,
            preconditions=preconditions,
            postconditions=postconditions,
            exceptions=exceptions
        )

    def _find_function_node(self, tree: ast.AST, func_name: str) -> Optional[ast.FunctionDef]:
        """查找函数节点"""
        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef) and node.name == func_name:
                return node
        return None

    def _infer_parameters(self, func, func_node: ast.FunctionDef) -> List[Dict[str, Any]]:
        """推断参数"""
        parameters = []

        # 获取类型注解
        type_hints = inspect.getfullargspec(func).annotations

        # 分析每个参数
        for arg in func_node.args.args:
            param_name = arg.arg
            param_type = self._infer_param_type(arg, type_hints)

            # 推断参数约束
            constraints = self._infer_param_constraints(func_node, param_name)

            parameters.append({
                "name": param_name,
                "type": param_type,
                "constraints": constraints,
                "default": self._get_default_value(arg)
            })

        return parameters

    def _infer_param_type(self, arg: ast.arg, type_hints: Dict[str, Any]) -> ParamType:
        """推断参数类型"""
        # 检查类型注解
        if arg.arg in type_hints:
            type_hint = type_hints[arg.arg]
            if hasattr(type_hint, '__name__'):
                type_name = type_hint.__name__
            elif hasattr(type_hint, '_name'):
                type_name = type_hint._name
            else:
                type_name = str(type_hint)

            if type_name in self.type_hints:
                return self.type_hints[type_name]

        # 尝试从默认值推断
        if arg.default:
            return self._infer_type_from_value(arg.default)

        return ParamType.CUSTOM

    def _infer_type_from_value(self, node: ast.AST) -> ParamType:
        """从值推断类型"""
        if isinstance(node, ast.Num):
            return ParamType.INT if isinstance(node.n, int) else ParamType.FLOAT
        elif isinstance(node, ast.Str):
            return ParamType.STR
        elif isinstance(node, ast.NameConstant):
            return ParamType.BOOL
        elif isinstance(node, ast.List):
            return ParamType.LIST
        elif isinstance(node, ast.Dict):
            return ParamType.DICT
        else:
            return ParamType.CUSTOM

    def _infer_param_constraints(self, func_node: ast.FunctionDef, param_name: str) -> List[str]:
        """推断参数约束"""
        constraints = []

        # 查找参数检查
        for node in ast.walk(func_node):
            if isinstance(node, ast.Compare):
                # 检查是否是对该参数的比较
                if isinstance(node.left, ast.Name) and node.left.id == param_name:
                    constraints.append(self._generate_constraint(node))
                elif any(
                    isinstance(operand, ast.Name) and operand.id == param_name
                    for operand in node.comparators
                ):
                    constraints.append(self._generate_constraint(node))

        return constraints

    def _generate_constraint(self, node: ast.Compare) -> str:
        """生成约束表达式"""
        import astor

        left = astor.to_source(node.left).strip()
        ops = ['<' if isinstance(op, ast.Lt) else '>' if isinstance(op, ast.Gt) else
                '<=' if isinstance(op, ast.LtE) else '>=' if isinstance(op, ast.GtE) else
                '==' if isinstance(op, ast.Eq) else '!=' if isinstance(op, ast.NotEq) else
                'is' if isinstance(op, ast.Is) else 'is not' if isinstance(op, ast.IsNot) else
                'in' if isinstance(op, ast.In) else 'not in' if isinstance(op, ast.NotIn) else '?'
               for op in node.ops]
        comparators = [astor.to_source(comp).strip() for comp in node.comparators]

        constraint = f"{left} {ops[0]} {comparators[0]}"
        return constraint

    def _get_default_value(self, arg: ast.arg) -> Any:
        """获取默认值"""
        if arg.default:
            import astor
            return ast.literal_eval(astor.to_source(arg.default).strip())
        return None

    def _infer_return_type(self, func, func_node: ast.FunctionDef) -> Optional[ParamType]:
        """推断返回类型"""
        # 检查返回类型注解
        type_hints = inspect.getfullargspec(func).annotations
        if 'return' in type_hints:
            return_type = type_hints['return']
            if hasattr(return_type, '__name__'):
                type_name = return_type.__name__
            elif hasattr(return_type, '_name'):
                type_name = return_type._name
            else:
                type_name = str(return_type)

            if type_name in self.type_hints:
                return self.type_hints[type_name]

        # 从 return 语句推断
        return_nodes = [node for node in ast.walk(func_node) if isinstance(node, ast.Return)]
        if return_nodes:
            return_node = return_nodes[0]
            if return_node.value:
                return self._infer_type_from_value(return_node.value)

        return None

    def _infer_preconditions(self, func_node: ast.FunctionDef) -> List[str]:
        """推断前置条件"""
        preconditions = []

        # 查找 if 检查
        for node in ast.walk(func_node):
            if isinstance(node, ast.If):
                # 分析条件
                condition = self._extract_condition(node.test)
                if condition:
                    preconditions.append(f"确保 {condition}")

        return preconditions

    def _infer_postconditions(self, func_node: ast.FunctionDef) -> List[str]:
        """推断后置条件"""
        postconditions = []

        # 查找 return 语句
        return_nodes = [node for node in ast.walk(func_node) if isinstance(node, ast.Return)]
        if return_nodes:
            postconditions.append("返回正确的值")

        # 查找状态修改
        # (简化版,实际应该分析对象状态变化)
        return postconditions

    def _infer_exceptions(self, func_node: ast.FunctionDef) -> List[str]:
        """推断异常"""
        exceptions = []

        # 查找 raise 语句
        raise_nodes = [node for node in ast.walk(func_node) if isinstance(node, ast.Raise)]
        for raise_node in raise_nodes:
            if raise_node.exc:
                if isinstance(raise_node.exc, ast.Name):
                    exceptions.append(raise_node.exc.id)
                elif isinstance(raise_node.exc, ast.Call):
                    if isinstance(raise_node.exc.func, ast.Name):
                        exceptions.append(raise_node.exc.func.id)

        return list(set(exceptions))  # 去重

    def _extract_condition(self, node: ast.AST) -> str:
        """提取条件表达式"""
        import astor
        return astor.to_source(node).strip()

契约推断器分析函数的参数、返回值、前置条件、后置条件和异常,为测试用例生成提供基础信息。

边界条件识别器

边界条件识别器识别各种类型的边界条件,为测试用例生成提供边界值。

from typing import Dict, List, Optional, Any
from dataclasses import dataclass
import math

@dataclass
class BoundaryValue:
    """边界值"""
    name: str
    value: Any
    description: str

class BoundaryConditionIdentifier:
    """边界条件识别器"""

    def __init__(self):
        # 基本类型的边界值
        self.type_boundaries = {
            ParamType.INT: [
                BoundaryValue("min_int", -2**31, "32位整数最小值"),
                BoundaryValue("max_int", 2**31 - 1, "32位整数最大值"),
                BoundaryValue("zero", 0, "零值"),
                BoundaryValue("negative_one", -1, "负一"),
                BoundaryValue("positive_one", 1, "正一"),
            ],
            ParamType.FLOAT: [
                BoundaryValue("min_float", float('-inf'), "浮点数负无穷"),
                BoundaryValue("max_float", float('inf'), "浮点数正无穷"),
                BoundaryValue("zero", 0.0, "浮点数零值"),
                BoundaryValue("nan", float('nan'), "非数值"),
                BoundaryValue("epsilon", 1e-10, "极小值"),
            ],
            ParamType.STR: [
                BoundaryValue("empty_string", "", "空字符串"),
                BoundaryValue("single_char", "a", "单字符"),
                BoundaryValue("max_length", "a" * 10000, "长字符串"),
                BoundaryValue("unicode", "你好世界", "Unicode 字符"),
                BoundaryValue("whitespace", " ", "空白字符"),
            ],
            ParamType.BOOL: [
                BoundaryValue("true", True, "布尔真值"),
                BoundaryValue("false", False, "布尔假值"),
            ],
            ParamType.LIST: [
                BoundaryValue("empty_list", [], "空列表"),
                BoundaryValue("single_element", [1], "单元素列表"),
                BoundaryValue("max_size", list(range(1000)), "大列表"),
                BoundaryValue("duplicate_elements", [1, 1, 1], "重复元素"),
            ],
            ParamType.DICT: [
                BoundaryValue("empty_dict", {}, "空字典"),
                BoundaryValue("single_key", {"key": "value"}, "单键字典"),
                BoundaryValue("max_size", {str(i): i for i in range(1000)}, "大字典"),
            ],
        }

    def identify_boundaries(self, param_type: ParamType, constraints: List[str]) -> List[BoundaryValue]:
        """识别边界值"""
        boundaries = []

        # 获取基本类型的边界值
        if param_type in self.type_boundaries:
            boundaries.extend(self.type_boundaries[param_type])

        # 基于约束识别额外边界
        constraint_boundaries = self._identify_constraint_boundaries(constraints)
        boundaries.extend(constraint_boundaries)

        return boundaries

    def _identify_constraint_boundaries(self, constraints: List[str]) -> List[BoundaryValue]:
        """基于约束识别边界值"""
        boundaries = []

        for constraint in constraints:
            # 解析约束表达式
            if '<=' in constraint:
                parts = constraint.split('<=')
                if len(parts) == 2:
                    try:
                        value = int(parts[1].strip())
                        boundaries.append(BoundaryValue("constraint_max", value, f"约束最大值: {value}"))
                        boundaries.append(BoundaryValue("constraint_max_minus_one", value - 1, f"约束最大值-1: {value-1}"))
                        boundaries.append(BoundaryValue("constraint_max_plus_one", value + 1, f"约束最大值+1: {value+1}"))
                    except ValueError:
                        pass
            elif '>=' in constraint:
                parts = constraint.split('>=')
                if len(parts) == 2:
                    try:
                        value = int(parts[1].strip())
                        boundaries.append(BoundaryValue("constraint_min", value, f"约束最小值: {value}"))
                        boundaries.append(BoundaryValue("constraint_min_minus_one", value - 1, f"约束最小值-1: {value-1}"))
                        boundaries.append(BoundaryValue("constraint_min_plus_one", value + 1, f"约束最小值+1: {value+1}"))
                    except ValueError:
                        pass

        return boundaries

边界条件识别器识别各种类型的边界值,包括基本类型的边界值和基于约束的边界值。

测试用例生成器

测试用例生成器基于契约和边界值生成测试用例。

import itertools
from typing import Dict, List, Optional, Any
from dataclasses import dataclass

@dataclass
class TestCase:
    """测试用例"""
    name: str
    inputs: Dict[str, Any]
    expected_output: Optional[Any] = None
    expected_exception: Optional[str] = None
    description: str = ""

class TestCaseGenerator:
    """测试用例生成器"""

    def __init__(self, max_test_cases: int = 100):
        self.max_test_cases = max_test_cases
        self.boundary_identifier = BoundaryConditionIdentifier()

    def generate_test_cases(self, contract: FunctionContract) -> List[TestCase]:
        """生成测试用例"""
        test_cases = []

        # 1. 生成正常用例
        normal_cases = self._generate_normal_cases(contract)
        test_cases.extend(normal_cases)

        # 2. 生成边界用例
        boundary_cases = self._generate_boundary_cases(contract)
        test_cases.extend(boundary_cases)

        # 3. 生成异常用例
        exception_cases = self._generate_exception_cases(contract)
        test_cases.extend(exception_cases)

        # 限制测试用例数量
        return test_cases[:self.max_test_cases]

    def _generate_normal_cases(self, contract: FunctionContract) -> List[TestCase]:
        """生成正常用例"""
        test_cases = []

        if not contract.parameters:
            # 无参数函数
            test_cases.append(TestCase(
                name="test_normal",
                inputs={},
                description="正常执行"
            ))
            return test_cases

        # 生成组合测试用例
        param_combinations = self._generate_param_combinations(contract.parameters)

        for i, params in enumerate(param_combinations):
            test_cases.append(TestCase(
                name=f"test_normal_{i}",
                inputs=params,
                description=f"正常用例 {i}"
            ))

        return test_cases

    def _generate_boundary_cases(self, contract: FunctionContract) -> List[TestCase]:
        """生成边界用例"""
        test_cases = []

        for param in contract.parameters:
            param_name = param["name"]
            param_type = param["type"]
            constraints = param["constraints"]

            # 识别边界值
            boundaries = self.boundary_identifier.identify_boundaries(param_type, constraints)

            for boundary in boundaries:
                test_cases.append(TestCase(
                    name=f"test_boundary_{param_name}_{boundary.name}",
                    inputs={param_name: boundary.value},
                    description=boundary.description
                ))

        return test_cases

    def _generate_exception_cases(self, contract: FunctionContract) -> List[TestCase]:
        """生成异常用例"""
        test_cases = []

        for exception in contract.exceptions:
            # 生成触发该异常的测试用例
            test_case = self._generate_exception_test_case(contract, exception)
            if test_case:
                test_cases.append(test_case)

        return test_cases

    def _generate_exception_test_case(self, contract: FunctionContract, exception: str) -> Optional[TestCase]:
        """生成特定异常的测试用例"""
        # 简化版:基于参数约束生成异常用例
        for param in contract.parameters:
            param_name = param["name"]
            param_type = param["type"]

            # 根据异常类型生成输入
            if "ValueError" in exception or "TypeError" in exception:
                # 生成类型错误的输入
                if param_type == ParamType.INT:
                    return TestCase(
                        name=f"test_exception_{exception}",
                        inputs={param_name: "invalid"},
                        expected_exception=exception,
                        description=f"触发 {exception}"
                    )
            elif "IndexError" in exception or "KeyError" in exception:
                # 生成越界或不存在键的输入
                if param_type == ParamType.LIST:
                    return TestCase(
                        name=f"test_exception_{exception}",
                        inputs={param_name: []},
                        expected_exception=exception,
                        description=f"触发 {exception}"
                    )

        return None

    def _generate_param_combinations(self, parameters: List[Dict]) -> List[Dict]:
        """生成参数组合"""
        # 为每个参数生成一些典型值
        param_values = []

        for param in parameters:
            param_name = param["name"]
            param_type = param["type"]
            default = param.get("default")

            values = self._generate_typical_values(param_type, default)
            param_values.append([(param_name, value) for value in values])

        # 生成笛卡尔积
        combinations = []
        for combination in itertools.product(*param_values):
            test_input = {}
            for key, value in combination:
                test_input[key] = value
            combinations.append(test_input)

        return combinations

    def _generate_typical_values(self, param_type: ParamType, default: Optional[Any] = None) -> List[Any]:
        """生成典型值"""
        if default is not None:
            return [default]

        typical_values = {
            ParamType.INT: [0, 1, 100, -1, -100],
            ParamType.FLOAT: [0.0, 1.0, 100.0, -1.0, 0.5],
            ParamType.STR: ["", "test", "你好世界", "a" * 100],
            ParamType.BOOL: [True, False],
            ParamType.LIST: [[], [1], [1, 2, 3]],
            ParamType.DICT: [{}, {"key": "value"}],
        }

        return typical_values.get(param_type, [None])

测试用例生成器生成正常用例、边界用例和异常用例,覆盖代码的主要路径和边界情况。

断言生成器

断言生成器为测试用例生成相应的断言。

from typing import Dict, List, Optional, Any
import ast
import inspect

class AssertionGenerator:
    """断言生成器"""

    def __init__(self):
        pass

    def generate_assertions(self,
                            contract: FunctionContract,
                            test_case: TestCase,
                            execution_result: Any) -> List[str]:
        """生成断言"""
        assertions = []

        # 1. 输出断言
        if test_case.expected_output is not None:
            assertions.append(f"assert result == {repr(test_case.expected_output)}")
        elif execution_result is not None:
            # 基于执行结果生成断言
            assertions.append(f"assert result == {repr(execution_result)}")

        # 2. 异常断言
        if test_case.expected_exception:
            assertions.append(f"with pytest.raises({test_case.expected_exception}):")

        # 3. 后置条件断言
        for postcondition in contract.postconditions:
            assertions.append(f"# {postcondition}")

        return assertions

    def generate_mock_assertions(self, contract: FunctionContract) -> List[str]:
        """生成模拟断言(用于示例)"""
        assertions = []

        # 基于返回类型生成断言
        if contract.return_type:
            if contract.return_type == ParamType.INT:
                assertions.append("assert isinstance(result, int)")
                assertions.append("assert result >= 0")
            elif contract.return_type == ParamType.STR:
                assertions.append("assert isinstance(result, str)")
                assertions.append("assert len(result) > 0")

        # 基于后置条件生成断言
        for postcondition in contract.postconditions:
            assertions.append(f"# {postcondition}")

        return assertions

断言生成器基于函数契约和测试结果生成相应的断言。

完整的测试生成 Agent

将所有组件组合起来,创建一个完整的测试生成 Agent。

import asyncio
from typing import Dict, List, Optional, Any
import os
import subprocess

class TestGenerationAgent:
    """测试生成 Agent"""

    def __init__(self, llm_api_key: str, llm_model: str = "gpt-4"):
        self.contract_inferencer = ContractInferencer()
        self.test_case_generator = TestCaseGenerator()
        self.assertion_generator = AssertionGenerator()
        self.code_generator = TestCodeGenerator()

    async def generate_tests(self,
                           func,
                           output_file: str = "test_generated.py") -> str:
        """生成测试代码"""
        func_name = func.__name__

        print(f"开始为函数 {func_name} 生成测试...")

        # 1. 推断函数契约
        print("步骤 1/5: 推断函数契约...")
        contract = self.contract_inferencer.infer_contract(func)

        # 2. 生成测试用例
        print("步骤 2/5: 生成测试用例...")
        test_cases = self.test_case_generator.generate_test_cases(contract)

        # 3. 生成断言
        print("步骤 3/5: 生成断言...")
        for test_case in test_cases:
            assertions = self.assertion_generator.generate_mock_assertions(contract)
            test_case.description += f" 断言: {', '.join(assertions[:2])}"

        # 4. 生成测试代码
        print("步骤 4/5: 生成测试代码...")
        test_code = self.code_generator.generate_test_code(func, contract, test_cases)

        # 5. 保存测试代码
        print("步骤 5/5: 保存测试代码...")
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write(test_code)

        print(f"测试代码已保存到 {output_file}")
        return test_code

    def run_tests(self, test_file: str) -> Dict[str, Any]:
        """运行测试"""
        try:
            result = subprocess.run(
                ['python', '-m', 'pytest', test_file, '-v'],
                capture_output=True,
                text=True
            )

            return {
                "success": result.returncode == 0,
                "output": result.stdout,
                "errors": result.stderr
            }
        except Exception as e:
            return {
                "success": False,
                "output": "",
                "errors": str(e)
            }


class TestCodeGenerator:
    """测试代码生成器"""

    def __init__(self):
        pass

    def generate_test_code(self,
                           func,
                           contract: FunctionContract,
                           test_cases: List[TestCase]) -> str:
        """生成测试代码"""
        code = self._generate_header()
        code += self._generate_imports()
        code += self._generate_test_class(func, contract, test_cases)
        code += self._generate_footer()

        return code

    def _generate_header(self) -> str:
        """生成文件头"""
        return '''# 自动生成的测试代码
# 由 TestGenerationAgent 生成
# 请根据需要修改和完善

import pytest
import sys
import os

# 添加源码路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

'''

    def _generate_imports(self) -> str:
        """生成导入语句"""
        return '''from src.main import *  # 导入被测试的函数
from typing import Any
'''

    def _generate_test_class(self,
                             func,
                             contract: FunctionContract,
                             test_cases: List[TestCase]) -> str:
        """生成测试类"""
        func_name = func.__name__

        code = f'''
class Test{func_name.capitalize()}:
    """测试类: {func_name}"""

'''

        # 生成测试方法
        for test_case in test_cases:
            code += self._generate_test_method(func, test_case)
            code += '\n'

        return code

    def _generate_test_method(self, func, test_case: TestCase) -> str:
        """生成测试方法"""
        func_name = func.__name__

        code = f'''    def {test_case.name}(self):
        """
        {test_case.description}
        """
'''

        # 生成参数设置
        for key, value in test_case.inputs.items():
            code += f'        {key} = {repr(value)}\n'

        # 生成函数调用
        params_str = ', '.join(test_case.inputs.keys())
        code += f'\n        result = {func_name}({params_str})\n\n'

        # 生成断言
        code += '        # 断言\n'
        if test_case.expected_exception:
            code += f'        # 注意:此测试用例预期抛出 {test_case.expected_exception}\n'
        else:
            code += '        assert result is not None\n'
            code += '        # 添加更多断言...\n'

        return code

    def _generate_footer(self) -> str:
        """生成文件尾"""
        return '''
if __name__ == "__main__":
    pytest.main([__file__, "-v"])
'''


# 示例函数
def calculate_discount(price: int, is_member: bool) -> float:
    """计算折扣价格

    Args:
        price: 原价,必须 >= 0
        is_member: 是否是会员

    Returns:
        折扣后的价格
    """
    if price < 0:
        raise ValueError("价格不能为负数")

    if is_member:
        return price * 0.8  # 会员8折
    else:
        return price * 0.9  # 非会员9折


def find_index(arr: List[int], target: int) -> int:
    """在数组中查找目标索引

    Args:
        arr: 数组
        target: 目标值

    Returns:
        目标值的索引,如果不存在返回 -1
    """
    if target < 0:
        raise ValueError("目标值不能为负数")

    for i, value in enumerate(arr):
        if value == target:
            return i

    return -1


# 使用示例
async def main():
    # 创建测试生成 Agent
    agent = TestGenerationAgent(llm_api_key="your-api-key-here")

    # 为示例函数生成测试
    print("=== 为 calculate_discount 生成测试 ===")
    discount_code = await agent.generate_tests(
        calculate_discount,
        output_file="/Users/liuyutao/Desktop/workspace/test_calculate_discount.py"
    )

    print("\n=== 为 find_index 生成测试 ===")
    index_code = await agent.generate_tests(
        find_index,
        output_file="/Users/liuyutao/Desktop/workspace/test_find_index.py"
    )

    # 运行测试(简化版,实际需要被测试函数)
    # result = agent.run_tests("/Users/liuyutao/Desktop/workspace/test_calculate_discount.py")
    # print(f"\n=== 测试运行结果 ===")
    # print(f"成功: {result['success']}")
    # print(f"输出:\n{result['output']}")
    # if result['errors']:
    #     print(f"错误:\n{result['errors']}")

if __name__ == "__main__":
    asyncio.run(main())

这个完整的测试生成 Agent 整合了契约推断、测试用例生成、断言生成和测试代码生成等功能,提供端到端的测试生成支持。

最佳实践与常见陷阱

代码契约注解

代码契约注解是测试生成的基础。清晰的契约注解能够帮助 Agent 准确理解代码的输入输出关系,生成更准确的测试用例。

最佳实践包括:

  1. 使用类型注解:为所有函数参数和返回值添加类型注解,帮助 Agent 理解数据类型约束。
  2. 编写文档字符串:详细描述函数的功能、参数要求、返回值含义和异常情况。
  3. 添加契约检查:使用契约库(如 dealicontract)明确表示前置条件和后置条件。

常见陷阱是契约注解不完整或不准确。这会导致 Agent 生成无效的测试用例或遗漏重要的测试场景。最佳实践是在代码编写时就完善契约注解,而不是事后补充。

测试用例多样性

测试用例的多样性直接影响测试覆盖率。Agent 应该生成多种类型的测试用例,包括正常用例、边界用例、异常用例、组合用例等。

边界用例特别重要,因为很多 bug 就隐藏在边界情况中。Agent 应该识别各种类型的边界条件,包括数值边界、字符串边界、集合边界、时间边界等。

另一个重要考虑是组合用例。多个参数的不同组合可能产生不同的行为,Agent 应该生成有效的参数组合测试用例,而不是只测试单一参数的变化。

断言的质量

断言的质量直接影响测试的有效性。Agent 应该生成基于契约的断言,而不是基于实现的断言。

基于契约的断言关注函数的输入输出关系,而不是具体的实现细节。这样的断言更稳定,不容易因为实现细节的变化而失效。

另一个考虑是断言的完整性。Agent 应该生成足够多的断言,验证函数的各个方面,包括返回值、状态变化、异常等。同时,断言应该易于理解和维护。

测试维护

测试维护是测试生成的一个重要考虑。代码演进时,测试也需要相应更新。Agent 应该能够识别代码变更,自动更新受影响的测试用例。

增量更新是测试维护的一个关键策略。Agent 应该只更新受影响的测试用例,而不是重新生成所有测试用例,这样可以保持测试的稳定性。

另一个考虑是测试的独立性。测试用例之间应该相互独立,一个测试用例的失败不应该影响其他测试用例的执行。

性能优化考虑

测试用例选择

测试用例选择是性能优化的关键。Agent 不应该生成所有可能的测试用例,而是选择最重要的测试用例。

重要性排序是测试用例选择的一个策略。Agent 可以根据代码覆盖率、边界重要性、异常概率等因素对测试用例排序,优先选择重要的测试用例。

另一个优化是测试用例去重。不同的测试路径可能产生相似的测试用例,Agent 应该识别并去除重复的测试用例。

并行执行

测试执行可以并行化,特别是对于独立的测试用例。Agent 可以并行执行多个测试用例,提升执行效率。

并行执行需要考虑测试用例之间的依赖关系。只有相互独立的测试用例才能并行执行,有依赖关系的测试用例需要按顺序执行。

另一个考虑是资源管理。并行执行需要合理的资源分配,避免资源竞争和过度消耗。

增量生成

增量生成是性能优化的另一个策略。Agent 可以只生成新代码或修改代码的测试用例,而不是重新生成所有测试用例。

增量生成需要跟踪代码变更历史,识别受影响的函数和模块。Agent 可以基于代码变更的 diff 信息,确定需要重新生成测试的范围。

另一个考虑是测试用例的缓存。对于未变更的代码,Agent 可以复用已有的测试用例,避免重复生成。

符号执行优化

符号执行是测试生成的关键技术,但也存在性能问题。Agent 应该优化符号执行的效率。

路径剪枝是符号执行优化的一个策略。Agent 可以剪枝不太可能触发的路径,只关注重要的路径,避免路径爆炸。

另一个优化是约束求解优化。Agent 可以使用高效的约束求解器,或者采用启发式方法加速约束求解。

参考资源

官方文档和工具

学术论文和研究

  • "Automatic Test Generation: A Survey" - 自动测试生成综述
  • "Symbolic Execution for Testing" - 基于符号执行的测试技术
  • "Test Generation with Machine Learning" - 基于机器学习的测试生成

实践指南

通过合理的设计和实现,测试生成 Agent 能够显著提升测试开发的效率和质量,成为开发团队的强大助手。