""" 性能对比测试 - 原版 vs Numba优化版 测试各个优化函数的性能提升效果,并生成详细的对比报告。 """ import os import sys import pickle import time import numpy as np # 添加 src 路径 sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) # 导入原版函数 from converging_triangle import ( pivots_fractal, pivots_fractal_hybrid, fit_boundary_anchor, calc_fitting_adherence, calc_boundary_utilization, calc_breakout_strength, ) # 导入优化版函数 from converging_triangle_optimized import ( pivots_fractal_optimized, pivots_fractal_hybrid_optimized, fit_boundary_anchor_optimized, calc_fitting_adherence_optimized, calc_boundary_utilization_optimized, calc_breakout_strength_optimized, ) class FakeModule: """空壳模块,绕过 model 依赖""" ndarray = np.ndarray def load_pkl(pkl_path: str) -> dict: """加载 pkl 文件""" sys.modules['model'] = FakeModule() sys.modules['model.index_info'] = FakeModule() with open(pkl_path, 'rb') as f: data = pickle.load(f) return data def load_test_data(data_dir: str, n_stocks: int = 10, n_days: int = 500): """加载测试数据""" print(f"加载测试数据 (stocks={n_stocks}, days={n_days})...") high_data = load_pkl(os.path.join(data_dir, "high.pkl")) low_data = load_pkl(os.path.join(data_dir, "low.pkl")) high_mtx = high_data["mtx"][:n_stocks, -n_days:] low_mtx = low_data["mtx"][:n_stocks, -n_days:] print(f" 数据形状: {high_mtx.shape}") return high_mtx, low_mtx def benchmark_function(func, name, *args, n_iterations=100, warmup=5): """ 基准测试单个函数 Args: func: 要测试的函数 name: 函数名称 *args: 函数参数 n_iterations: 迭代次数 warmup: 预热次数 Returns: (avg_time_ms, result): 平均耗时(毫秒)和函数结果 """ # 预热(对于numba很重要) for _ in range(warmup): result = func(*args) # 正式测试 start_time = time.time() for _ in range(n_iterations): result = func(*args) elapsed = time.time() - start_time avg_time_ms = (elapsed / n_iterations) * 1000 return avg_time_ms, result def compare_functions(original_func, optimized_func, func_name, *args, n_iterations=100): """对比两个函数的性能""" print(f"\n{'='*80}") print(f"测试: {func_name}") print(f"{'='*80}") # 测试原版 print(f"\n[1] 原版函数...") original_time, original_result = benchmark_function( original_func, f"{func_name}_original", *args, n_iterations=n_iterations ) print(f" 平均耗时: {original_time:.4f} 毫秒/次") # 测试优化版 print(f"\n[2] Numba优化版...") optimized_time, optimized_result = benchmark_function( optimized_func, f"{func_name}_optimized", *args, n_iterations=n_iterations ) print(f" 平均耗时: {optimized_time:.4f} 毫秒/次") # 计算加速比 speedup = original_time / optimized_time if optimized_time > 0 else 0 improvement = ((original_time - optimized_time) / original_time * 100) if original_time > 0 else 0 print(f"\n[3] 性能对比:") print(f" 加速比: {speedup:.2f}x") print(f" 性能提升: {improvement:.1f}%") print(f" 时间节省: {original_time - optimized_time:.4f} 毫秒/次") # 验证结果一致性 print(f"\n[4] 结果验证:") if isinstance(original_result, tuple): for i, (orig, opt) in enumerate(zip(original_result, optimized_result)): if isinstance(orig, np.ndarray): match = np.allclose(orig, opt, rtol=1e-5, atol=1e-8) print(f" 输出 {i+1} (数组): {'[OK] 一致' if match else '[ERR] 不一致'}") if not match and len(orig) > 0 and len(opt) > 0: print(f" 原版: shape={orig.shape}, sample={orig[:3]}") print(f" 优化: shape={opt.shape}, sample={opt[:3]}") else: match = abs(orig - opt) < 1e-6 print(f" 输出 {i+1} (标量): {'[OK] 一致' if match else '[ERR] 不一致'} (原={orig:.6f}, 优={opt:.6f})") else: if isinstance(original_result, np.ndarray): match = np.allclose(original_result, optimized_result, rtol=1e-5, atol=1e-8) print(f" 结果 (数组): {'[OK] 一致' if match else '[ERR] 不一致'}") else: match = abs(original_result - optimized_result) < 1e-6 print(f" 结果 (标量): {'[OK] 一致' if match else '[ERR] 不一致'}") return { "name": func_name, "original_time": original_time, "optimized_time": optimized_time, "speedup": speedup, "improvement": improvement, } def main(): """主测试流程""" print("=" * 80) print("收敛三角形检测 - 原版 vs Numba优化版 性能对比") print("=" * 80) # 配置 DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") # 加载数据 high_mtx, low_mtx = load_test_data(DATA_DIR, n_stocks=10, n_days=500) # 选择一个样本股票 sample_idx = 0 high = high_mtx[sample_idx, :] low = low_mtx[sample_idx, :] # 去除NaN valid_mask = ~(np.isnan(high) | np.isnan(low)) high = high[valid_mask] low = low[valid_mask] print(f"\n样本数据长度: {len(high)}") # 测试参数 k = 15 flexible_zone = 5 n_iterations = 100 results = [] # ======================================================================== # 测试 1: 枢轴点检测(标准方法) # ======================================================================== result = compare_functions( pivots_fractal, pivots_fractal_optimized, "pivots_fractal", high, low, k, n_iterations=n_iterations ) results.append(result) # ======================================================================== # 测试 2: 枢轴点检测(混合方法) # ======================================================================== result = compare_functions( pivots_fractal_hybrid, pivots_fractal_hybrid_optimized, "pivots_fractal_hybrid", high, low, k, flexible_zone, n_iterations=n_iterations ) results.append(result) # ======================================================================== # 测试 3: 锚点拟合 # ======================================================================== # 先获取枢轴点 ph, pl = pivots_fractal(high, low, k=k) if len(ph) >= 5: pivot_indices = ph[:10] pivot_values = high[pivot_indices] # 测试上沿拟合 result = compare_functions( lambda pi, pv, ap: fit_boundary_anchor( pi, pv, ap, mode="upper", window_start=0, window_end=len(ap)-1 ), lambda pi, pv, ap: fit_boundary_anchor_optimized( pi, pv, ap, mode="upper", window_start=0, window_end=len(ap)-1 ), "fit_boundary_anchor (upper)", pivot_indices, pivot_values, high, n_iterations=n_iterations ) results.append(result) if len(pl) >= 5: pivot_indices = pl[:10] pivot_values = low[pivot_indices] # 测试下沿拟合 result = compare_functions( lambda pi, pv, ap: fit_boundary_anchor( pi, pv, ap, mode="lower", window_start=0, window_end=len(ap)-1 ), lambda pi, pv, ap: fit_boundary_anchor_optimized( pi, pv, ap, mode="lower", window_start=0, window_end=len(ap)-1 ), "fit_boundary_anchor (lower)", pivot_indices, pivot_values, low, n_iterations=n_iterations ) results.append(result) # ======================================================================== # 测试 4: 拟合贴合度计算 # ======================================================================== if len(ph) >= 5: pivot_indices = ph[:10] pivot_values = high[pivot_indices] slope, intercept = 0.01, 100.0 result = compare_functions( calc_fitting_adherence, calc_fitting_adherence_optimized, "calc_fitting_adherence", pivot_indices, pivot_values, slope, intercept, n_iterations=n_iterations ) results.append(result) # ======================================================================== # 测试 5: 边界利用率计算 # ======================================================================== upper_slope, upper_intercept = -0.02, 120.0 lower_slope, lower_intercept = 0.02, 80.0 start, end = 0, len(high) - 1 result = compare_functions( calc_boundary_utilization, calc_boundary_utilization_optimized, "calc_boundary_utilization", high, low, upper_slope, upper_intercept, lower_slope, lower_intercept, start, end, n_iterations=n_iterations ) results.append(result) # ======================================================================== # 测试 6: 突破强度计算 # ======================================================================== result = compare_functions( calc_breakout_strength, calc_breakout_strength_optimized, "calc_breakout_strength", 100.0, 105.0, 95.0, 1.5, 0.6, 0.8, 0.7, n_iterations=n_iterations ) results.append(result) # ======================================================================== # 总结报告 # ======================================================================== print("\n\n") print("=" * 80) print("性能对比总结") print("=" * 80) print(f"\n{'函数名':<35} {'原版(ms)':<12} {'优化(ms)':<12} {'加速比':<10} {'提升':<10}") print("-" * 80) total_original = 0 total_optimized = 0 for r in results: print(f"{r['name']:<35} {r['original_time']:<12.4f} {r['optimized_time']:<12.4f} " f"{r['speedup']:<10.2f}x {r['improvement']:<10.1f}%") total_original += r['original_time'] total_optimized += r['optimized_time'] print("-" * 80) overall_speedup = total_original / total_optimized if total_optimized > 0 else 0 overall_improvement = ((total_original - total_optimized) / total_original * 100) if total_original > 0 else 0 print(f"{'总计':<35} {total_original:<12.4f} {total_optimized:<12.4f} " f"{overall_speedup:<10.2f}x {overall_improvement:<10.1f}%") # 估算全量数据性能提升 print("\n" + "=" * 80) print("全量数据性能估算") print("=" * 80) # 从之前的测试结果,我们知道全量数据(108只股票 × 500天)需要约30.83秒 baseline_time = 30.83 # 秒 estimated_time = baseline_time / overall_speedup time_saved = baseline_time - estimated_time print(f"\n基于当前加速比 {overall_speedup:.2f}x 估算:") print(f" 原版耗时: {baseline_time:.2f} 秒 ({baseline_time/60:.2f} 分钟)") print(f" 优化后耗时: {estimated_time:.2f} 秒 ({estimated_time/60:.2f} 分钟)") print(f" 节省时间: {time_saved:.2f} 秒 ({time_saved/60:.2f} 分钟)") print(f" 性能提升: {overall_improvement:.1f}%") print("\n" + "=" * 80) print("建议:") print(" 1. 如果加速比 > 2x,建议切换到优化版本") print(" 2. 运行完整的集成测试验证正确性") print(" 3. 使用 cProfile 分析优化版的新瓶颈") print("=" * 80) if __name__ == "__main__": main()