核心改进: - 新增 converging_triangle_optimized.py,使用Numba JIT编译优化7个核心函数 - 在 converging_triangle.py 末尾自动导入优化版本,无需手动配置 - 全量检测耗时从30秒降至<1秒(首次需3-5秒编译) 性能提升明细: - pivots_fractal: 460x 加速 - pivots_fractal_hybrid: 511x 加速 - fit_boundary_anchor: 138x 加速 - calc_boundary_utilization: 195x 加速 - calc_fitting_adherence: 7x 加速 - calc_breakout_strength: 3x 加速 绘图功能增强: - 添加 --plot-boundary-source 参数,支持选择高低价或收盘价拟合边界线 - 默认改为使用收盘价拟合(更平滑、更符合实际交易) - 添加 --show-high-low 参数,可选显示日内高低价范围 技术特性: - 自动检测并启用Numba加速,无numba时自动降级 - 结果与原版100%一致(误差<1e-6) - 完整的性能测试和对比验证 - 零侵入性,原版函数作为备用 新增文件: - src/converging_triangle_optimized.py - Numba优化版核心函数 - docs/README_性能优化.md - 性能优化文档索引 - docs/性能优化执行总结.md - 快速参考 - docs/性能优化完整报告.md - 完整技术报告 - docs/性能优化方案.md - 详细技术方案 - scripts/test_performance.py - 性能基线测试 - scripts/test_optimization_comparison.py - 优化对比测试 - scripts/test_full_pipeline.py - 完整流水线测试 - scripts/README_performance_tests.md - 测试脚本使用说明 修改文件: - README.md - 添加性能优化说明和依赖 - src/converging_triangle.py - 集成优化版本导入 - scripts/pipeline_converging_triangle.py - 默认使用收盘价拟合 - scripts/plot_converging_triangles.py - 默认使用收盘价拟合
385 lines
12 KiB
Python
385 lines
12 KiB
Python
"""
|
||
性能测试脚本 - 分析收敛三角形检测算法的性能瓶颈
|
||
|
||
此脚本不修改任何现有代码,仅用于性能分析和测试。
|
||
使用 cProfile 和 line_profiler 来识别热点函数。
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import pickle
|
||
import time
|
||
import cProfile
|
||
import pstats
|
||
import io
|
||
from pstats import SortKey
|
||
import numpy as np
|
||
|
||
# 添加 src 路径
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||
|
||
from converging_triangle import (
|
||
ConvergingTriangleParams,
|
||
detect_converging_triangle_batch,
|
||
pivots_fractal,
|
||
pivots_fractal_hybrid,
|
||
fit_pivot_line,
|
||
calc_breakout_strength,
|
||
)
|
||
|
||
|
||
class FakeModule:
|
||
"""空壳模块,绕过 model 依赖"""
|
||
ndarray = np.ndarray
|
||
|
||
|
||
def load_pkl(pkl_path: str) -> dict:
|
||
"""加载 pkl 文件"""
|
||
sys.modules['model'] = FakeModule()
|
||
sys.modules['model.index_info'] = FakeModule()
|
||
|
||
with open(pkl_path, 'rb') as f:
|
||
data = pickle.load(f)
|
||
return data
|
||
|
||
|
||
def load_test_data(data_dir: str, n_stocks: int = 10, n_days: int = 500):
|
||
"""
|
||
加载测试数据的子集
|
||
|
||
Args:
|
||
data_dir: 数据目录
|
||
n_stocks: 使用多少只股票(用于小规模测试)
|
||
n_days: 使用最近多少天的数据
|
||
|
||
Returns:
|
||
(open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name)
|
||
"""
|
||
print(f"\n加载测试数据 (stocks={n_stocks}, days={n_days})...")
|
||
|
||
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
|
||
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
|
||
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
|
||
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
|
||
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
|
||
|
||
# 截取子集
|
||
open_mtx = open_data["mtx"][:n_stocks, -n_days:]
|
||
high_mtx = high_data["mtx"][:n_stocks, -n_days:]
|
||
low_mtx = low_data["mtx"][:n_stocks, -n_days:]
|
||
close_mtx = close_data["mtx"][:n_stocks, -n_days:]
|
||
volume_mtx = volume_data["mtx"][:n_stocks, -n_days:]
|
||
|
||
dates = close_data["dtes"][-n_days:]
|
||
tkrs = close_data["tkrs"][:n_stocks]
|
||
tkrs_name = close_data["tkrs_name"][:n_stocks]
|
||
|
||
print(f" 数据形状: {close_mtx.shape}")
|
||
print(f" 日期范围: {dates[0]} ~ {dates[-1]}")
|
||
|
||
return open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name
|
||
|
||
|
||
def benchmark_batch_detection(
|
||
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx,
|
||
params: ConvergingTriangleParams,
|
||
profile_output: str = None
|
||
):
|
||
"""
|
||
基准测试:批量检测
|
||
|
||
Args:
|
||
profile_output: 如果指定,保存 profile 结果到此文件
|
||
"""
|
||
print("\n" + "=" * 80)
|
||
print("基准测试:批量检测")
|
||
print("=" * 80)
|
||
|
||
n_stocks, n_days = close_mtx.shape
|
||
window = params.window
|
||
|
||
# 计算测试范围
|
||
start_day = window - 1
|
||
end_day = n_days - 1
|
||
total_points = n_stocks * (end_day - start_day + 1)
|
||
|
||
print(f"\n测试配置:")
|
||
print(f" 股票数: {n_stocks}")
|
||
print(f" 交易日: {n_days}")
|
||
print(f" 窗口大小: {window}")
|
||
print(f" 检测点数: {total_points}")
|
||
print(f" 实时模式: {'是' if hasattr(params, 'realtime_mode') else '否'}")
|
||
|
||
# 预热(避免冷启动影响)
|
||
print("\n预热中...")
|
||
_ = detect_converging_triangle_batch(
|
||
open_mtx=open_mtx[:2, :],
|
||
high_mtx=high_mtx[:2, :],
|
||
low_mtx=low_mtx[:2, :],
|
||
close_mtx=close_mtx[:2, :],
|
||
volume_mtx=volume_mtx[:2, :],
|
||
params=params,
|
||
start_day=start_day,
|
||
end_day=min(start_day + 10, end_day),
|
||
only_valid=True,
|
||
verbose=False,
|
||
)
|
||
|
||
# 性能测试
|
||
print("\n开始性能测试...")
|
||
|
||
if profile_output:
|
||
# 使用 cProfile
|
||
profiler = cProfile.Profile()
|
||
profiler.enable()
|
||
|
||
start_time = time.time()
|
||
df = detect_converging_triangle_batch(
|
||
open_mtx=open_mtx,
|
||
high_mtx=high_mtx,
|
||
low_mtx=low_mtx,
|
||
close_mtx=close_mtx,
|
||
volume_mtx=volume_mtx,
|
||
params=params,
|
||
start_day=start_day,
|
||
end_day=end_day,
|
||
only_valid=True,
|
||
verbose=False,
|
||
)
|
||
elapsed = time.time() - start_time
|
||
|
||
profiler.disable()
|
||
|
||
# 保存 profile 结果
|
||
profiler.dump_stats(profile_output)
|
||
print(f"\n[OK] Profile 结果已保存: {profile_output}")
|
||
|
||
# 打印 top 20 热点函数
|
||
print("\n" + "-" * 80)
|
||
print("Top 20 热点函数:")
|
||
print("-" * 80)
|
||
|
||
s = io.StringIO()
|
||
ps = pstats.Stats(profiler, stream=s).sort_stats(SortKey.CUMULATIVE)
|
||
ps.print_stats(20)
|
||
print(s.getvalue())
|
||
|
||
else:
|
||
# 简单计时
|
||
start_time = time.time()
|
||
df = detect_converging_triangle_batch(
|
||
open_mtx=open_mtx,
|
||
high_mtx=high_mtx,
|
||
low_mtx=low_mtx,
|
||
close_mtx=close_mtx,
|
||
volume_mtx=volume_mtx,
|
||
params=params,
|
||
start_day=start_day,
|
||
end_day=end_day,
|
||
only_valid=True,
|
||
verbose=False,
|
||
)
|
||
elapsed = time.time() - start_time
|
||
|
||
# 输出统计
|
||
print("\n" + "=" * 80)
|
||
print("性能统计")
|
||
print("=" * 80)
|
||
print(f"\n总耗时: {elapsed:.2f} 秒 ({elapsed/60:.2f} 分钟)")
|
||
print(f"处理点数: {total_points}")
|
||
print(f"平均速度: {total_points/elapsed:.1f} 点/秒")
|
||
print(f"单点耗时: {elapsed/total_points*1000:.2f} 毫秒/点")
|
||
|
||
if len(df) > 0:
|
||
valid_count = df['is_valid'].sum()
|
||
print(f"\n检测结果:")
|
||
print(f" 有效三角形: {valid_count}")
|
||
print(f" 检出率: {valid_count/total_points*100:.2f}%")
|
||
|
||
return df, elapsed
|
||
|
||
|
||
def benchmark_pivot_detection(high, low, k=15, n_iterations=100):
|
||
"""
|
||
基准测试:枢轴点检测
|
||
"""
|
||
print("\n" + "=" * 80)
|
||
print("基准测试:枢轴点检测")
|
||
print("=" * 80)
|
||
|
||
print(f"\n测试配置:")
|
||
print(f" 数据长度: {len(high)}")
|
||
print(f" 窗口大小 k: {k}")
|
||
print(f" 迭代次数: {n_iterations}")
|
||
|
||
# 测试标准方法
|
||
start_time = time.time()
|
||
for _ in range(n_iterations):
|
||
ph, pl = pivots_fractal(high, low, k=k)
|
||
elapsed_standard = time.time() - start_time
|
||
|
||
print(f"\n标准方法 (pivots_fractal):")
|
||
print(f" 总耗时: {elapsed_standard:.4f} 秒")
|
||
print(f" 平均耗时: {elapsed_standard/n_iterations*1000:.4f} 毫秒/次")
|
||
print(f" 检测到的枢轴点: 高点={len(ph)}, 低点={len(pl)}")
|
||
|
||
# 测试混合方法
|
||
start_time = time.time()
|
||
for _ in range(n_iterations):
|
||
ph_c, pl_c, ph_cd, pl_cd = pivots_fractal_hybrid(high, low, k=k, flexible_zone=5)
|
||
elapsed_hybrid = time.time() - start_time
|
||
|
||
print(f"\n混合方法 (pivots_fractal_hybrid):")
|
||
print(f" 总耗时: {elapsed_hybrid:.4f} 秒")
|
||
print(f" 平均耗时: {elapsed_hybrid/n_iterations*1000:.4f} 毫秒/次")
|
||
print(f" 确认点: 高点={len(ph_c)}, 低点={len(pl_c)}")
|
||
print(f" 候选点: 高点={len(ph_cd)}, 低点={len(pl_cd)}")
|
||
|
||
print(f"\n性能对比:")
|
||
print(f" 混合/标准 比值: {elapsed_hybrid/elapsed_standard:.2f}x")
|
||
|
||
return elapsed_standard, elapsed_hybrid
|
||
|
||
|
||
def benchmark_line_fitting(pivot_indices, pivot_values, n_iterations=100):
|
||
"""
|
||
基准测试:线性拟合
|
||
"""
|
||
print("\n" + "=" * 80)
|
||
print("基准测试:线性拟合")
|
||
print("=" * 80)
|
||
|
||
print(f"\n测试配置:")
|
||
print(f" 枢轴点数: {len(pivot_indices)}")
|
||
print(f" 迭代次数: {n_iterations}")
|
||
|
||
start_time = time.time()
|
||
for _ in range(n_iterations):
|
||
a, b, selected = fit_pivot_line(
|
||
pivot_indices=pivot_indices,
|
||
pivot_values=pivot_values,
|
||
mode="upper",
|
||
)
|
||
elapsed = time.time() - start_time
|
||
|
||
print(f"\n迭代拟合法 (fit_pivot_line):")
|
||
print(f" 总耗时: {elapsed:.4f} 秒")
|
||
print(f" 平均耗时: {elapsed/n_iterations*1000:.4f} 毫秒/次")
|
||
print(f" 选中点数: {len(selected)}")
|
||
|
||
return elapsed
|
||
|
||
|
||
def main():
|
||
"""主测试流程"""
|
||
print("=" * 80)
|
||
print("收敛三角形检测 - 性能分析")
|
||
print("=" * 80)
|
||
|
||
# 配置
|
||
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
|
||
OUTPUT_DIR = os.path.join(os.path.dirname(__file__), "..", "outputs", "performance")
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
# 测试参数
|
||
TEST_CONFIGS = [
|
||
{"name": "小规模测试", "n_stocks": 10, "n_days": 300},
|
||
{"name": "中等规模测试", "n_stocks": 50, "n_days": 500},
|
||
{"name": "全量测试", "n_stocks": 108, "n_days": 500},
|
||
]
|
||
|
||
# 检测参数
|
||
params = ConvergingTriangleParams(
|
||
window=240,
|
||
pivot_k=15,
|
||
boundary_n_segments=2,
|
||
boundary_source="full",
|
||
fitting_method="anchor",
|
||
upper_slope_max=0,
|
||
lower_slope_min=0,
|
||
touch_tol=0.10,
|
||
touch_loss_max=0.10,
|
||
shrink_ratio=0.45,
|
||
break_tol=0.005,
|
||
vol_window=20,
|
||
vol_k=1.5,
|
||
false_break_m=5,
|
||
)
|
||
|
||
results = []
|
||
|
||
# 逐级测试
|
||
for i, config in enumerate(TEST_CONFIGS):
|
||
print("\n\n")
|
||
print("=" * 80)
|
||
print(f"测试配置 {i+1}/{len(TEST_CONFIGS)}: {config['name']}")
|
||
print("=" * 80)
|
||
|
||
# 加载数据
|
||
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = \
|
||
load_test_data(DATA_DIR, n_stocks=config['n_stocks'], n_days=config['n_days'])
|
||
|
||
# 批量检测测试
|
||
profile_path = os.path.join(OUTPUT_DIR, f"profile_{config['name']}.prof")
|
||
df, elapsed = benchmark_batch_detection(
|
||
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx,
|
||
params,
|
||
profile_output=profile_path
|
||
)
|
||
|
||
results.append({
|
||
"config": config['name'],
|
||
"n_stocks": config['n_stocks'],
|
||
"n_days": config['n_days'],
|
||
"total_points": config['n_stocks'] * (config['n_days'] - params.window + 1),
|
||
"elapsed": elapsed,
|
||
"speed": config['n_stocks'] * (config['n_days'] - params.window + 1) / elapsed,
|
||
})
|
||
|
||
# 枢轴点检测测试(仅第一次)
|
||
if i == 0:
|
||
sample_stock_idx = 0
|
||
high = high_mtx[sample_stock_idx, :]
|
||
low = low_mtx[sample_stock_idx, :]
|
||
benchmark_pivot_detection(high, low, k=params.pivot_k, n_iterations=100)
|
||
|
||
# 线性拟合测试
|
||
ph, pl = pivots_fractal(high, low, k=params.pivot_k)
|
||
if len(ph) >= 5:
|
||
benchmark_line_fitting(
|
||
pivot_indices=ph[:10],
|
||
pivot_values=high[ph[:10]],
|
||
n_iterations=100
|
||
)
|
||
|
||
# 总结报告
|
||
print("\n\n")
|
||
print("=" * 80)
|
||
print("性能测试总结")
|
||
print("=" * 80)
|
||
|
||
print(f"\n{'配置':<20} {'股票数':<10} {'交易日':<10} {'总点数':<15} {'耗时(秒)':<12} {'速度(点/秒)':<15}")
|
||
print("-" * 90)
|
||
|
||
for r in results:
|
||
print(f"{r['config']:<20} {r['n_stocks']:<10} {r['n_days']:<10} "
|
||
f"{r['total_points']:<15} {r['elapsed']:<12.2f} {r['speed']:<15.1f}")
|
||
|
||
# 估算全量运行时间
|
||
if len(results) > 0:
|
||
last_result = results[-1]
|
||
if last_result['n_stocks'] == 108:
|
||
print(f"\n全量数据 (108只股票 × 500天) 预计耗时: {last_result['elapsed']:.2f} 秒 ({last_result['elapsed']/60:.2f} 分钟)")
|
||
|
||
print("\n" + "=" * 80)
|
||
print("Profile 文件已保存到:")
|
||
print(f" {OUTPUT_DIR}/")
|
||
print("\n使用 snakeviz 可视化:")
|
||
print(f" pip install snakeviz")
|
||
print(f" snakeviz {OUTPUT_DIR}/profile_*.prof")
|
||
print("=" * 80)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|