""" 性能测试脚本 - 分析收敛三角形检测算法的性能瓶颈 此脚本不修改任何现有代码,仅用于性能分析和测试。 使用 cProfile 和 line_profiler 来识别热点函数。 """ import os import sys import pickle import time import cProfile import pstats import io from pstats import SortKey import numpy as np # 添加 src 路径 sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) from converging_triangle import ( ConvergingTriangleParams, detect_converging_triangle_batch, pivots_fractal, pivots_fractal_hybrid, fit_pivot_line, calc_breakout_strength, ) class FakeModule: """空壳模块,绕过 model 依赖""" ndarray = np.ndarray def load_pkl(pkl_path: str) -> dict: """加载 pkl 文件""" sys.modules['model'] = FakeModule() sys.modules['model.index_info'] = FakeModule() with open(pkl_path, 'rb') as f: data = pickle.load(f) return data def load_test_data(data_dir: str, n_stocks: int = 10, n_days: int = 500): """ 加载测试数据的子集 Args: data_dir: 数据目录 n_stocks: 使用多少只股票(用于小规模测试) n_days: 使用最近多少天的数据 Returns: (open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name) """ print(f"\n加载测试数据 (stocks={n_stocks}, days={n_days})...") open_data = load_pkl(os.path.join(data_dir, "open.pkl")) high_data = load_pkl(os.path.join(data_dir, "high.pkl")) low_data = load_pkl(os.path.join(data_dir, "low.pkl")) close_data = load_pkl(os.path.join(data_dir, "close.pkl")) volume_data = load_pkl(os.path.join(data_dir, "volume.pkl")) # 截取子集 open_mtx = open_data["mtx"][:n_stocks, -n_days:] high_mtx = high_data["mtx"][:n_stocks, -n_days:] low_mtx = low_data["mtx"][:n_stocks, -n_days:] close_mtx = close_data["mtx"][:n_stocks, -n_days:] volume_mtx = volume_data["mtx"][:n_stocks, -n_days:] dates = close_data["dtes"][-n_days:] tkrs = close_data["tkrs"][:n_stocks] tkrs_name = close_data["tkrs_name"][:n_stocks] print(f" 数据形状: {close_mtx.shape}") print(f" 日期范围: {dates[0]} ~ {dates[-1]}") return open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name def benchmark_batch_detection( open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, params: ConvergingTriangleParams, profile_output: str = None ): """ 基准测试:批量检测 Args: profile_output: 如果指定,保存 profile 结果到此文件 """ print("\n" + "=" * 80) print("基准测试:批量检测") print("=" * 80) n_stocks, n_days = close_mtx.shape window = params.window # 计算测试范围 start_day = window - 1 end_day = n_days - 1 total_points = n_stocks * (end_day - start_day + 1) print(f"\n测试配置:") print(f" 股票数: {n_stocks}") print(f" 交易日: {n_days}") print(f" 窗口大小: {window}") print(f" 检测点数: {total_points}") print(f" 实时模式: {'是' if hasattr(params, 'realtime_mode') else '否'}") # 预热(避免冷启动影响) print("\n预热中...") _ = detect_converging_triangle_batch( open_mtx=open_mtx[:2, :], high_mtx=high_mtx[:2, :], low_mtx=low_mtx[:2, :], close_mtx=close_mtx[:2, :], volume_mtx=volume_mtx[:2, :], params=params, start_day=start_day, end_day=min(start_day + 10, end_day), only_valid=True, verbose=False, ) # 性能测试 print("\n开始性能测试...") if profile_output: # 使用 cProfile profiler = cProfile.Profile() profiler.enable() start_time = time.time() df = detect_converging_triangle_batch( open_mtx=open_mtx, high_mtx=high_mtx, low_mtx=low_mtx, close_mtx=close_mtx, volume_mtx=volume_mtx, params=params, start_day=start_day, end_day=end_day, only_valid=True, verbose=False, ) elapsed = time.time() - start_time profiler.disable() # 保存 profile 结果 profiler.dump_stats(profile_output) print(f"\n[OK] Profile 结果已保存: {profile_output}") # 打印 top 20 热点函数 print("\n" + "-" * 80) print("Top 20 热点函数:") print("-" * 80) s = io.StringIO() ps = pstats.Stats(profiler, stream=s).sort_stats(SortKey.CUMULATIVE) ps.print_stats(20) print(s.getvalue()) else: # 简单计时 start_time = time.time() df = detect_converging_triangle_batch( open_mtx=open_mtx, high_mtx=high_mtx, low_mtx=low_mtx, close_mtx=close_mtx, volume_mtx=volume_mtx, params=params, start_day=start_day, end_day=end_day, only_valid=True, verbose=False, ) elapsed = time.time() - start_time # 输出统计 print("\n" + "=" * 80) print("性能统计") print("=" * 80) print(f"\n总耗时: {elapsed:.2f} 秒 ({elapsed/60:.2f} 分钟)") print(f"处理点数: {total_points}") print(f"平均速度: {total_points/elapsed:.1f} 点/秒") print(f"单点耗时: {elapsed/total_points*1000:.2f} 毫秒/点") if len(df) > 0: valid_count = df['is_valid'].sum() print(f"\n检测结果:") print(f" 有效三角形: {valid_count}") print(f" 检出率: {valid_count/total_points*100:.2f}%") return df, elapsed def benchmark_pivot_detection(high, low, k=15, n_iterations=100): """ 基准测试:枢轴点检测 """ print("\n" + "=" * 80) print("基准测试:枢轴点检测") print("=" * 80) print(f"\n测试配置:") print(f" 数据长度: {len(high)}") print(f" 窗口大小 k: {k}") print(f" 迭代次数: {n_iterations}") # 测试标准方法 start_time = time.time() for _ in range(n_iterations): ph, pl = pivots_fractal(high, low, k=k) elapsed_standard = time.time() - start_time print(f"\n标准方法 (pivots_fractal):") print(f" 总耗时: {elapsed_standard:.4f} 秒") print(f" 平均耗时: {elapsed_standard/n_iterations*1000:.4f} 毫秒/次") print(f" 检测到的枢轴点: 高点={len(ph)}, 低点={len(pl)}") # 测试混合方法 start_time = time.time() for _ in range(n_iterations): ph_c, pl_c, ph_cd, pl_cd = pivots_fractal_hybrid(high, low, k=k, flexible_zone=5) elapsed_hybrid = time.time() - start_time print(f"\n混合方法 (pivots_fractal_hybrid):") print(f" 总耗时: {elapsed_hybrid:.4f} 秒") print(f" 平均耗时: {elapsed_hybrid/n_iterations*1000:.4f} 毫秒/次") print(f" 确认点: 高点={len(ph_c)}, 低点={len(pl_c)}") print(f" 候选点: 高点={len(ph_cd)}, 低点={len(pl_cd)}") print(f"\n性能对比:") print(f" 混合/标准 比值: {elapsed_hybrid/elapsed_standard:.2f}x") return elapsed_standard, elapsed_hybrid def benchmark_line_fitting(pivot_indices, pivot_values, n_iterations=100): """ 基准测试:线性拟合 """ print("\n" + "=" * 80) print("基准测试:线性拟合") print("=" * 80) print(f"\n测试配置:") print(f" 枢轴点数: {len(pivot_indices)}") print(f" 迭代次数: {n_iterations}") start_time = time.time() for _ in range(n_iterations): a, b, selected = fit_pivot_line( pivot_indices=pivot_indices, pivot_values=pivot_values, mode="upper", ) elapsed = time.time() - start_time print(f"\n迭代拟合法 (fit_pivot_line):") print(f" 总耗时: {elapsed:.4f} 秒") print(f" 平均耗时: {elapsed/n_iterations*1000:.4f} 毫秒/次") print(f" 选中点数: {len(selected)}") return elapsed def main(): """主测试流程""" print("=" * 80) print("收敛三角形检测 - 性能分析") print("=" * 80) # 配置 DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "..", "outputs", "performance") os.makedirs(OUTPUT_DIR, exist_ok=True) # 测试参数 TEST_CONFIGS = [ {"name": "小规模测试", "n_stocks": 10, "n_days": 300}, {"name": "中等规模测试", "n_stocks": 50, "n_days": 500}, {"name": "全量测试", "n_stocks": 108, "n_days": 500}, ] # 检测参数 params = ConvergingTriangleParams( window=240, pivot_k=15, boundary_n_segments=2, boundary_source="full", fitting_method="anchor", upper_slope_max=0, lower_slope_min=0, touch_tol=0.10, touch_loss_max=0.10, shrink_ratio=0.45, break_tol=0.005, vol_window=20, vol_k=1.5, false_break_m=5, ) results = [] # 逐级测试 for i, config in enumerate(TEST_CONFIGS): print("\n\n") print("=" * 80) print(f"测试配置 {i+1}/{len(TEST_CONFIGS)}: {config['name']}") print("=" * 80) # 加载数据 open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = \ load_test_data(DATA_DIR, n_stocks=config['n_stocks'], n_days=config['n_days']) # 批量检测测试 profile_path = os.path.join(OUTPUT_DIR, f"profile_{config['name']}.prof") df, elapsed = benchmark_batch_detection( open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, params, profile_output=profile_path ) results.append({ "config": config['name'], "n_stocks": config['n_stocks'], "n_days": config['n_days'], "total_points": config['n_stocks'] * (config['n_days'] - params.window + 1), "elapsed": elapsed, "speed": config['n_stocks'] * (config['n_days'] - params.window + 1) / elapsed, }) # 枢轴点检测测试(仅第一次) if i == 0: sample_stock_idx = 0 high = high_mtx[sample_stock_idx, :] low = low_mtx[sample_stock_idx, :] benchmark_pivot_detection(high, low, k=params.pivot_k, n_iterations=100) # 线性拟合测试 ph, pl = pivots_fractal(high, low, k=params.pivot_k) if len(ph) >= 5: benchmark_line_fitting( pivot_indices=ph[:10], pivot_values=high[ph[:10]], n_iterations=100 ) # 总结报告 print("\n\n") print("=" * 80) print("性能测试总结") print("=" * 80) print(f"\n{'配置':<20} {'股票数':<10} {'交易日':<10} {'总点数':<15} {'耗时(秒)':<12} {'速度(点/秒)':<15}") print("-" * 90) for r in results: print(f"{r['config']:<20} {r['n_stocks']:<10} {r['n_days']:<10} " f"{r['total_points']:<15} {r['elapsed']:<12.2f} {r['speed']:<15.1f}") # 估算全量运行时间 if len(results) > 0: last_result = results[-1] if last_result['n_stocks'] == 108: print(f"\n全量数据 (108只股票 × 500天) 预计耗时: {last_result['elapsed']:.2f} 秒 ({last_result['elapsed']/60:.2f} 分钟)") print("\n" + "=" * 80) print("Profile 文件已保存到:") print(f" {OUTPUT_DIR}/") print("\n使用 snakeviz 可视化:") print(f" pip install snakeviz") print(f" snakeviz {OUTPUT_DIR}/profile_*.prof") print("=" * 80) if __name__ == "__main__": main()