technical-patterns-lab/scripts/test_optimization_comparison.py
褚宏光 0f8b9d836b Refactor strength scoring system with new parameters and renaming
- Introduced a new "tilt" parameter to the strength scoring system, allowing for the assessment of triangle slope directionality.
- Renamed existing parameters: "拟合贴合度" to "形态规则度" and "边界利用率" to "价格活跃度" for improved clarity.
- Updated normalization methods for all strength components to ensure they remain within the [0, 1] range, facilitating LLM tuning.
- Enhanced documentation to reflect changes in parameter names and scoring logic, including detailed explanations of the new tilt parameter.
- Modified multiple source files and scripts to accommodate the new scoring structure and ensure backward compatibility.

Files modified:
- `src/converging_triangle.py`, `src/converging_triangle_optimized.py`, `src/triangle_detector_api.py`: Updated parameter names and scoring logic.
- `scripts/plot_converging_triangles.py`, `scripts/generate_stock_viewer.py`: Adjusted for new scoring parameters in output.
- New documentation files created to explain the renaming and new scoring system in detail.
2026-01-29 15:55:50 +08:00

350 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
性能对比测试 - 原版 vs Numba优化版
测试各个优化函数的性能提升效果,并生成详细的对比报告。
"""
import os
import sys
import pickle
import time
import numpy as np
# 添加 src 路径
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
# 导入原版函数
from converging_triangle import (
pivots_fractal,
pivots_fractal_hybrid,
fit_boundary_anchor,
calc_geometry_score,
calc_activity_score,
calc_breakout_strength,
)
# 导入优化版函数
from converging_triangle_optimized import (
pivots_fractal_optimized,
pivots_fractal_hybrid_optimized,
fit_boundary_anchor_optimized,
calc_geometry_score_optimized,
calc_activity_score_optimized,
calc_breakout_strength_optimized,
)
class FakeModule:
"""空壳模块,绕过 model 依赖"""
ndarray = np.ndarray
def load_pkl(pkl_path: str) -> dict:
"""加载 pkl 文件"""
sys.modules['model'] = FakeModule()
sys.modules['model.index_info'] = FakeModule()
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data
def load_test_data(data_dir: str, n_stocks: int = 10, n_days: int = 500):
"""加载测试数据"""
print(f"加载测试数据 (stocks={n_stocks}, days={n_days})...")
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
high_mtx = high_data["mtx"][:n_stocks, -n_days:]
low_mtx = low_data["mtx"][:n_stocks, -n_days:]
print(f" 数据形状: {high_mtx.shape}")
return high_mtx, low_mtx
def benchmark_function(func, name, *args, n_iterations=100, warmup=5):
"""
基准测试单个函数
Args:
func: 要测试的函数
name: 函数名称
*args: 函数参数
n_iterations: 迭代次数
warmup: 预热次数
Returns:
(avg_time_ms, result): 平均耗时(毫秒)和函数结果
"""
# 预热对于numba很重要
for _ in range(warmup):
result = func(*args)
# 正式测试
start_time = time.time()
for _ in range(n_iterations):
result = func(*args)
elapsed = time.time() - start_time
avg_time_ms = (elapsed / n_iterations) * 1000
return avg_time_ms, result
def compare_functions(original_func, optimized_func, func_name, *args, n_iterations=100):
"""对比两个函数的性能"""
print(f"\n{'='*80}")
print(f"测试: {func_name}")
print(f"{'='*80}")
# 测试原版
print(f"\n[1] 原版函数...")
original_time, original_result = benchmark_function(
original_func, f"{func_name}_original", *args, n_iterations=n_iterations
)
print(f" 平均耗时: {original_time:.4f} 毫秒/次")
# 测试优化版
print(f"\n[2] Numba优化版...")
optimized_time, optimized_result = benchmark_function(
optimized_func, f"{func_name}_optimized", *args, n_iterations=n_iterations
)
print(f" 平均耗时: {optimized_time:.4f} 毫秒/次")
# 计算加速比
speedup = original_time / optimized_time if optimized_time > 0 else 0
improvement = ((original_time - optimized_time) / original_time * 100) if original_time > 0 else 0
print(f"\n[3] 性能对比:")
print(f" 加速比: {speedup:.2f}x")
print(f" 性能提升: {improvement:.1f}%")
print(f" 时间节省: {original_time - optimized_time:.4f} 毫秒/次")
# 验证结果一致性
print(f"\n[4] 结果验证:")
if isinstance(original_result, tuple):
for i, (orig, opt) in enumerate(zip(original_result, optimized_result)):
if isinstance(orig, np.ndarray):
match = np.allclose(orig, opt, rtol=1e-5, atol=1e-8)
print(f" 输出 {i+1} (数组): {'[OK] 一致' if match else '[ERR] 不一致'}")
if not match and len(orig) > 0 and len(opt) > 0:
print(f" 原版: shape={orig.shape}, sample={orig[:3]}")
print(f" 优化: shape={opt.shape}, sample={opt[:3]}")
else:
match = abs(orig - opt) < 1e-6
print(f" 输出 {i+1} (标量): {'[OK] 一致' if match else '[ERR] 不一致'} (原={orig:.6f}, 优={opt:.6f})")
else:
if isinstance(original_result, np.ndarray):
match = np.allclose(original_result, optimized_result, rtol=1e-5, atol=1e-8)
print(f" 结果 (数组): {'[OK] 一致' if match else '[ERR] 不一致'}")
else:
match = abs(original_result - optimized_result) < 1e-6
print(f" 结果 (标量): {'[OK] 一致' if match else '[ERR] 不一致'}")
return {
"name": func_name,
"original_time": original_time,
"optimized_time": optimized_time,
"speedup": speedup,
"improvement": improvement,
}
def main():
"""主测试流程"""
print("=" * 80)
print("收敛三角形检测 - 原版 vs Numba优化版 性能对比")
print("=" * 80)
# 配置
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
# 加载数据
high_mtx, low_mtx = load_test_data(DATA_DIR, n_stocks=10, n_days=500)
# 选择一个样本股票
sample_idx = 0
high = high_mtx[sample_idx, :]
low = low_mtx[sample_idx, :]
# 去除NaN
valid_mask = ~(np.isnan(high) | np.isnan(low))
high = high[valid_mask]
low = low[valid_mask]
print(f"\n样本数据长度: {len(high)}")
# 测试参数
k = 15
flexible_zone = 5
n_iterations = 100
results = []
# ========================================================================
# 测试 1: 枢轴点检测(标准方法)
# ========================================================================
result = compare_functions(
pivots_fractal,
pivots_fractal_optimized,
"pivots_fractal",
high, low, k,
n_iterations=n_iterations
)
results.append(result)
# ========================================================================
# 测试 2: 枢轴点检测(混合方法)
# ========================================================================
result = compare_functions(
pivots_fractal_hybrid,
pivots_fractal_hybrid_optimized,
"pivots_fractal_hybrid",
high, low, k, flexible_zone,
n_iterations=n_iterations
)
results.append(result)
# ========================================================================
# 测试 3: 锚点拟合
# ========================================================================
# 先获取枢轴点
ph, pl = pivots_fractal(high, low, k=k)
if len(ph) >= 5:
pivot_indices = ph[:10]
pivot_values = high[pivot_indices]
# 测试上沿拟合
result = compare_functions(
lambda pi, pv, ap: fit_boundary_anchor(
pi, pv, ap, mode="upper", window_start=0, window_end=len(ap)-1
),
lambda pi, pv, ap: fit_boundary_anchor_optimized(
pi, pv, ap, mode="upper", window_start=0, window_end=len(ap)-1
),
"fit_boundary_anchor (upper)",
pivot_indices, pivot_values, high,
n_iterations=n_iterations
)
results.append(result)
if len(pl) >= 5:
pivot_indices = pl[:10]
pivot_values = low[pivot_indices]
# 测试下沿拟合
result = compare_functions(
lambda pi, pv, ap: fit_boundary_anchor(
pi, pv, ap, mode="lower", window_start=0, window_end=len(ap)-1
),
lambda pi, pv, ap: fit_boundary_anchor_optimized(
pi, pv, ap, mode="lower", window_start=0, window_end=len(ap)-1
),
"fit_boundary_anchor (lower)",
pivot_indices, pivot_values, low,
n_iterations=n_iterations
)
results.append(result)
# ========================================================================
# 测试 4: 形态规则度计算
# ========================================================================
if len(ph) >= 5:
pivot_indices = ph[:10]
pivot_values = high[pivot_indices]
slope, intercept = 0.01, 100.0
result = compare_functions(
calc_geometry_score,
calc_geometry_score_optimized,
"calc_geometry_score",
pivot_indices, pivot_values, slope, intercept,
n_iterations=n_iterations
)
results.append(result)
# ========================================================================
# 测试 5: 价格活跃度计算
# ========================================================================
upper_slope, upper_intercept = -0.02, 120.0
lower_slope, lower_intercept = 0.02, 80.0
start, end = 0, len(high) - 1
result = compare_functions(
calc_activity_score,
calc_activity_score_optimized,
"calc_activity_score",
high, low, upper_slope, upper_intercept, lower_slope, lower_intercept, start, end,
n_iterations=n_iterations
)
results.append(result)
# ========================================================================
# 测试 6: 突破强度计算
# ========================================================================
result = compare_functions(
calc_breakout_strength,
calc_breakout_strength_optimized,
"calc_breakout_strength",
100.0, 105.0, 95.0, 1.5, 0.6, 0.8, 0.7,
n_iterations=n_iterations
)
results.append(result)
# ========================================================================
# 总结报告
# ========================================================================
print("\n\n")
print("=" * 80)
print("性能对比总结")
print("=" * 80)
print(f"\n{'函数名':<35} {'原版(ms)':<12} {'优化(ms)':<12} {'加速比':<10} {'提升':<10}")
print("-" * 80)
total_original = 0
total_optimized = 0
for r in results:
print(f"{r['name']:<35} {r['original_time']:<12.4f} {r['optimized_time']:<12.4f} "
f"{r['speedup']:<10.2f}x {r['improvement']:<10.1f}%")
total_original += r['original_time']
total_optimized += r['optimized_time']
print("-" * 80)
overall_speedup = total_original / total_optimized if total_optimized > 0 else 0
overall_improvement = ((total_original - total_optimized) / total_original * 100) if total_original > 0 else 0
print(f"{'总计':<35} {total_original:<12.4f} {total_optimized:<12.4f} "
f"{overall_speedup:<10.2f}x {overall_improvement:<10.1f}%")
# 估算全量数据性能提升
print("\n" + "=" * 80)
print("全量数据性能估算")
print("=" * 80)
# 从之前的测试结果我们知道全量数据108只股票 × 500天需要约30.83秒
baseline_time = 30.83 # 秒
estimated_time = baseline_time / overall_speedup
time_saved = baseline_time - estimated_time
print(f"\n基于当前加速比 {overall_speedup:.2f}x 估算:")
print(f" 原版耗时: {baseline_time:.2f} 秒 ({baseline_time/60:.2f} 分钟)")
print(f" 优化后耗时: {estimated_time:.2f} 秒 ({estimated_time/60:.2f} 分钟)")
print(f" 节省时间: {time_saved:.2f} 秒 ({time_saved/60:.2f} 分钟)")
print(f" 性能提升: {overall_improvement:.1f}%")
print("\n" + "=" * 80)
print("建议:")
print(" 1. 如果加速比 > 2x建议切换到优化版本")
print(" 2. 运行完整的集成测试验证正确性")
print(" 3. 使用 cProfile 分析优化版的新瓶颈")
print("=" * 80)
if __name__ == "__main__":
main()