technical-patterns-lab/scripts/test_performance.py
褚宏光 759042c5bd 性能优化:集成Numba加速,实现300+倍性能提升
核心改进:
- 新增 converging_triangle_optimized.py,使用Numba JIT编译优化7个核心函数
- 在 converging_triangle.py 末尾自动导入优化版本,无需手动配置
- 全量检测耗时从30秒降至<1秒(首次需3-5秒编译)

性能提升明细:
- pivots_fractal: 460x 加速
- pivots_fractal_hybrid: 511x 加速
- fit_boundary_anchor: 138x 加速
- calc_boundary_utilization: 195x 加速
- calc_fitting_adherence: 7x 加速
- calc_breakout_strength: 3x 加速

绘图功能增强:
- 添加 --plot-boundary-source 参数,支持选择高低价或收盘价拟合边界线
- 默认改为使用收盘价拟合(更平滑、更符合实际交易)
- 添加 --show-high-low 参数,可选显示日内高低价范围

技术特性:
- 自动检测并启用Numba加速,无numba时自动降级
- 结果与原版100%一致(误差<1e-6)
- 完整的性能测试和对比验证
- 零侵入性,原版函数作为备用

新增文件:
- src/converging_triangle_optimized.py - Numba优化版核心函数
- docs/README_性能优化.md - 性能优化文档索引
- docs/性能优化执行总结.md - 快速参考
- docs/性能优化完整报告.md - 完整技术报告
- docs/性能优化方案.md - 详细技术方案
- scripts/test_performance.py - 性能基线测试
- scripts/test_optimization_comparison.py - 优化对比测试
- scripts/test_full_pipeline.py - 完整流水线测试
- scripts/README_performance_tests.md - 测试脚本使用说明

修改文件:
- README.md - 添加性能优化说明和依赖
- src/converging_triangle.py - 集成优化版本导入
- scripts/pipeline_converging_triangle.py - 默认使用收盘价拟合
- scripts/plot_converging_triangles.py - 默认使用收盘价拟合
2026-01-28 17:22:13 +08:00

385 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
性能测试脚本 - 分析收敛三角形检测算法的性能瓶颈
此脚本不修改任何现有代码,仅用于性能分析和测试。
使用 cProfile 和 line_profiler 来识别热点函数。
"""
import os
import sys
import pickle
import time
import cProfile
import pstats
import io
from pstats import SortKey
import numpy as np
# 添加 src 路径
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from converging_triangle import (
ConvergingTriangleParams,
detect_converging_triangle_batch,
pivots_fractal,
pivots_fractal_hybrid,
fit_pivot_line,
calc_breakout_strength,
)
class FakeModule:
"""空壳模块,绕过 model 依赖"""
ndarray = np.ndarray
def load_pkl(pkl_path: str) -> dict:
"""加载 pkl 文件"""
sys.modules['model'] = FakeModule()
sys.modules['model.index_info'] = FakeModule()
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data
def load_test_data(data_dir: str, n_stocks: int = 10, n_days: int = 500):
"""
加载测试数据的子集
Args:
data_dir: 数据目录
n_stocks: 使用多少只股票(用于小规模测试)
n_days: 使用最近多少天的数据
Returns:
(open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name)
"""
print(f"\n加载测试数据 (stocks={n_stocks}, days={n_days})...")
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
# 截取子集
open_mtx = open_data["mtx"][:n_stocks, -n_days:]
high_mtx = high_data["mtx"][:n_stocks, -n_days:]
low_mtx = low_data["mtx"][:n_stocks, -n_days:]
close_mtx = close_data["mtx"][:n_stocks, -n_days:]
volume_mtx = volume_data["mtx"][:n_stocks, -n_days:]
dates = close_data["dtes"][-n_days:]
tkrs = close_data["tkrs"][:n_stocks]
tkrs_name = close_data["tkrs_name"][:n_stocks]
print(f" 数据形状: {close_mtx.shape}")
print(f" 日期范围: {dates[0]} ~ {dates[-1]}")
return open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name
def benchmark_batch_detection(
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx,
params: ConvergingTriangleParams,
profile_output: str = None
):
"""
基准测试:批量检测
Args:
profile_output: 如果指定,保存 profile 结果到此文件
"""
print("\n" + "=" * 80)
print("基准测试:批量检测")
print("=" * 80)
n_stocks, n_days = close_mtx.shape
window = params.window
# 计算测试范围
start_day = window - 1
end_day = n_days - 1
total_points = n_stocks * (end_day - start_day + 1)
print(f"\n测试配置:")
print(f" 股票数: {n_stocks}")
print(f" 交易日: {n_days}")
print(f" 窗口大小: {window}")
print(f" 检测点数: {total_points}")
print(f" 实时模式: {'' if hasattr(params, 'realtime_mode') else ''}")
# 预热(避免冷启动影响)
print("\n预热中...")
_ = detect_converging_triangle_batch(
open_mtx=open_mtx[:2, :],
high_mtx=high_mtx[:2, :],
low_mtx=low_mtx[:2, :],
close_mtx=close_mtx[:2, :],
volume_mtx=volume_mtx[:2, :],
params=params,
start_day=start_day,
end_day=min(start_day + 10, end_day),
only_valid=True,
verbose=False,
)
# 性能测试
print("\n开始性能测试...")
if profile_output:
# 使用 cProfile
profiler = cProfile.Profile()
profiler.enable()
start_time = time.time()
df = detect_converging_triangle_batch(
open_mtx=open_mtx,
high_mtx=high_mtx,
low_mtx=low_mtx,
close_mtx=close_mtx,
volume_mtx=volume_mtx,
params=params,
start_day=start_day,
end_day=end_day,
only_valid=True,
verbose=False,
)
elapsed = time.time() - start_time
profiler.disable()
# 保存 profile 结果
profiler.dump_stats(profile_output)
print(f"\n[OK] Profile 结果已保存: {profile_output}")
# 打印 top 20 热点函数
print("\n" + "-" * 80)
print("Top 20 热点函数:")
print("-" * 80)
s = io.StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats(SortKey.CUMULATIVE)
ps.print_stats(20)
print(s.getvalue())
else:
# 简单计时
start_time = time.time()
df = detect_converging_triangle_batch(
open_mtx=open_mtx,
high_mtx=high_mtx,
low_mtx=low_mtx,
close_mtx=close_mtx,
volume_mtx=volume_mtx,
params=params,
start_day=start_day,
end_day=end_day,
only_valid=True,
verbose=False,
)
elapsed = time.time() - start_time
# 输出统计
print("\n" + "=" * 80)
print("性能统计")
print("=" * 80)
print(f"\n总耗时: {elapsed:.2f} 秒 ({elapsed/60:.2f} 分钟)")
print(f"处理点数: {total_points}")
print(f"平均速度: {total_points/elapsed:.1f} 点/秒")
print(f"单点耗时: {elapsed/total_points*1000:.2f} 毫秒/点")
if len(df) > 0:
valid_count = df['is_valid'].sum()
print(f"\n检测结果:")
print(f" 有效三角形: {valid_count}")
print(f" 检出率: {valid_count/total_points*100:.2f}%")
return df, elapsed
def benchmark_pivot_detection(high, low, k=15, n_iterations=100):
"""
基准测试:枢轴点检测
"""
print("\n" + "=" * 80)
print("基准测试:枢轴点检测")
print("=" * 80)
print(f"\n测试配置:")
print(f" 数据长度: {len(high)}")
print(f" 窗口大小 k: {k}")
print(f" 迭代次数: {n_iterations}")
# 测试标准方法
start_time = time.time()
for _ in range(n_iterations):
ph, pl = pivots_fractal(high, low, k=k)
elapsed_standard = time.time() - start_time
print(f"\n标准方法 (pivots_fractal):")
print(f" 总耗时: {elapsed_standard:.4f}")
print(f" 平均耗时: {elapsed_standard/n_iterations*1000:.4f} 毫秒/次")
print(f" 检测到的枢轴点: 高点={len(ph)}, 低点={len(pl)}")
# 测试混合方法
start_time = time.time()
for _ in range(n_iterations):
ph_c, pl_c, ph_cd, pl_cd = pivots_fractal_hybrid(high, low, k=k, flexible_zone=5)
elapsed_hybrid = time.time() - start_time
print(f"\n混合方法 (pivots_fractal_hybrid):")
print(f" 总耗时: {elapsed_hybrid:.4f}")
print(f" 平均耗时: {elapsed_hybrid/n_iterations*1000:.4f} 毫秒/次")
print(f" 确认点: 高点={len(ph_c)}, 低点={len(pl_c)}")
print(f" 候选点: 高点={len(ph_cd)}, 低点={len(pl_cd)}")
print(f"\n性能对比:")
print(f" 混合/标准 比值: {elapsed_hybrid/elapsed_standard:.2f}x")
return elapsed_standard, elapsed_hybrid
def benchmark_line_fitting(pivot_indices, pivot_values, n_iterations=100):
"""
基准测试:线性拟合
"""
print("\n" + "=" * 80)
print("基准测试:线性拟合")
print("=" * 80)
print(f"\n测试配置:")
print(f" 枢轴点数: {len(pivot_indices)}")
print(f" 迭代次数: {n_iterations}")
start_time = time.time()
for _ in range(n_iterations):
a, b, selected = fit_pivot_line(
pivot_indices=pivot_indices,
pivot_values=pivot_values,
mode="upper",
)
elapsed = time.time() - start_time
print(f"\n迭代拟合法 (fit_pivot_line):")
print(f" 总耗时: {elapsed:.4f}")
print(f" 平均耗时: {elapsed/n_iterations*1000:.4f} 毫秒/次")
print(f" 选中点数: {len(selected)}")
return elapsed
def main():
"""主测试流程"""
print("=" * 80)
print("收敛三角形检测 - 性能分析")
print("=" * 80)
# 配置
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "..", "outputs", "performance")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 测试参数
TEST_CONFIGS = [
{"name": "小规模测试", "n_stocks": 10, "n_days": 300},
{"name": "中等规模测试", "n_stocks": 50, "n_days": 500},
{"name": "全量测试", "n_stocks": 108, "n_days": 500},
]
# 检测参数
params = ConvergingTriangleParams(
window=240,
pivot_k=15,
boundary_n_segments=2,
boundary_source="full",
fitting_method="anchor",
upper_slope_max=0,
lower_slope_min=0,
touch_tol=0.10,
touch_loss_max=0.10,
shrink_ratio=0.45,
break_tol=0.005,
vol_window=20,
vol_k=1.5,
false_break_m=5,
)
results = []
# 逐级测试
for i, config in enumerate(TEST_CONFIGS):
print("\n\n")
print("=" * 80)
print(f"测试配置 {i+1}/{len(TEST_CONFIGS)}: {config['name']}")
print("=" * 80)
# 加载数据
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = \
load_test_data(DATA_DIR, n_stocks=config['n_stocks'], n_days=config['n_days'])
# 批量检测测试
profile_path = os.path.join(OUTPUT_DIR, f"profile_{config['name']}.prof")
df, elapsed = benchmark_batch_detection(
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx,
params,
profile_output=profile_path
)
results.append({
"config": config['name'],
"n_stocks": config['n_stocks'],
"n_days": config['n_days'],
"total_points": config['n_stocks'] * (config['n_days'] - params.window + 1),
"elapsed": elapsed,
"speed": config['n_stocks'] * (config['n_days'] - params.window + 1) / elapsed,
})
# 枢轴点检测测试(仅第一次)
if i == 0:
sample_stock_idx = 0
high = high_mtx[sample_stock_idx, :]
low = low_mtx[sample_stock_idx, :]
benchmark_pivot_detection(high, low, k=params.pivot_k, n_iterations=100)
# 线性拟合测试
ph, pl = pivots_fractal(high, low, k=params.pivot_k)
if len(ph) >= 5:
benchmark_line_fitting(
pivot_indices=ph[:10],
pivot_values=high[ph[:10]],
n_iterations=100
)
# 总结报告
print("\n\n")
print("=" * 80)
print("性能测试总结")
print("=" * 80)
print(f"\n{'配置':<20} {'股票数':<10} {'交易日':<10} {'总点数':<15} {'耗时(秒)':<12} {'速度(点/秒)':<15}")
print("-" * 90)
for r in results:
print(f"{r['config']:<20} {r['n_stocks']:<10} {r['n_days']:<10} "
f"{r['total_points']:<15} {r['elapsed']:<12.2f} {r['speed']:<15.1f}")
# 估算全量运行时间
if len(results) > 0:
last_result = results[-1]
if last_result['n_stocks'] == 108:
print(f"\n全量数据 (108只股票 × 500天) 预计耗时: {last_result['elapsed']:.2f} 秒 ({last_result['elapsed']/60:.2f} 分钟)")
print("\n" + "=" * 80)
print("Profile 文件已保存到:")
print(f" {OUTPUT_DIR}/")
print("\n使用 snakeviz 可视化:")
print(f" pip install snakeviz")
print(f" snakeviz {OUTPUT_DIR}/profile_*.prof")
print("=" * 80)
if __name__ == "__main__":
main()