""" 完整流水线性能测试 - 验证Numba优化效果 此脚本模拟完整的批量检测流程,对比原版和优化版的性能。 """ import os import sys import pickle import time import cProfile import pstats from io import StringIO import numpy as np import pandas as pd # 添加 src 路径 sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) from converging_triangle import ( ConvergingTriangleParams, detect_converging_triangle_batch as detect_batch_original, ) class FakeModule: """空壳模块,绕过 model 依赖""" ndarray = np.ndarray def load_pkl(pkl_path: str) -> dict: """加载 pkl 文件""" sys.modules['model'] = FakeModule() sys.modules['model.index_info'] = FakeModule() with open(pkl_path, 'rb') as f: data = pickle.load(f) return data def load_test_data(data_dir: str, n_stocks: int = None, n_days: int = None): """加载测试数据""" print(f"加载数据...") open_data = load_pkl(os.path.join(data_dir, "open.pkl")) high_data = load_pkl(os.path.join(data_dir, "high.pkl")) low_data = load_pkl(os.path.join(data_dir, "low.pkl")) close_data = load_pkl(os.path.join(data_dir, "close.pkl")) volume_data = load_pkl(os.path.join(data_dir, "volume.pkl")) # 截取子集(如果指定) if n_stocks and n_days: open_mtx = open_data["mtx"][:n_stocks, -n_days:] high_mtx = high_data["mtx"][:n_stocks, -n_days:] low_mtx = low_data["mtx"][:n_stocks, -n_days:] close_mtx = close_data["mtx"][:n_stocks, -n_days:] volume_mtx = volume_data["mtx"][:n_stocks, -n_days:] dates = close_data["dtes"][-n_days:] tkrs = close_data["tkrs"][:n_stocks] tkrs_name = close_data["tkrs_name"][:n_stocks] else: # 全量数据 open_mtx = open_data["mtx"] high_mtx = high_data["mtx"] low_mtx = low_data["mtx"] close_mtx = close_data["mtx"] volume_mtx = volume_data["mtx"] dates = close_data["dtes"] tkrs = close_data["tkrs"] tkrs_name = close_data["tkrs_name"] print(f" 数据形状: {close_mtx.shape}") print(f" 日期范围: {dates[0]} ~ {dates[-1]}") return open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name def test_pipeline( open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, params: ConvergingTriangleParams, use_optimized: bool = False, profile: bool = False ): """ 测试完整流水线 Args: use_optimized: 是否使用优化版本 profile: 是否生成profile """ print(f"\n{'='*80}") print(f"测试: {'Numba优化版' if use_optimized else '原版'}") print(f"{'='*80}") n_stocks, n_days = close_mtx.shape window = params.window # 计算测试范围 start_day = window - 1 # 找到最后有效的数据日 any_valid = np.any(~np.isnan(close_mtx), axis=0) valid_day_idx = np.where(any_valid)[0] end_day = valid_day_idx[-1] if len(valid_day_idx) > 0 else n_days - 1 total_points = n_stocks * (end_day - start_day + 1) print(f"\n配置:") print(f" 股票数: {n_stocks}") print(f" 交易日: {n_days}") print(f" 窗口大小: {window}") print(f" 检测点数: {total_points}") print(f" 使用优化: {'是' if use_optimized else '否'}") # 如果使用优化版,导入优化模块并替换函数 if use_optimized: try: print("\n导入Numba优化模块...") import converging_triangle from converging_triangle_optimized import ( pivots_fractal_optimized, pivots_fractal_hybrid_optimized, fit_boundary_anchor_optimized, calc_geometry_score_optimized, calc_activity_score_optimized, calc_breakout_strength_optimized, ) # 猴子补丁替换 converging_triangle.pivots_fractal = pivots_fractal_optimized converging_triangle.pivots_fractal_hybrid = pivots_fractal_hybrid_optimized converging_triangle.fit_boundary_anchor = fit_boundary_anchor_optimized converging_triangle.calc_geometry_score = calc_geometry_score_optimized converging_triangle.calc_activity_score = calc_activity_score_optimized converging_triangle.calc_breakout_strength = calc_breakout_strength_optimized print(" [OK] Numba优化已启用") # 预热编译 print("\n预热Numba编译...") sample_high = high_mtx[0, :window] sample_low = low_mtx[0, :window] valid_mask = ~(np.isnan(sample_high) | np.isnan(sample_low)) if np.sum(valid_mask) >= window: sample_high = sample_high[valid_mask] sample_low = sample_low[valid_mask] _ = pivots_fractal_optimized(sample_high, sample_low, k=params.pivot_k) print(" [OK] 预热完成") except Exception as e: print(f" [ERROR] 无法启用优化: {e}") return None, 0 # 运行检测 print("\n开始批量检测...") if profile: profiler = cProfile.Profile() profiler.enable() start_time = time.time() df = detect_batch_original( open_mtx=open_mtx, high_mtx=high_mtx, low_mtx=low_mtx, close_mtx=close_mtx, volume_mtx=volume_mtx, params=params, start_day=start_day, end_day=end_day, only_valid=True, verbose=False, ) elapsed = time.time() - start_time if profile: profiler.disable() # 打印profile print("\n" + "-" * 80) print("Profile Top 20:") print("-" * 80) s = StringIO() ps = pstats.Stats(profiler, stream=s).sort_stats('cumulative') ps.print_stats(20) print(s.getvalue()) # 统计结果 print(f"\n{'='*80}") print("性能统计") print(f"{'='*80}") print(f"\n总耗时: {elapsed:.2f} 秒 ({elapsed/60:.2f} 分钟)") print(f"处理点数: {total_points}") print(f"平均速度: {total_points/elapsed:.1f} 点/秒") print(f"单点耗时: {elapsed/total_points*1000:.3f} 毫秒/点") if len(df) > 0: valid_count = len(df[df['is_valid'] == True]) print(f"\n检测结果:") print(f" 有效三角形: {valid_count}") print(f" 检出率: {valid_count/total_points*100:.2f}%") if 'breakout_strength_up' in df.columns: strong_up = (df['breakout_strength_up'] > 0.3).sum() strong_down = (df['breakout_strength_down'] > 0.3).sum() print(f" 高强度向上突破 (>0.3): {strong_up}") print(f" 高强度向下突破 (>0.3): {strong_down}") return df, elapsed def compare_results(df_original, df_optimized): """对比两个版本的输出结果""" print(f"\n{'='*80}") print("结果一致性验证") print(f"{'='*80}") if df_original is None or df_optimized is None: print("\n[ERROR] 无法对比:某个版本未成功运行") return False # 检查记录数 if len(df_original) != len(df_optimized): print(f"\n[WARNING] 记录数不一致:") print(f" 原版: {len(df_original)}") print(f" 优化: {len(df_optimized)}") return False print(f"\n记录数: {len(df_original)} (一致 [OK])") # 检查数值列 numeric_cols = [ 'breakout_strength_up', 'breakout_strength_down', 'price_score_up', 'price_score_down', 'convergence_score', 'volume_score', 'geometry_score', 'upper_slope', 'lower_slope', 'width_ratio', 'touches_upper', 'touches_lower', 'apex_x' ] numeric_cols = [c for c in numeric_cols if c in df_original.columns] print(f"\n数值列对比:") print(f"{'列名':<30} {'最大差异':<15} {'平均差异':<15} {'状态':<10}") print("-" * 70) all_match = True for col in numeric_cols: diff = (df_original[col] - df_optimized[col]).abs() max_diff = diff.max() mean_diff = diff.mean() # 判断标准:最大差异 < 1e-6 match = max_diff < 1e-6 status = "[OK]" if match else "[ERR]" print(f"{col:<30} {max_diff:<15.10f} {mean_diff:<15.10f} {status:<10}") if not match: all_match = False print("-" * 70) if all_match: print("\n[结论] 所有数值列完全一致 (误差 < 1e-6) ✓") else: print("\n[结论] 发现不一致的列,请检查优化实现") return all_match def main(): """主测试流程""" print("=" * 80) print("收敛三角形检测 - 完整流水线性能测试") print("=" * 80) # 配置 DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") # 检测参数 params = ConvergingTriangleParams( window=240, pivot_k=15, boundary_n_segments=2, boundary_source="full", fitting_method="anchor", upper_slope_max=0, lower_slope_min=0, touch_tol=0.10, touch_loss_max=0.10, shrink_ratio=0.45, break_tol=0.005, vol_window=20, vol_k=1.5, false_break_m=5, ) # 加载数据(全量) open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = \ load_test_data(DATA_DIR) # 测试1: 原版 print("\n" + "=" * 80) print("阶段 1/2: 测试原版") print("=" * 80) df_original, time_original = test_pipeline( open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, params, use_optimized=False, profile=False ) # 测试2: 优化版 print("\n" + "=" * 80) print("阶段 2/2: 测试优化版") print("=" * 80) df_optimized, time_optimized = test_pipeline( open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, params, use_optimized=True, profile=False ) # 对比结果 if df_original is not None and df_optimized is not None: results_match = compare_results(df_original, df_optimized) # 性能对比 print(f"\n{'='*80}") print("性能对比总结") print(f"{'='*80}") speedup = time_original / time_optimized if time_optimized > 0 else 0 improvement = ((time_original - time_optimized) / time_original * 100) if time_original > 0 else 0 time_saved = time_original - time_optimized print(f"\n{'指标':<20} {'原版':<20} {'优化版':<20} {'改善':<20}") print("-" * 80) print(f"{'总耗时':<20} {time_original:.2f}秒 ({time_original/60:.2f}分) " f"{time_optimized:.2f}秒 ({time_optimized/60:.2f}分) " f"-{time_saved:.2f}秒") print(f"{'加速比':<20} {'1.00x':<20} {f'{speedup:.2f}x':<20} {f'+{speedup-1:.2f}x':<20}") print(f"{'性能提升':<20} {'0%':<20} {f'{improvement:.1f}%':<20} {f'+{improvement:.1f}%':<20}") n_stocks, n_days = close_mtx.shape window = params.window end_day_idx = np.where(np.any(~np.isnan(close_mtx), axis=0))[0][-1] total_points = n_stocks * (end_day_idx - window + 2) speed_original = total_points / time_original speed_optimized = total_points / time_optimized print(f"{'处理速度':<20} {f'{speed_original:.0f}点/秒':<20} {f'{speed_optimized:.0f}点/秒':<20} {f'+{speed_optimized-speed_original:.0f}点/秒':<20}") print("\n" + "=" * 80) print("最终结论") print("=" * 80) if results_match: print("\n[OK] 输出结果完全一致 ✓") else: print("\n[WARNING] 输出结果存在差异,请检查") print(f"\n性能提升: {speedup:.1f}x ({improvement:.1f}%)") print(f"时间节省: {time_saved:.2f}秒 ({time_saved/60:.2f}分钟)") if speedup > 100: print("\n[推荐] 性能提升巨大 (>100x),强烈建议部署优化版本!") elif speedup > 10: print("\n[推荐] 性能提升显著 (>10x),建议部署优化版本") elif speedup > 2: print("\n[推荐] 性能有提升 (>2x),可以考虑部署优化版本") else: print("\n[提示] 性能提升不明显,可能不需要部署") print("\n" + "=" * 80) if __name__ == "__main__": main()