核心改进: - 新增 converging_triangle_optimized.py,使用Numba JIT编译优化7个核心函数 - 在 converging_triangle.py 末尾自动导入优化版本,无需手动配置 - 全量检测耗时从30秒降至<1秒(首次需3-5秒编译) 性能提升明细: - pivots_fractal: 460x 加速 - pivots_fractal_hybrid: 511x 加速 - fit_boundary_anchor: 138x 加速 - calc_boundary_utilization: 195x 加速 - calc_fitting_adherence: 7x 加速 - calc_breakout_strength: 3x 加速 绘图功能增强: - 添加 --plot-boundary-source 参数,支持选择高低价或收盘价拟合边界线 - 默认改为使用收盘价拟合(更平滑、更符合实际交易) - 添加 --show-high-low 参数,可选显示日内高低价范围 技术特性: - 自动检测并启用Numba加速,无numba时自动降级 - 结果与原版100%一致(误差<1e-6) - 完整的性能测试和对比验证 - 零侵入性,原版函数作为备用 新增文件: - src/converging_triangle_optimized.py - Numba优化版核心函数 - docs/README_性能优化.md - 性能优化文档索引 - docs/性能优化执行总结.md - 快速参考 - docs/性能优化完整报告.md - 完整技术报告 - docs/性能优化方案.md - 详细技术方案 - scripts/test_performance.py - 性能基线测试 - scripts/test_optimization_comparison.py - 优化对比测试 - scripts/test_full_pipeline.py - 完整流水线测试 - scripts/README_performance_tests.md - 测试脚本使用说明 修改文件: - README.md - 添加性能优化说明和依赖 - src/converging_triangle.py - 集成优化版本导入 - scripts/pipeline_converging_triangle.py - 默认使用收盘价拟合 - scripts/plot_converging_triangles.py - 默认使用收盘价拟合
350 lines
12 KiB
Python
350 lines
12 KiB
Python
"""
|
||
性能对比测试 - 原版 vs Numba优化版
|
||
|
||
测试各个优化函数的性能提升效果,并生成详细的对比报告。
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import pickle
|
||
import time
|
||
import numpy as np
|
||
|
||
# 添加 src 路径
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||
|
||
# 导入原版函数
|
||
from converging_triangle import (
|
||
pivots_fractal,
|
||
pivots_fractal_hybrid,
|
||
fit_boundary_anchor,
|
||
calc_fitting_adherence,
|
||
calc_boundary_utilization,
|
||
calc_breakout_strength,
|
||
)
|
||
|
||
# 导入优化版函数
|
||
from converging_triangle_optimized import (
|
||
pivots_fractal_optimized,
|
||
pivots_fractal_hybrid_optimized,
|
||
fit_boundary_anchor_optimized,
|
||
calc_fitting_adherence_optimized,
|
||
calc_boundary_utilization_optimized,
|
||
calc_breakout_strength_optimized,
|
||
)
|
||
|
||
|
||
class FakeModule:
|
||
"""空壳模块,绕过 model 依赖"""
|
||
ndarray = np.ndarray
|
||
|
||
|
||
def load_pkl(pkl_path: str) -> dict:
|
||
"""加载 pkl 文件"""
|
||
sys.modules['model'] = FakeModule()
|
||
sys.modules['model.index_info'] = FakeModule()
|
||
|
||
with open(pkl_path, 'rb') as f:
|
||
data = pickle.load(f)
|
||
return data
|
||
|
||
|
||
def load_test_data(data_dir: str, n_stocks: int = 10, n_days: int = 500):
|
||
"""加载测试数据"""
|
||
print(f"加载测试数据 (stocks={n_stocks}, days={n_days})...")
|
||
|
||
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
|
||
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
|
||
|
||
high_mtx = high_data["mtx"][:n_stocks, -n_days:]
|
||
low_mtx = low_data["mtx"][:n_stocks, -n_days:]
|
||
|
||
print(f" 数据形状: {high_mtx.shape}")
|
||
|
||
return high_mtx, low_mtx
|
||
|
||
|
||
def benchmark_function(func, name, *args, n_iterations=100, warmup=5):
|
||
"""
|
||
基准测试单个函数
|
||
|
||
Args:
|
||
func: 要测试的函数
|
||
name: 函数名称
|
||
*args: 函数参数
|
||
n_iterations: 迭代次数
|
||
warmup: 预热次数
|
||
|
||
Returns:
|
||
(avg_time_ms, result): 平均耗时(毫秒)和函数结果
|
||
"""
|
||
# 预热(对于numba很重要)
|
||
for _ in range(warmup):
|
||
result = func(*args)
|
||
|
||
# 正式测试
|
||
start_time = time.time()
|
||
for _ in range(n_iterations):
|
||
result = func(*args)
|
||
elapsed = time.time() - start_time
|
||
|
||
avg_time_ms = (elapsed / n_iterations) * 1000
|
||
|
||
return avg_time_ms, result
|
||
|
||
|
||
def compare_functions(original_func, optimized_func, func_name, *args, n_iterations=100):
|
||
"""对比两个函数的性能"""
|
||
print(f"\n{'='*80}")
|
||
print(f"测试: {func_name}")
|
||
print(f"{'='*80}")
|
||
|
||
# 测试原版
|
||
print(f"\n[1] 原版函数...")
|
||
original_time, original_result = benchmark_function(
|
||
original_func, f"{func_name}_original", *args, n_iterations=n_iterations
|
||
)
|
||
print(f" 平均耗时: {original_time:.4f} 毫秒/次")
|
||
|
||
# 测试优化版
|
||
print(f"\n[2] Numba优化版...")
|
||
optimized_time, optimized_result = benchmark_function(
|
||
optimized_func, f"{func_name}_optimized", *args, n_iterations=n_iterations
|
||
)
|
||
print(f" 平均耗时: {optimized_time:.4f} 毫秒/次")
|
||
|
||
# 计算加速比
|
||
speedup = original_time / optimized_time if optimized_time > 0 else 0
|
||
improvement = ((original_time - optimized_time) / original_time * 100) if original_time > 0 else 0
|
||
|
||
print(f"\n[3] 性能对比:")
|
||
print(f" 加速比: {speedup:.2f}x")
|
||
print(f" 性能提升: {improvement:.1f}%")
|
||
print(f" 时间节省: {original_time - optimized_time:.4f} 毫秒/次")
|
||
|
||
# 验证结果一致性
|
||
print(f"\n[4] 结果验证:")
|
||
if isinstance(original_result, tuple):
|
||
for i, (orig, opt) in enumerate(zip(original_result, optimized_result)):
|
||
if isinstance(orig, np.ndarray):
|
||
match = np.allclose(orig, opt, rtol=1e-5, atol=1e-8)
|
||
print(f" 输出 {i+1} (数组): {'[OK] 一致' if match else '[ERR] 不一致'}")
|
||
if not match and len(orig) > 0 and len(opt) > 0:
|
||
print(f" 原版: shape={orig.shape}, sample={orig[:3]}")
|
||
print(f" 优化: shape={opt.shape}, sample={opt[:3]}")
|
||
else:
|
||
match = abs(orig - opt) < 1e-6
|
||
print(f" 输出 {i+1} (标量): {'[OK] 一致' if match else '[ERR] 不一致'} (原={orig:.6f}, 优={opt:.6f})")
|
||
else:
|
||
if isinstance(original_result, np.ndarray):
|
||
match = np.allclose(original_result, optimized_result, rtol=1e-5, atol=1e-8)
|
||
print(f" 结果 (数组): {'[OK] 一致' if match else '[ERR] 不一致'}")
|
||
else:
|
||
match = abs(original_result - optimized_result) < 1e-6
|
||
print(f" 结果 (标量): {'[OK] 一致' if match else '[ERR] 不一致'}")
|
||
|
||
return {
|
||
"name": func_name,
|
||
"original_time": original_time,
|
||
"optimized_time": optimized_time,
|
||
"speedup": speedup,
|
||
"improvement": improvement,
|
||
}
|
||
|
||
|
||
def main():
|
||
"""主测试流程"""
|
||
print("=" * 80)
|
||
print("收敛三角形检测 - 原版 vs Numba优化版 性能对比")
|
||
print("=" * 80)
|
||
|
||
# 配置
|
||
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
|
||
|
||
# 加载数据
|
||
high_mtx, low_mtx = load_test_data(DATA_DIR, n_stocks=10, n_days=500)
|
||
|
||
# 选择一个样本股票
|
||
sample_idx = 0
|
||
high = high_mtx[sample_idx, :]
|
||
low = low_mtx[sample_idx, :]
|
||
|
||
# 去除NaN
|
||
valid_mask = ~(np.isnan(high) | np.isnan(low))
|
||
high = high[valid_mask]
|
||
low = low[valid_mask]
|
||
|
||
print(f"\n样本数据长度: {len(high)}")
|
||
|
||
# 测试参数
|
||
k = 15
|
||
flexible_zone = 5
|
||
n_iterations = 100
|
||
|
||
results = []
|
||
|
||
# ========================================================================
|
||
# 测试 1: 枢轴点检测(标准方法)
|
||
# ========================================================================
|
||
result = compare_functions(
|
||
pivots_fractal,
|
||
pivots_fractal_optimized,
|
||
"pivots_fractal",
|
||
high, low, k,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
# ========================================================================
|
||
# 测试 2: 枢轴点检测(混合方法)
|
||
# ========================================================================
|
||
result = compare_functions(
|
||
pivots_fractal_hybrid,
|
||
pivots_fractal_hybrid_optimized,
|
||
"pivots_fractal_hybrid",
|
||
high, low, k, flexible_zone,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
# ========================================================================
|
||
# 测试 3: 锚点拟合
|
||
# ========================================================================
|
||
# 先获取枢轴点
|
||
ph, pl = pivots_fractal(high, low, k=k)
|
||
|
||
if len(ph) >= 5:
|
||
pivot_indices = ph[:10]
|
||
pivot_values = high[pivot_indices]
|
||
|
||
# 测试上沿拟合
|
||
result = compare_functions(
|
||
lambda pi, pv, ap: fit_boundary_anchor(
|
||
pi, pv, ap, mode="upper", window_start=0, window_end=len(ap)-1
|
||
),
|
||
lambda pi, pv, ap: fit_boundary_anchor_optimized(
|
||
pi, pv, ap, mode="upper", window_start=0, window_end=len(ap)-1
|
||
),
|
||
"fit_boundary_anchor (upper)",
|
||
pivot_indices, pivot_values, high,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
if len(pl) >= 5:
|
||
pivot_indices = pl[:10]
|
||
pivot_values = low[pivot_indices]
|
||
|
||
# 测试下沿拟合
|
||
result = compare_functions(
|
||
lambda pi, pv, ap: fit_boundary_anchor(
|
||
pi, pv, ap, mode="lower", window_start=0, window_end=len(ap)-1
|
||
),
|
||
lambda pi, pv, ap: fit_boundary_anchor_optimized(
|
||
pi, pv, ap, mode="lower", window_start=0, window_end=len(ap)-1
|
||
),
|
||
"fit_boundary_anchor (lower)",
|
||
pivot_indices, pivot_values, low,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
# ========================================================================
|
||
# 测试 4: 拟合贴合度计算
|
||
# ========================================================================
|
||
if len(ph) >= 5:
|
||
pivot_indices = ph[:10]
|
||
pivot_values = high[pivot_indices]
|
||
slope, intercept = 0.01, 100.0
|
||
|
||
result = compare_functions(
|
||
calc_fitting_adherence,
|
||
calc_fitting_adherence_optimized,
|
||
"calc_fitting_adherence",
|
||
pivot_indices, pivot_values, slope, intercept,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
# ========================================================================
|
||
# 测试 5: 边界利用率计算
|
||
# ========================================================================
|
||
upper_slope, upper_intercept = -0.02, 120.0
|
||
lower_slope, lower_intercept = 0.02, 80.0
|
||
start, end = 0, len(high) - 1
|
||
|
||
result = compare_functions(
|
||
calc_boundary_utilization,
|
||
calc_boundary_utilization_optimized,
|
||
"calc_boundary_utilization",
|
||
high, low, upper_slope, upper_intercept, lower_slope, lower_intercept, start, end,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
# ========================================================================
|
||
# 测试 6: 突破强度计算
|
||
# ========================================================================
|
||
result = compare_functions(
|
||
calc_breakout_strength,
|
||
calc_breakout_strength_optimized,
|
||
"calc_breakout_strength",
|
||
100.0, 105.0, 95.0, 1.5, 0.6, 0.8, 0.7,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
# ========================================================================
|
||
# 总结报告
|
||
# ========================================================================
|
||
print("\n\n")
|
||
print("=" * 80)
|
||
print("性能对比总结")
|
||
print("=" * 80)
|
||
|
||
print(f"\n{'函数名':<35} {'原版(ms)':<12} {'优化(ms)':<12} {'加速比':<10} {'提升':<10}")
|
||
print("-" * 80)
|
||
|
||
total_original = 0
|
||
total_optimized = 0
|
||
|
||
for r in results:
|
||
print(f"{r['name']:<35} {r['original_time']:<12.4f} {r['optimized_time']:<12.4f} "
|
||
f"{r['speedup']:<10.2f}x {r['improvement']:<10.1f}%")
|
||
total_original += r['original_time']
|
||
total_optimized += r['optimized_time']
|
||
|
||
print("-" * 80)
|
||
overall_speedup = total_original / total_optimized if total_optimized > 0 else 0
|
||
overall_improvement = ((total_original - total_optimized) / total_original * 100) if total_original > 0 else 0
|
||
|
||
print(f"{'总计':<35} {total_original:<12.4f} {total_optimized:<12.4f} "
|
||
f"{overall_speedup:<10.2f}x {overall_improvement:<10.1f}%")
|
||
|
||
# 估算全量数据性能提升
|
||
print("\n" + "=" * 80)
|
||
print("全量数据性能估算")
|
||
print("=" * 80)
|
||
|
||
# 从之前的测试结果,我们知道全量数据(108只股票 × 500天)需要约30.83秒
|
||
baseline_time = 30.83 # 秒
|
||
estimated_time = baseline_time / overall_speedup
|
||
time_saved = baseline_time - estimated_time
|
||
|
||
print(f"\n基于当前加速比 {overall_speedup:.2f}x 估算:")
|
||
print(f" 原版耗时: {baseline_time:.2f} 秒 ({baseline_time/60:.2f} 分钟)")
|
||
print(f" 优化后耗时: {estimated_time:.2f} 秒 ({estimated_time/60:.2f} 分钟)")
|
||
print(f" 节省时间: {time_saved:.2f} 秒 ({time_saved/60:.2f} 分钟)")
|
||
print(f" 性能提升: {overall_improvement:.1f}%")
|
||
|
||
print("\n" + "=" * 80)
|
||
print("建议:")
|
||
print(" 1. 如果加速比 > 2x,建议切换到优化版本")
|
||
print(" 2. 运行完整的集成测试验证正确性")
|
||
print(" 3. 使用 cProfile 分析优化版的新瓶颈")
|
||
print("=" * 80)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|