测试生成 Agent:测试用例设计与自动化测试
测试生成 Agent 通过代码理解、符号执行和边界分析,自动生成高质量测试用例,提升测试覆盖率,降低测试维护成本,让开发者能够更专注于业务逻辑实现。
测试生成 Agent:测试用例设计与自动化测试
概述与动机
在软件开发过程中,测试是保证代码质量和功能正确性的关键环节。然而,编写高质量的测试用例往往比实现业务逻辑还要耗时耗力。开发者需要在正常用例、边界用例、异常用例之间找到平衡,确保测试既能覆盖主要功能,又能捕获潜在的边界问题。随着代码复杂度的增长,手动编写和维护测试用件变得越来越困难。
测试生成 Agent 通过结合静态分析、符号执行和机器学习技术,自动化测试用例的生成和维护。它能够理解代码的输入输出契约,识别代码中的分支和边界条件,生成全面且高效的测试用例。更重要的是,Agent 能够随着代码的演进而自动更新测试,保持测试的同步性和有效性。
从业务价值角度看,高质量的测试用例能够显著降低线上故障率,提升系统稳定性。测试生成 Agent 的价值不仅在于减少编写测试的时间,更在于提升测试的质量和覆盖率。它能发现手工测试容易遗漏的边界情况和潜在 bug,提供更高的测试置信度。同时,通过自动化的测试维护,降低了代码重构的成本,让开发者更有信心进行代码优化和重构。
核心概念与架构设计
代码理解与契约推断
测试用例生成的基础是对代码的深度理解。Agent 需要理解代码的功能、输入输出契约、边界条件和异常情况。这种理解不仅包括语法层面的分析,还包括语义层面的推断。
代码理解的核心是构建代码的控制流图和数据流图。控制流图展示代码的执行路径,包括条件分支、循环、异常处理等。数据流图展示数据在代码中的流动,包括变量定义、使用、修改等。通过分析这两个图,Agent 能够识别代码中的关键路径和边界条件。
契约推断是代码理解的高级形式。Agent 能够从代码的实现中推断出输入输出的约束条件,比如函数的参数类型、返回值范围、前提条件和后置条件。这些契约信息对于生成有效的测试用例至关重要。
边界条件识别
边界条件是测试的关键,很多 bug 就隐藏在边界情况中。Agent 需要识别各种类型的边界条件,包括数值边界、字符串边界、集合边界、时间边界等。
数值边界包括最大值、最小值、零值、负值、浮点数边界(如 NaN、Infinity)。字符串边界包括空字符串、单字符、最大长度字符串、特殊字符(如 Unicode、控制字符)。集合边界包括空集合、单元素集合、最大容量集合、重复元素。时间边界包括时间戳零值、最大时间戳、闰秒、时区边界。
边界条件识别的一个挑战是代码可能显式或隐式地假设某些边界。Agent 需要通过代码分析发现这些假设,并生成相应的测试用例。比如,代码中可能有 if (index < array.length) 这样的检查,Agent 应该生成 index = array.length - 1 和 index = array.length 的测试用例。
符号执行与路径覆盖
符号执行是一种程序分析技术,通过将输入视为符号变量而不是具体值,分析程序的所有可能执行路径。Agent 可以利用符号执行自动生成能够覆盖代码中每条路径的测试用例。
符号执行的核心是路径条件(Path Condition)。对于代码中的每个条件分支,符号执行会产生不同的路径条件。通过求解路径条件,Agent 能够生成触发该路径的具体输入值。
路径覆盖的一个挑战是路径爆炸问题。随着代码复杂度的增加,可能的执行路径呈指数级增长。Agent 需要采用启发式方法,优先选择重要的路径(如错误处理路径、边界路径)进行覆盖,而不是盲目地覆盖所有路径。
断言生成策略
测试用例的价值不仅在于输入设计,还在于断言的正确性。Agent 需要生成能够验证代码行为的断言,包括输出断言、状态断言和异常断言。
输出断言验证函数的返回值是否符合预期。Agent 可以通过分析代码逻辑推断输出值,或者通过执行代码并观察输出值来生成断言。状态断言验证函数执行后的状态变化,比如对象属性的修改、全局变量的变化等。异常断言验证函数在特定输入下是否抛出预期的异常。
断言生成的一个挑战是平衡断言的严格性和灵活性。过于严格的断言可能因为实现细节的变化而失效,过于宽松的断言则无法有效验证代码行为。Agent 需要理解代码的契约,生成基于契约的断言,而不是基于实现的断言。
Agent 架构设计
测试生成 Agent 的架构采用流水线设计,从代码分析到测试生成的各个阶段职责清晰。
代码解析模块负责将源代码转换为中间表示,包括 AST、控制流图、数据流图等。契约推断模块分析代码的输入输出契约,包括类型约束、值域约束、前置条件、后置条件等。边界条件识别模块识别代码中的各种边界条件。路径分析模块分析代码的控制流,识别需要覆盖的关键路径。测试用例生成模块基于契约、边界条件和路径分析,生成测试输入。断言生成模块为每个测试用例生成相应的断言。测试代码生成模块将测试用例和断言转换为可执行的测试代码。测试执行模块执行测试代码,收集测试结果。覆盖率分析模块分析测试覆盖率,识别未覆盖的代码区域。测试优化模块基于覆盖率分析结果,优化测试用例,提高覆盖率。
关键技术实现
代码解析与契约推断
代码解析是测试生成的基础。我们实现一个代码解析器和契约推断器。
import ast
import inspect
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from enum import Enum
class ParamType(Enum):
"""参数类型"""
INT = "int"
FLOAT = "float"
STR = "str"
BOOL = "bool"
LIST = "list"
DICT = "dict"
CUSTOM = "custom"
@dataclass
class FunctionContract:
"""函数契约"""
name: str
parameters: List[Dict[str, Any]]
return_type: Optional[ParamType]
preconditions: List[str]
postconditions: List[str]
exceptions: List[str]
class ContractInferencer:
"""契约推断器"""
def __init__(self):
self.type_hints = {
'int': ParamType.INT,
'float': ParamType.FLOAT,
'str': ParamType.STR,
'bool': ParamType.BOOL,
'list': ParamType.LIST,
'dict': ParamType.DICT,
}
def infer_contract(self, func) -> FunctionContract:
"""推断函数契约"""
# 获取函数信息
func_name = func.__name__
func_code = inspect.getsource(func)
# 解析函数
tree = ast.parse(func_code)
func_node = self._find_function_node(tree, func_name)
if not func_node:
raise ValueError(f"无法找到函数节点: {func_name}")
# 推断参数
parameters = self._infer_parameters(func, func_node)
# 推断返回类型
return_type = self._infer_return_type(func, func_node)
# 推断前置条件
preconditions = self._infer_preconditions(func_node)
# 推断后置条件
postconditions = self._infer_postconditions(func_node)
# 推断异常
exceptions = self._infer_exceptions(func_node)
return FunctionContract(
name=func_name,
parameters=parameters,
return_type=return_type,
preconditions=preconditions,
postconditions=postconditions,
exceptions=exceptions
)
def _find_function_node(self, tree: ast.AST, func_name: str) -> Optional[ast.FunctionDef]:
"""查找函数节点"""
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == func_name:
return node
return None
def _infer_parameters(self, func, func_node: ast.FunctionDef) -> List[Dict[str, Any]]:
"""推断参数"""
parameters = []
# 获取类型注解
type_hints = inspect.getfullargspec(func).annotations
# 分析每个参数
for arg in func_node.args.args:
param_name = arg.arg
param_type = self._infer_param_type(arg, type_hints)
# 推断参数约束
constraints = self._infer_param_constraints(func_node, param_name)
parameters.append({
"name": param_name,
"type": param_type,
"constraints": constraints,
"default": self._get_default_value(arg)
})
return parameters
def _infer_param_type(self, arg: ast.arg, type_hints: Dict[str, Any]) -> ParamType:
"""推断参数类型"""
# 检查类型注解
if arg.arg in type_hints:
type_hint = type_hints[arg.arg]
if hasattr(type_hint, '__name__'):
type_name = type_hint.__name__
elif hasattr(type_hint, '_name'):
type_name = type_hint._name
else:
type_name = str(type_hint)
if type_name in self.type_hints:
return self.type_hints[type_name]
# 尝试从默认值推断
if arg.default:
return self._infer_type_from_value(arg.default)
return ParamType.CUSTOM
def _infer_type_from_value(self, node: ast.AST) -> ParamType:
"""从值推断类型"""
if isinstance(node, ast.Num):
return ParamType.INT if isinstance(node.n, int) else ParamType.FLOAT
elif isinstance(node, ast.Str):
return ParamType.STR
elif isinstance(node, ast.NameConstant):
return ParamType.BOOL
elif isinstance(node, ast.List):
return ParamType.LIST
elif isinstance(node, ast.Dict):
return ParamType.DICT
else:
return ParamType.CUSTOM
def _infer_param_constraints(self, func_node: ast.FunctionDef, param_name: str) -> List[str]:
"""推断参数约束"""
constraints = []
# 查找参数检查
for node in ast.walk(func_node):
if isinstance(node, ast.Compare):
# 检查是否是对该参数的比较
if isinstance(node.left, ast.Name) and node.left.id == param_name:
constraints.append(self._generate_constraint(node))
elif any(
isinstance(operand, ast.Name) and operand.id == param_name
for operand in node.comparators
):
constraints.append(self._generate_constraint(node))
return constraints
def _generate_constraint(self, node: ast.Compare) -> str:
"""生成约束表达式"""
import astor
left = astor.to_source(node.left).strip()
ops = ['<' if isinstance(op, ast.Lt) else '>' if isinstance(op, ast.Gt) else
'<=' if isinstance(op, ast.LtE) else '>=' if isinstance(op, ast.GtE) else
'==' if isinstance(op, ast.Eq) else '!=' if isinstance(op, ast.NotEq) else
'is' if isinstance(op, ast.Is) else 'is not' if isinstance(op, ast.IsNot) else
'in' if isinstance(op, ast.In) else 'not in' if isinstance(op, ast.NotIn) else '?'
for op in node.ops]
comparators = [astor.to_source(comp).strip() for comp in node.comparators]
constraint = f"{left} {ops[0]} {comparators[0]}"
return constraint
def _get_default_value(self, arg: ast.arg) -> Any:
"""获取默认值"""
if arg.default:
import astor
return ast.literal_eval(astor.to_source(arg.default).strip())
return None
def _infer_return_type(self, func, func_node: ast.FunctionDef) -> Optional[ParamType]:
"""推断返回类型"""
# 检查返回类型注解
type_hints = inspect.getfullargspec(func).annotations
if 'return' in type_hints:
return_type = type_hints['return']
if hasattr(return_type, '__name__'):
type_name = return_type.__name__
elif hasattr(return_type, '_name'):
type_name = return_type._name
else:
type_name = str(return_type)
if type_name in self.type_hints:
return self.type_hints[type_name]
# 从 return 语句推断
return_nodes = [node for node in ast.walk(func_node) if isinstance(node, ast.Return)]
if return_nodes:
return_node = return_nodes[0]
if return_node.value:
return self._infer_type_from_value(return_node.value)
return None
def _infer_preconditions(self, func_node: ast.FunctionDef) -> List[str]:
"""推断前置条件"""
preconditions = []
# 查找 if 检查
for node in ast.walk(func_node):
if isinstance(node, ast.If):
# 分析条件
condition = self._extract_condition(node.test)
if condition:
preconditions.append(f"确保 {condition}")
return preconditions
def _infer_postconditions(self, func_node: ast.FunctionDef) -> List[str]:
"""推断后置条件"""
postconditions = []
# 查找 return 语句
return_nodes = [node for node in ast.walk(func_node) if isinstance(node, ast.Return)]
if return_nodes:
postconditions.append("返回正确的值")
# 查找状态修改
# (简化版,实际应该分析对象状态变化)
return postconditions
def _infer_exceptions(self, func_node: ast.FunctionDef) -> List[str]:
"""推断异常"""
exceptions = []
# 查找 raise 语句
raise_nodes = [node for node in ast.walk(func_node) if isinstance(node, ast.Raise)]
for raise_node in raise_nodes:
if raise_node.exc:
if isinstance(raise_node.exc, ast.Name):
exceptions.append(raise_node.exc.id)
elif isinstance(raise_node.exc, ast.Call):
if isinstance(raise_node.exc.func, ast.Name):
exceptions.append(raise_node.exc.func.id)
return list(set(exceptions)) # 去重
def _extract_condition(self, node: ast.AST) -> str:
"""提取条件表达式"""
import astor
return astor.to_source(node).strip()
契约推断器分析函数的参数、返回值、前置条件、后置条件和异常,为测试用例生成提供基础信息。
边界条件识别器
边界条件识别器识别各种类型的边界条件,为测试用例生成提供边界值。
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
import math
@dataclass
class BoundaryValue:
"""边界值"""
name: str
value: Any
description: str
class BoundaryConditionIdentifier:
"""边界条件识别器"""
def __init__(self):
# 基本类型的边界值
self.type_boundaries = {
ParamType.INT: [
BoundaryValue("min_int", -2**31, "32位整数最小值"),
BoundaryValue("max_int", 2**31 - 1, "32位整数最大值"),
BoundaryValue("zero", 0, "零值"),
BoundaryValue("negative_one", -1, "负一"),
BoundaryValue("positive_one", 1, "正一"),
],
ParamType.FLOAT: [
BoundaryValue("min_float", float('-inf'), "浮点数负无穷"),
BoundaryValue("max_float", float('inf'), "浮点数正无穷"),
BoundaryValue("zero", 0.0, "浮点数零值"),
BoundaryValue("nan", float('nan'), "非数值"),
BoundaryValue("epsilon", 1e-10, "极小值"),
],
ParamType.STR: [
BoundaryValue("empty_string", "", "空字符串"),
BoundaryValue("single_char", "a", "单字符"),
BoundaryValue("max_length", "a" * 10000, "长字符串"),
BoundaryValue("unicode", "你好世界", "Unicode 字符"),
BoundaryValue("whitespace", " ", "空白字符"),
],
ParamType.BOOL: [
BoundaryValue("true", True, "布尔真值"),
BoundaryValue("false", False, "布尔假值"),
],
ParamType.LIST: [
BoundaryValue("empty_list", [], "空列表"),
BoundaryValue("single_element", [1], "单元素列表"),
BoundaryValue("max_size", list(range(1000)), "大列表"),
BoundaryValue("duplicate_elements", [1, 1, 1], "重复元素"),
],
ParamType.DICT: [
BoundaryValue("empty_dict", {}, "空字典"),
BoundaryValue("single_key", {"key": "value"}, "单键字典"),
BoundaryValue("max_size", {str(i): i for i in range(1000)}, "大字典"),
],
}
def identify_boundaries(self, param_type: ParamType, constraints: List[str]) -> List[BoundaryValue]:
"""识别边界值"""
boundaries = []
# 获取基本类型的边界值
if param_type in self.type_boundaries:
boundaries.extend(self.type_boundaries[param_type])
# 基于约束识别额外边界
constraint_boundaries = self._identify_constraint_boundaries(constraints)
boundaries.extend(constraint_boundaries)
return boundaries
def _identify_constraint_boundaries(self, constraints: List[str]) -> List[BoundaryValue]:
"""基于约束识别边界值"""
boundaries = []
for constraint in constraints:
# 解析约束表达式
if '<=' in constraint:
parts = constraint.split('<=')
if len(parts) == 2:
try:
value = int(parts[1].strip())
boundaries.append(BoundaryValue("constraint_max", value, f"约束最大值: {value}"))
boundaries.append(BoundaryValue("constraint_max_minus_one", value - 1, f"约束最大值-1: {value-1}"))
boundaries.append(BoundaryValue("constraint_max_plus_one", value + 1, f"约束最大值+1: {value+1}"))
except ValueError:
pass
elif '>=' in constraint:
parts = constraint.split('>=')
if len(parts) == 2:
try:
value = int(parts[1].strip())
boundaries.append(BoundaryValue("constraint_min", value, f"约束最小值: {value}"))
boundaries.append(BoundaryValue("constraint_min_minus_one", value - 1, f"约束最小值-1: {value-1}"))
boundaries.append(BoundaryValue("constraint_min_plus_one", value + 1, f"约束最小值+1: {value+1}"))
except ValueError:
pass
return boundaries
边界条件识别器识别各种类型的边界值,包括基本类型的边界值和基于约束的边界值。
测试用例生成器
测试用例生成器基于契约和边界值生成测试用例。
import itertools
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
@dataclass
class TestCase:
"""测试用例"""
name: str
inputs: Dict[str, Any]
expected_output: Optional[Any] = None
expected_exception: Optional[str] = None
description: str = ""
class TestCaseGenerator:
"""测试用例生成器"""
def __init__(self, max_test_cases: int = 100):
self.max_test_cases = max_test_cases
self.boundary_identifier = BoundaryConditionIdentifier()
def generate_test_cases(self, contract: FunctionContract) -> List[TestCase]:
"""生成测试用例"""
test_cases = []
# 1. 生成正常用例
normal_cases = self._generate_normal_cases(contract)
test_cases.extend(normal_cases)
# 2. 生成边界用例
boundary_cases = self._generate_boundary_cases(contract)
test_cases.extend(boundary_cases)
# 3. 生成异常用例
exception_cases = self._generate_exception_cases(contract)
test_cases.extend(exception_cases)
# 限制测试用例数量
return test_cases[:self.max_test_cases]
def _generate_normal_cases(self, contract: FunctionContract) -> List[TestCase]:
"""生成正常用例"""
test_cases = []
if not contract.parameters:
# 无参数函数
test_cases.append(TestCase(
name="test_normal",
inputs={},
description="正常执行"
))
return test_cases
# 生成组合测试用例
param_combinations = self._generate_param_combinations(contract.parameters)
for i, params in enumerate(param_combinations):
test_cases.append(TestCase(
name=f"test_normal_{i}",
inputs=params,
description=f"正常用例 {i}"
))
return test_cases
def _generate_boundary_cases(self, contract: FunctionContract) -> List[TestCase]:
"""生成边界用例"""
test_cases = []
for param in contract.parameters:
param_name = param["name"]
param_type = param["type"]
constraints = param["constraints"]
# 识别边界值
boundaries = self.boundary_identifier.identify_boundaries(param_type, constraints)
for boundary in boundaries:
test_cases.append(TestCase(
name=f"test_boundary_{param_name}_{boundary.name}",
inputs={param_name: boundary.value},
description=boundary.description
))
return test_cases
def _generate_exception_cases(self, contract: FunctionContract) -> List[TestCase]:
"""生成异常用例"""
test_cases = []
for exception in contract.exceptions:
# 生成触发该异常的测试用例
test_case = self._generate_exception_test_case(contract, exception)
if test_case:
test_cases.append(test_case)
return test_cases
def _generate_exception_test_case(self, contract: FunctionContract, exception: str) -> Optional[TestCase]:
"""生成特定异常的测试用例"""
# 简化版:基于参数约束生成异常用例
for param in contract.parameters:
param_name = param["name"]
param_type = param["type"]
# 根据异常类型生成输入
if "ValueError" in exception or "TypeError" in exception:
# 生成类型错误的输入
if param_type == ParamType.INT:
return TestCase(
name=f"test_exception_{exception}",
inputs={param_name: "invalid"},
expected_exception=exception,
description=f"触发 {exception}"
)
elif "IndexError" in exception or "KeyError" in exception:
# 生成越界或不存在键的输入
if param_type == ParamType.LIST:
return TestCase(
name=f"test_exception_{exception}",
inputs={param_name: []},
expected_exception=exception,
description=f"触发 {exception}"
)
return None
def _generate_param_combinations(self, parameters: List[Dict]) -> List[Dict]:
"""生成参数组合"""
# 为每个参数生成一些典型值
param_values = []
for param in parameters:
param_name = param["name"]
param_type = param["type"]
default = param.get("default")
values = self._generate_typical_values(param_type, default)
param_values.append([(param_name, value) for value in values])
# 生成笛卡尔积
combinations = []
for combination in itertools.product(*param_values):
test_input = {}
for key, value in combination:
test_input[key] = value
combinations.append(test_input)
return combinations
def _generate_typical_values(self, param_type: ParamType, default: Optional[Any] = None) -> List[Any]:
"""生成典型值"""
if default is not None:
return [default]
typical_values = {
ParamType.INT: [0, 1, 100, -1, -100],
ParamType.FLOAT: [0.0, 1.0, 100.0, -1.0, 0.5],
ParamType.STR: ["", "test", "你好世界", "a" * 100],
ParamType.BOOL: [True, False],
ParamType.LIST: [[], [1], [1, 2, 3]],
ParamType.DICT: [{}, {"key": "value"}],
}
return typical_values.get(param_type, [None])
测试用例生成器生成正常用例、边界用例和异常用例,覆盖代码的主要路径和边界情况。
断言生成器
断言生成器为测试用例生成相应的断言。
from typing import Dict, List, Optional, Any
import ast
import inspect
class AssertionGenerator:
"""断言生成器"""
def __init__(self):
pass
def generate_assertions(self,
contract: FunctionContract,
test_case: TestCase,
execution_result: Any) -> List[str]:
"""生成断言"""
assertions = []
# 1. 输出断言
if test_case.expected_output is not None:
assertions.append(f"assert result == {repr(test_case.expected_output)}")
elif execution_result is not None:
# 基于执行结果生成断言
assertions.append(f"assert result == {repr(execution_result)}")
# 2. 异常断言
if test_case.expected_exception:
assertions.append(f"with pytest.raises({test_case.expected_exception}):")
# 3. 后置条件断言
for postcondition in contract.postconditions:
assertions.append(f"# {postcondition}")
return assertions
def generate_mock_assertions(self, contract: FunctionContract) -> List[str]:
"""生成模拟断言(用于示例)"""
assertions = []
# 基于返回类型生成断言
if contract.return_type:
if contract.return_type == ParamType.INT:
assertions.append("assert isinstance(result, int)")
assertions.append("assert result >= 0")
elif contract.return_type == ParamType.STR:
assertions.append("assert isinstance(result, str)")
assertions.append("assert len(result) > 0")
# 基于后置条件生成断言
for postcondition in contract.postconditions:
assertions.append(f"# {postcondition}")
return assertions
断言生成器基于函数契约和测试结果生成相应的断言。
完整的测试生成 Agent
将所有组件组合起来,创建一个完整的测试生成 Agent。
import asyncio
from typing import Dict, List, Optional, Any
import os
import subprocess
class TestGenerationAgent:
"""测试生成 Agent"""
def __init__(self, llm_api_key: str, llm_model: str = "gpt-4"):
self.contract_inferencer = ContractInferencer()
self.test_case_generator = TestCaseGenerator()
self.assertion_generator = AssertionGenerator()
self.code_generator = TestCodeGenerator()
async def generate_tests(self,
func,
output_file: str = "test_generated.py") -> str:
"""生成测试代码"""
func_name = func.__name__
print(f"开始为函数 {func_name} 生成测试...")
# 1. 推断函数契约
print("步骤 1/5: 推断函数契约...")
contract = self.contract_inferencer.infer_contract(func)
# 2. 生成测试用例
print("步骤 2/5: 生成测试用例...")
test_cases = self.test_case_generator.generate_test_cases(contract)
# 3. 生成断言
print("步骤 3/5: 生成断言...")
for test_case in test_cases:
assertions = self.assertion_generator.generate_mock_assertions(contract)
test_case.description += f" 断言: {', '.join(assertions[:2])}"
# 4. 生成测试代码
print("步骤 4/5: 生成测试代码...")
test_code = self.code_generator.generate_test_code(func, contract, test_cases)
# 5. 保存测试代码
print("步骤 5/5: 保存测试代码...")
with open(output_file, 'w', encoding='utf-8') as f:
f.write(test_code)
print(f"测试代码已保存到 {output_file}")
return test_code
def run_tests(self, test_file: str) -> Dict[str, Any]:
"""运行测试"""
try:
result = subprocess.run(
['python', '-m', 'pytest', test_file, '-v'],
capture_output=True,
text=True
)
return {
"success": result.returncode == 0,
"output": result.stdout,
"errors": result.stderr
}
except Exception as e:
return {
"success": False,
"output": "",
"errors": str(e)
}
class TestCodeGenerator:
"""测试代码生成器"""
def __init__(self):
pass
def generate_test_code(self,
func,
contract: FunctionContract,
test_cases: List[TestCase]) -> str:
"""生成测试代码"""
code = self._generate_header()
code += self._generate_imports()
code += self._generate_test_class(func, contract, test_cases)
code += self._generate_footer()
return code
def _generate_header(self) -> str:
"""生成文件头"""
return '''# 自动生成的测试代码
# 由 TestGenerationAgent 生成
# 请根据需要修改和完善
import pytest
import sys
import os
# 添加源码路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
'''
def _generate_imports(self) -> str:
"""生成导入语句"""
return '''from src.main import * # 导入被测试的函数
from typing import Any
'''
def _generate_test_class(self,
func,
contract: FunctionContract,
test_cases: List[TestCase]) -> str:
"""生成测试类"""
func_name = func.__name__
code = f'''
class Test{func_name.capitalize()}:
"""测试类: {func_name}"""
'''
# 生成测试方法
for test_case in test_cases:
code += self._generate_test_method(func, test_case)
code += '\n'
return code
def _generate_test_method(self, func, test_case: TestCase) -> str:
"""生成测试方法"""
func_name = func.__name__
code = f''' def {test_case.name}(self):
"""
{test_case.description}
"""
'''
# 生成参数设置
for key, value in test_case.inputs.items():
code += f' {key} = {repr(value)}\n'
# 生成函数调用
params_str = ', '.join(test_case.inputs.keys())
code += f'\n result = {func_name}({params_str})\n\n'
# 生成断言
code += ' # 断言\n'
if test_case.expected_exception:
code += f' # 注意:此测试用例预期抛出 {test_case.expected_exception}\n'
else:
code += ' assert result is not None\n'
code += ' # 添加更多断言...\n'
return code
def _generate_footer(self) -> str:
"""生成文件尾"""
return '''
if __name__ == "__main__":
pytest.main([__file__, "-v"])
'''
# 示例函数
def calculate_discount(price: int, is_member: bool) -> float:
"""计算折扣价格
Args:
price: 原价,必须 >= 0
is_member: 是否是会员
Returns:
折扣后的价格
"""
if price < 0:
raise ValueError("价格不能为负数")
if is_member:
return price * 0.8 # 会员8折
else:
return price * 0.9 # 非会员9折
def find_index(arr: List[int], target: int) -> int:
"""在数组中查找目标索引
Args:
arr: 数组
target: 目标值
Returns:
目标值的索引,如果不存在返回 -1
"""
if target < 0:
raise ValueError("目标值不能为负数")
for i, value in enumerate(arr):
if value == target:
return i
return -1
# 使用示例
async def main():
# 创建测试生成 Agent
agent = TestGenerationAgent(llm_api_key="your-api-key-here")
# 为示例函数生成测试
print("=== 为 calculate_discount 生成测试 ===")
discount_code = await agent.generate_tests(
calculate_discount,
output_file="/Users/liuyutao/Desktop/workspace/test_calculate_discount.py"
)
print("\n=== 为 find_index 生成测试 ===")
index_code = await agent.generate_tests(
find_index,
output_file="/Users/liuyutao/Desktop/workspace/test_find_index.py"
)
# 运行测试(简化版,实际需要被测试函数)
# result = agent.run_tests("/Users/liuyutao/Desktop/workspace/test_calculate_discount.py")
# print(f"\n=== 测试运行结果 ===")
# print(f"成功: {result['success']}")
# print(f"输出:\n{result['output']}")
# if result['errors']:
# print(f"错误:\n{result['errors']}")
if __name__ == "__main__":
asyncio.run(main())
这个完整的测试生成 Agent 整合了契约推断、测试用例生成、断言生成和测试代码生成等功能,提供端到端的测试生成支持。
最佳实践与常见陷阱
代码契约注解
代码契约注解是测试生成的基础。清晰的契约注解能够帮助 Agent 准确理解代码的输入输出关系,生成更准确的测试用例。
最佳实践包括:
- 使用类型注解:为所有函数参数和返回值添加类型注解,帮助 Agent 理解数据类型约束。
- 编写文档字符串:详细描述函数的功能、参数要求、返回值含义和异常情况。
- 添加契约检查:使用契约库(如
deal、icontract)明确表示前置条件和后置条件。
常见陷阱是契约注解不完整或不准确。这会导致 Agent 生成无效的测试用例或遗漏重要的测试场景。最佳实践是在代码编写时就完善契约注解,而不是事后补充。
测试用例多样性
测试用例的多样性直接影响测试覆盖率。Agent 应该生成多种类型的测试用例,包括正常用例、边界用例、异常用例、组合用例等。
边界用例特别重要,因为很多 bug 就隐藏在边界情况中。Agent 应该识别各种类型的边界条件,包括数值边界、字符串边界、集合边界、时间边界等。
另一个重要考虑是组合用例。多个参数的不同组合可能产生不同的行为,Agent 应该生成有效的参数组合测试用例,而不是只测试单一参数的变化。
断言的质量
断言的质量直接影响测试的有效性。Agent 应该生成基于契约的断言,而不是基于实现的断言。
基于契约的断言关注函数的输入输出关系,而不是具体的实现细节。这样的断言更稳定,不容易因为实现细节的变化而失效。
另一个考虑是断言的完整性。Agent 应该生成足够多的断言,验证函数的各个方面,包括返回值、状态变化、异常等。同时,断言应该易于理解和维护。
测试维护
测试维护是测试生成的一个重要考虑。代码演进时,测试也需要相应更新。Agent 应该能够识别代码变更,自动更新受影响的测试用例。
增量更新是测试维护的一个关键策略。Agent 应该只更新受影响的测试用例,而不是重新生成所有测试用例,这样可以保持测试的稳定性。
另一个考虑是测试的独立性。测试用例之间应该相互独立,一个测试用例的失败不应该影响其他测试用例的执行。
性能优化考虑
测试用例选择
测试用例选择是性能优化的关键。Agent 不应该生成所有可能的测试用例,而是选择最重要的测试用例。
重要性排序是测试用例选择的一个策略。Agent 可以根据代码覆盖率、边界重要性、异常概率等因素对测试用例排序,优先选择重要的测试用例。
另一个优化是测试用例去重。不同的测试路径可能产生相似的测试用例,Agent 应该识别并去除重复的测试用例。
并行执行
测试执行可以并行化,特别是对于独立的测试用例。Agent 可以并行执行多个测试用例,提升执行效率。
并行执行需要考虑测试用例之间的依赖关系。只有相互独立的测试用例才能并行执行,有依赖关系的测试用例需要按顺序执行。
另一个考虑是资源管理。并行执行需要合理的资源分配,避免资源竞争和过度消耗。
增量生成
增量生成是性能优化的另一个策略。Agent 可以只生成新代码或修改代码的测试用例,而不是重新生成所有测试用例。
增量生成需要跟踪代码变更历史,识别受影响的函数和模块。Agent 可以基于代码变更的 diff 信息,确定需要重新生成测试的范围。
另一个考虑是测试用例的缓存。对于未变更的代码,Agent 可以复用已有的测试用例,避免重复生成。
符号执行优化
符号执行是测试生成的关键技术,但也存在性能问题。Agent 应该优化符号执行的效率。
路径剪枝是符号执行优化的一个策略。Agent 可以剪枝不太可能触发的路径,只关注重要的路径,避免路径爆炸。
另一个优化是约束求解优化。Agent 可以使用高效的约束求解器,或者采用启发式方法加速约束求解。
参考资源
官方文档和工具
- Pytest - Python 测试框架
- unittest - Python 内置测试框架
- Pylint - Python 代码分析工具
- Coverage.py - Python 代码覆盖率工具
学术论文和研究
- "Automatic Test Generation: A Survey" - 自动测试生成综述
- "Symbolic Execution for Testing" - 基于符号执行的测试技术
- "Test Generation with Machine Learning" - 基于机器学习的测试生成
实践指南
- Effective Testing with Pytest - 高效的 Pytest 测试实践
- Test-Driven Development - 测试驱动开发指南
- Testing Best Practices - Google 测试最佳实践
通过合理的设计和实现,测试生成 Agent 能够显著提升测试开发的效率和质量,成为开发团队的强大助手。