technical-patterns-lab/scripts/test_optimization_comparison.py
褚宏光 759042c5bd 性能优化:集成Numba加速,实现300+倍性能提升
核心改进:
- 新增 converging_triangle_optimized.py,使用Numba JIT编译优化7个核心函数
- 在 converging_triangle.py 末尾自动导入优化版本,无需手动配置
- 全量检测耗时从30秒降至<1秒(首次需3-5秒编译)

性能提升明细:
- pivots_fractal: 460x 加速
- pivots_fractal_hybrid: 511x 加速
- fit_boundary_anchor: 138x 加速
- calc_boundary_utilization: 195x 加速
- calc_fitting_adherence: 7x 加速
- calc_breakout_strength: 3x 加速

绘图功能增强:
- 添加 --plot-boundary-source 参数,支持选择高低价或收盘价拟合边界线
- 默认改为使用收盘价拟合(更平滑、更符合实际交易)
- 添加 --show-high-low 参数,可选显示日内高低价范围

技术特性:
- 自动检测并启用Numba加速,无numba时自动降级
- 结果与原版100%一致(误差<1e-6)
- 完整的性能测试和对比验证
- 零侵入性,原版函数作为备用

新增文件:
- src/converging_triangle_optimized.py - Numba优化版核心函数
- docs/README_性能优化.md - 性能优化文档索引
- docs/性能优化执行总结.md - 快速参考
- docs/性能优化完整报告.md - 完整技术报告
- docs/性能优化方案.md - 详细技术方案
- scripts/test_performance.py - 性能基线测试
- scripts/test_optimization_comparison.py - 优化对比测试
- scripts/test_full_pipeline.py - 完整流水线测试
- scripts/README_performance_tests.md - 测试脚本使用说明

修改文件:
- README.md - 添加性能优化说明和依赖
- src/converging_triangle.py - 集成优化版本导入
- scripts/pipeline_converging_triangle.py - 默认使用收盘价拟合
- scripts/plot_converging_triangles.py - 默认使用收盘价拟合
2026-01-28 17:22:13 +08:00

350 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
性能对比测试 - 原版 vs Numba优化版
测试各个优化函数的性能提升效果,并生成详细的对比报告。
"""
import os
import sys
import pickle
import time
import numpy as np
# 添加 src 路径
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
# 导入原版函数
from converging_triangle import (
pivots_fractal,
pivots_fractal_hybrid,
fit_boundary_anchor,
calc_fitting_adherence,
calc_boundary_utilization,
calc_breakout_strength,
)
# 导入优化版函数
from converging_triangle_optimized import (
pivots_fractal_optimized,
pivots_fractal_hybrid_optimized,
fit_boundary_anchor_optimized,
calc_fitting_adherence_optimized,
calc_boundary_utilization_optimized,
calc_breakout_strength_optimized,
)
class FakeModule:
"""空壳模块,绕过 model 依赖"""
ndarray = np.ndarray
def load_pkl(pkl_path: str) -> dict:
"""加载 pkl 文件"""
sys.modules['model'] = FakeModule()
sys.modules['model.index_info'] = FakeModule()
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data
def load_test_data(data_dir: str, n_stocks: int = 10, n_days: int = 500):
"""加载测试数据"""
print(f"加载测试数据 (stocks={n_stocks}, days={n_days})...")
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
high_mtx = high_data["mtx"][:n_stocks, -n_days:]
low_mtx = low_data["mtx"][:n_stocks, -n_days:]
print(f" 数据形状: {high_mtx.shape}")
return high_mtx, low_mtx
def benchmark_function(func, name, *args, n_iterations=100, warmup=5):
"""
基准测试单个函数
Args:
func: 要测试的函数
name: 函数名称
*args: 函数参数
n_iterations: 迭代次数
warmup: 预热次数
Returns:
(avg_time_ms, result): 平均耗时(毫秒)和函数结果
"""
# 预热对于numba很重要
for _ in range(warmup):
result = func(*args)
# 正式测试
start_time = time.time()
for _ in range(n_iterations):
result = func(*args)
elapsed = time.time() - start_time
avg_time_ms = (elapsed / n_iterations) * 1000
return avg_time_ms, result
def compare_functions(original_func, optimized_func, func_name, *args, n_iterations=100):
"""对比两个函数的性能"""
print(f"\n{'='*80}")
print(f"测试: {func_name}")
print(f"{'='*80}")
# 测试原版
print(f"\n[1] 原版函数...")
original_time, original_result = benchmark_function(
original_func, f"{func_name}_original", *args, n_iterations=n_iterations
)
print(f" 平均耗时: {original_time:.4f} 毫秒/次")
# 测试优化版
print(f"\n[2] Numba优化版...")
optimized_time, optimized_result = benchmark_function(
optimized_func, f"{func_name}_optimized", *args, n_iterations=n_iterations
)
print(f" 平均耗时: {optimized_time:.4f} 毫秒/次")
# 计算加速比
speedup = original_time / optimized_time if optimized_time > 0 else 0
improvement = ((original_time - optimized_time) / original_time * 100) if original_time > 0 else 0
print(f"\n[3] 性能对比:")
print(f" 加速比: {speedup:.2f}x")
print(f" 性能提升: {improvement:.1f}%")
print(f" 时间节省: {original_time - optimized_time:.4f} 毫秒/次")
# 验证结果一致性
print(f"\n[4] 结果验证:")
if isinstance(original_result, tuple):
for i, (orig, opt) in enumerate(zip(original_result, optimized_result)):
if isinstance(orig, np.ndarray):
match = np.allclose(orig, opt, rtol=1e-5, atol=1e-8)
print(f" 输出 {i+1} (数组): {'[OK] 一致' if match else '[ERR] 不一致'}")
if not match and len(orig) > 0 and len(opt) > 0:
print(f" 原版: shape={orig.shape}, sample={orig[:3]}")
print(f" 优化: shape={opt.shape}, sample={opt[:3]}")
else:
match = abs(orig - opt) < 1e-6
print(f" 输出 {i+1} (标量): {'[OK] 一致' if match else '[ERR] 不一致'} (原={orig:.6f}, 优={opt:.6f})")
else:
if isinstance(original_result, np.ndarray):
match = np.allclose(original_result, optimized_result, rtol=1e-5, atol=1e-8)
print(f" 结果 (数组): {'[OK] 一致' if match else '[ERR] 不一致'}")
else:
match = abs(original_result - optimized_result) < 1e-6
print(f" 结果 (标量): {'[OK] 一致' if match else '[ERR] 不一致'}")
return {
"name": func_name,
"original_time": original_time,
"optimized_time": optimized_time,
"speedup": speedup,
"improvement": improvement,
}
def main():
"""主测试流程"""
print("=" * 80)
print("收敛三角形检测 - 原版 vs Numba优化版 性能对比")
print("=" * 80)
# 配置
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
# 加载数据
high_mtx, low_mtx = load_test_data(DATA_DIR, n_stocks=10, n_days=500)
# 选择一个样本股票
sample_idx = 0
high = high_mtx[sample_idx, :]
low = low_mtx[sample_idx, :]
# 去除NaN
valid_mask = ~(np.isnan(high) | np.isnan(low))
high = high[valid_mask]
low = low[valid_mask]
print(f"\n样本数据长度: {len(high)}")
# 测试参数
k = 15
flexible_zone = 5
n_iterations = 100
results = []
# ========================================================================
# 测试 1: 枢轴点检测(标准方法)
# ========================================================================
result = compare_functions(
pivots_fractal,
pivots_fractal_optimized,
"pivots_fractal",
high, low, k,
n_iterations=n_iterations
)
results.append(result)
# ========================================================================
# 测试 2: 枢轴点检测(混合方法)
# ========================================================================
result = compare_functions(
pivots_fractal_hybrid,
pivots_fractal_hybrid_optimized,
"pivots_fractal_hybrid",
high, low, k, flexible_zone,
n_iterations=n_iterations
)
results.append(result)
# ========================================================================
# 测试 3: 锚点拟合
# ========================================================================
# 先获取枢轴点
ph, pl = pivots_fractal(high, low, k=k)
if len(ph) >= 5:
pivot_indices = ph[:10]
pivot_values = high[pivot_indices]
# 测试上沿拟合
result = compare_functions(
lambda pi, pv, ap: fit_boundary_anchor(
pi, pv, ap, mode="upper", window_start=0, window_end=len(ap)-1
),
lambda pi, pv, ap: fit_boundary_anchor_optimized(
pi, pv, ap, mode="upper", window_start=0, window_end=len(ap)-1
),
"fit_boundary_anchor (upper)",
pivot_indices, pivot_values, high,
n_iterations=n_iterations
)
results.append(result)
if len(pl) >= 5:
pivot_indices = pl[:10]
pivot_values = low[pivot_indices]
# 测试下沿拟合
result = compare_functions(
lambda pi, pv, ap: fit_boundary_anchor(
pi, pv, ap, mode="lower", window_start=0, window_end=len(ap)-1
),
lambda pi, pv, ap: fit_boundary_anchor_optimized(
pi, pv, ap, mode="lower", window_start=0, window_end=len(ap)-1
),
"fit_boundary_anchor (lower)",
pivot_indices, pivot_values, low,
n_iterations=n_iterations
)
results.append(result)
# ========================================================================
# 测试 4: 拟合贴合度计算
# ========================================================================
if len(ph) >= 5:
pivot_indices = ph[:10]
pivot_values = high[pivot_indices]
slope, intercept = 0.01, 100.0
result = compare_functions(
calc_fitting_adherence,
calc_fitting_adherence_optimized,
"calc_fitting_adherence",
pivot_indices, pivot_values, slope, intercept,
n_iterations=n_iterations
)
results.append(result)
# ========================================================================
# 测试 5: 边界利用率计算
# ========================================================================
upper_slope, upper_intercept = -0.02, 120.0
lower_slope, lower_intercept = 0.02, 80.0
start, end = 0, len(high) - 1
result = compare_functions(
calc_boundary_utilization,
calc_boundary_utilization_optimized,
"calc_boundary_utilization",
high, low, upper_slope, upper_intercept, lower_slope, lower_intercept, start, end,
n_iterations=n_iterations
)
results.append(result)
# ========================================================================
# 测试 6: 突破强度计算
# ========================================================================
result = compare_functions(
calc_breakout_strength,
calc_breakout_strength_optimized,
"calc_breakout_strength",
100.0, 105.0, 95.0, 1.5, 0.6, 0.8, 0.7,
n_iterations=n_iterations
)
results.append(result)
# ========================================================================
# 总结报告
# ========================================================================
print("\n\n")
print("=" * 80)
print("性能对比总结")
print("=" * 80)
print(f"\n{'函数名':<35} {'原版(ms)':<12} {'优化(ms)':<12} {'加速比':<10} {'提升':<10}")
print("-" * 80)
total_original = 0
total_optimized = 0
for r in results:
print(f"{r['name']:<35} {r['original_time']:<12.4f} {r['optimized_time']:<12.4f} "
f"{r['speedup']:<10.2f}x {r['improvement']:<10.1f}%")
total_original += r['original_time']
total_optimized += r['optimized_time']
print("-" * 80)
overall_speedup = total_original / total_optimized if total_optimized > 0 else 0
overall_improvement = ((total_original - total_optimized) / total_original * 100) if total_original > 0 else 0
print(f"{'总计':<35} {total_original:<12.4f} {total_optimized:<12.4f} "
f"{overall_speedup:<10.2f}x {overall_improvement:<10.1f}%")
# 估算全量数据性能提升
print("\n" + "=" * 80)
print("全量数据性能估算")
print("=" * 80)
# 从之前的测试结果我们知道全量数据108只股票 × 500天需要约30.83秒
baseline_time = 30.83 # 秒
estimated_time = baseline_time / overall_speedup
time_saved = baseline_time - estimated_time
print(f"\n基于当前加速比 {overall_speedup:.2f}x 估算:")
print(f" 原版耗时: {baseline_time:.2f} 秒 ({baseline_time/60:.2f} 分钟)")
print(f" 优化后耗时: {estimated_time:.2f} 秒 ({estimated_time/60:.2f} 分钟)")
print(f" 节省时间: {time_saved:.2f} 秒 ({time_saved/60:.2f} 分钟)")
print(f" 性能提升: {overall_improvement:.1f}%")
print("\n" + "=" * 80)
print("建议:")
print(" 1. 如果加速比 > 2x建议切换到优化版本")
print(" 2. 运行完整的集成测试验证正确性")
print(" 3. 使用 cProfile 分析优化版的新瓶颈")
print("=" * 80)
if __name__ == "__main__":
main()