- Introduced a new "tilt" parameter to the strength scoring system, allowing for the assessment of triangle slope directionality. - Renamed existing parameters: "拟合贴合度" to "形态规则度" and "边界利用率" to "价格活跃度" for improved clarity. - Updated normalization methods for all strength components to ensure they remain within the [0, 1] range, facilitating LLM tuning. - Enhanced documentation to reflect changes in parameter names and scoring logic, including detailed explanations of the new tilt parameter. - Modified multiple source files and scripts to accommodate the new scoring structure and ensure backward compatibility. Files modified: - `src/converging_triangle.py`, `src/converging_triangle_optimized.py`, `src/triangle_detector_api.py`: Updated parameter names and scoring logic. - `scripts/plot_converging_triangles.py`, `scripts/generate_stock_viewer.py`: Adjusted for new scoring parameters in output. - New documentation files created to explain the renaming and new scoring system in detail.
350 lines
12 KiB
Python
350 lines
12 KiB
Python
"""
|
||
性能对比测试 - 原版 vs Numba优化版
|
||
|
||
测试各个优化函数的性能提升效果,并生成详细的对比报告。
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import pickle
|
||
import time
|
||
import numpy as np
|
||
|
||
# 添加 src 路径
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||
|
||
# 导入原版函数
|
||
from converging_triangle import (
|
||
pivots_fractal,
|
||
pivots_fractal_hybrid,
|
||
fit_boundary_anchor,
|
||
calc_geometry_score,
|
||
calc_activity_score,
|
||
calc_breakout_strength,
|
||
)
|
||
|
||
# 导入优化版函数
|
||
from converging_triangle_optimized import (
|
||
pivots_fractal_optimized,
|
||
pivots_fractal_hybrid_optimized,
|
||
fit_boundary_anchor_optimized,
|
||
calc_geometry_score_optimized,
|
||
calc_activity_score_optimized,
|
||
calc_breakout_strength_optimized,
|
||
)
|
||
|
||
|
||
class FakeModule:
|
||
"""空壳模块,绕过 model 依赖"""
|
||
ndarray = np.ndarray
|
||
|
||
|
||
def load_pkl(pkl_path: str) -> dict:
|
||
"""加载 pkl 文件"""
|
||
sys.modules['model'] = FakeModule()
|
||
sys.modules['model.index_info'] = FakeModule()
|
||
|
||
with open(pkl_path, 'rb') as f:
|
||
data = pickle.load(f)
|
||
return data
|
||
|
||
|
||
def load_test_data(data_dir: str, n_stocks: int = 10, n_days: int = 500):
|
||
"""加载测试数据"""
|
||
print(f"加载测试数据 (stocks={n_stocks}, days={n_days})...")
|
||
|
||
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
|
||
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
|
||
|
||
high_mtx = high_data["mtx"][:n_stocks, -n_days:]
|
||
low_mtx = low_data["mtx"][:n_stocks, -n_days:]
|
||
|
||
print(f" 数据形状: {high_mtx.shape}")
|
||
|
||
return high_mtx, low_mtx
|
||
|
||
|
||
def benchmark_function(func, name, *args, n_iterations=100, warmup=5):
|
||
"""
|
||
基准测试单个函数
|
||
|
||
Args:
|
||
func: 要测试的函数
|
||
name: 函数名称
|
||
*args: 函数参数
|
||
n_iterations: 迭代次数
|
||
warmup: 预热次数
|
||
|
||
Returns:
|
||
(avg_time_ms, result): 平均耗时(毫秒)和函数结果
|
||
"""
|
||
# 预热(对于numba很重要)
|
||
for _ in range(warmup):
|
||
result = func(*args)
|
||
|
||
# 正式测试
|
||
start_time = time.time()
|
||
for _ in range(n_iterations):
|
||
result = func(*args)
|
||
elapsed = time.time() - start_time
|
||
|
||
avg_time_ms = (elapsed / n_iterations) * 1000
|
||
|
||
return avg_time_ms, result
|
||
|
||
|
||
def compare_functions(original_func, optimized_func, func_name, *args, n_iterations=100):
|
||
"""对比两个函数的性能"""
|
||
print(f"\n{'='*80}")
|
||
print(f"测试: {func_name}")
|
||
print(f"{'='*80}")
|
||
|
||
# 测试原版
|
||
print(f"\n[1] 原版函数...")
|
||
original_time, original_result = benchmark_function(
|
||
original_func, f"{func_name}_original", *args, n_iterations=n_iterations
|
||
)
|
||
print(f" 平均耗时: {original_time:.4f} 毫秒/次")
|
||
|
||
# 测试优化版
|
||
print(f"\n[2] Numba优化版...")
|
||
optimized_time, optimized_result = benchmark_function(
|
||
optimized_func, f"{func_name}_optimized", *args, n_iterations=n_iterations
|
||
)
|
||
print(f" 平均耗时: {optimized_time:.4f} 毫秒/次")
|
||
|
||
# 计算加速比
|
||
speedup = original_time / optimized_time if optimized_time > 0 else 0
|
||
improvement = ((original_time - optimized_time) / original_time * 100) if original_time > 0 else 0
|
||
|
||
print(f"\n[3] 性能对比:")
|
||
print(f" 加速比: {speedup:.2f}x")
|
||
print(f" 性能提升: {improvement:.1f}%")
|
||
print(f" 时间节省: {original_time - optimized_time:.4f} 毫秒/次")
|
||
|
||
# 验证结果一致性
|
||
print(f"\n[4] 结果验证:")
|
||
if isinstance(original_result, tuple):
|
||
for i, (orig, opt) in enumerate(zip(original_result, optimized_result)):
|
||
if isinstance(orig, np.ndarray):
|
||
match = np.allclose(orig, opt, rtol=1e-5, atol=1e-8)
|
||
print(f" 输出 {i+1} (数组): {'[OK] 一致' if match else '[ERR] 不一致'}")
|
||
if not match and len(orig) > 0 and len(opt) > 0:
|
||
print(f" 原版: shape={orig.shape}, sample={orig[:3]}")
|
||
print(f" 优化: shape={opt.shape}, sample={opt[:3]}")
|
||
else:
|
||
match = abs(orig - opt) < 1e-6
|
||
print(f" 输出 {i+1} (标量): {'[OK] 一致' if match else '[ERR] 不一致'} (原={orig:.6f}, 优={opt:.6f})")
|
||
else:
|
||
if isinstance(original_result, np.ndarray):
|
||
match = np.allclose(original_result, optimized_result, rtol=1e-5, atol=1e-8)
|
||
print(f" 结果 (数组): {'[OK] 一致' if match else '[ERR] 不一致'}")
|
||
else:
|
||
match = abs(original_result - optimized_result) < 1e-6
|
||
print(f" 结果 (标量): {'[OK] 一致' if match else '[ERR] 不一致'}")
|
||
|
||
return {
|
||
"name": func_name,
|
||
"original_time": original_time,
|
||
"optimized_time": optimized_time,
|
||
"speedup": speedup,
|
||
"improvement": improvement,
|
||
}
|
||
|
||
|
||
def main():
|
||
"""主测试流程"""
|
||
print("=" * 80)
|
||
print("收敛三角形检测 - 原版 vs Numba优化版 性能对比")
|
||
print("=" * 80)
|
||
|
||
# 配置
|
||
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
|
||
|
||
# 加载数据
|
||
high_mtx, low_mtx = load_test_data(DATA_DIR, n_stocks=10, n_days=500)
|
||
|
||
# 选择一个样本股票
|
||
sample_idx = 0
|
||
high = high_mtx[sample_idx, :]
|
||
low = low_mtx[sample_idx, :]
|
||
|
||
# 去除NaN
|
||
valid_mask = ~(np.isnan(high) | np.isnan(low))
|
||
high = high[valid_mask]
|
||
low = low[valid_mask]
|
||
|
||
print(f"\n样本数据长度: {len(high)}")
|
||
|
||
# 测试参数
|
||
k = 15
|
||
flexible_zone = 5
|
||
n_iterations = 100
|
||
|
||
results = []
|
||
|
||
# ========================================================================
|
||
# 测试 1: 枢轴点检测(标准方法)
|
||
# ========================================================================
|
||
result = compare_functions(
|
||
pivots_fractal,
|
||
pivots_fractal_optimized,
|
||
"pivots_fractal",
|
||
high, low, k,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
# ========================================================================
|
||
# 测试 2: 枢轴点检测(混合方法)
|
||
# ========================================================================
|
||
result = compare_functions(
|
||
pivots_fractal_hybrid,
|
||
pivots_fractal_hybrid_optimized,
|
||
"pivots_fractal_hybrid",
|
||
high, low, k, flexible_zone,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
# ========================================================================
|
||
# 测试 3: 锚点拟合
|
||
# ========================================================================
|
||
# 先获取枢轴点
|
||
ph, pl = pivots_fractal(high, low, k=k)
|
||
|
||
if len(ph) >= 5:
|
||
pivot_indices = ph[:10]
|
||
pivot_values = high[pivot_indices]
|
||
|
||
# 测试上沿拟合
|
||
result = compare_functions(
|
||
lambda pi, pv, ap: fit_boundary_anchor(
|
||
pi, pv, ap, mode="upper", window_start=0, window_end=len(ap)-1
|
||
),
|
||
lambda pi, pv, ap: fit_boundary_anchor_optimized(
|
||
pi, pv, ap, mode="upper", window_start=0, window_end=len(ap)-1
|
||
),
|
||
"fit_boundary_anchor (upper)",
|
||
pivot_indices, pivot_values, high,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
if len(pl) >= 5:
|
||
pivot_indices = pl[:10]
|
||
pivot_values = low[pivot_indices]
|
||
|
||
# 测试下沿拟合
|
||
result = compare_functions(
|
||
lambda pi, pv, ap: fit_boundary_anchor(
|
||
pi, pv, ap, mode="lower", window_start=0, window_end=len(ap)-1
|
||
),
|
||
lambda pi, pv, ap: fit_boundary_anchor_optimized(
|
||
pi, pv, ap, mode="lower", window_start=0, window_end=len(ap)-1
|
||
),
|
||
"fit_boundary_anchor (lower)",
|
||
pivot_indices, pivot_values, low,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
# ========================================================================
|
||
# 测试 4: 形态规则度计算
|
||
# ========================================================================
|
||
if len(ph) >= 5:
|
||
pivot_indices = ph[:10]
|
||
pivot_values = high[pivot_indices]
|
||
slope, intercept = 0.01, 100.0
|
||
|
||
result = compare_functions(
|
||
calc_geometry_score,
|
||
calc_geometry_score_optimized,
|
||
"calc_geometry_score",
|
||
pivot_indices, pivot_values, slope, intercept,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
# ========================================================================
|
||
# 测试 5: 价格活跃度计算
|
||
# ========================================================================
|
||
upper_slope, upper_intercept = -0.02, 120.0
|
||
lower_slope, lower_intercept = 0.02, 80.0
|
||
start, end = 0, len(high) - 1
|
||
|
||
result = compare_functions(
|
||
calc_activity_score,
|
||
calc_activity_score_optimized,
|
||
"calc_activity_score",
|
||
high, low, upper_slope, upper_intercept, lower_slope, lower_intercept, start, end,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
# ========================================================================
|
||
# 测试 6: 突破强度计算
|
||
# ========================================================================
|
||
result = compare_functions(
|
||
calc_breakout_strength,
|
||
calc_breakout_strength_optimized,
|
||
"calc_breakout_strength",
|
||
100.0, 105.0, 95.0, 1.5, 0.6, 0.8, 0.7,
|
||
n_iterations=n_iterations
|
||
)
|
||
results.append(result)
|
||
|
||
# ========================================================================
|
||
# 总结报告
|
||
# ========================================================================
|
||
print("\n\n")
|
||
print("=" * 80)
|
||
print("性能对比总结")
|
||
print("=" * 80)
|
||
|
||
print(f"\n{'函数名':<35} {'原版(ms)':<12} {'优化(ms)':<12} {'加速比':<10} {'提升':<10}")
|
||
print("-" * 80)
|
||
|
||
total_original = 0
|
||
total_optimized = 0
|
||
|
||
for r in results:
|
||
print(f"{r['name']:<35} {r['original_time']:<12.4f} {r['optimized_time']:<12.4f} "
|
||
f"{r['speedup']:<10.2f}x {r['improvement']:<10.1f}%")
|
||
total_original += r['original_time']
|
||
total_optimized += r['optimized_time']
|
||
|
||
print("-" * 80)
|
||
overall_speedup = total_original / total_optimized if total_optimized > 0 else 0
|
||
overall_improvement = ((total_original - total_optimized) / total_original * 100) if total_original > 0 else 0
|
||
|
||
print(f"{'总计':<35} {total_original:<12.4f} {total_optimized:<12.4f} "
|
||
f"{overall_speedup:<10.2f}x {overall_improvement:<10.1f}%")
|
||
|
||
# 估算全量数据性能提升
|
||
print("\n" + "=" * 80)
|
||
print("全量数据性能估算")
|
||
print("=" * 80)
|
||
|
||
# 从之前的测试结果,我们知道全量数据(108只股票 × 500天)需要约30.83秒
|
||
baseline_time = 30.83 # 秒
|
||
estimated_time = baseline_time / overall_speedup
|
||
time_saved = baseline_time - estimated_time
|
||
|
||
print(f"\n基于当前加速比 {overall_speedup:.2f}x 估算:")
|
||
print(f" 原版耗时: {baseline_time:.2f} 秒 ({baseline_time/60:.2f} 分钟)")
|
||
print(f" 优化后耗时: {estimated_time:.2f} 秒 ({estimated_time/60:.2f} 分钟)")
|
||
print(f" 节省时间: {time_saved:.2f} 秒 ({time_saved/60:.2f} 分钟)")
|
||
print(f" 性能提升: {overall_improvement:.1f}%")
|
||
|
||
print("\n" + "=" * 80)
|
||
print("建议:")
|
||
print(" 1. 如果加速比 > 2x,建议切换到优化版本")
|
||
print(" 2. 运行完整的集成测试验证正确性")
|
||
print(" 3. 使用 cProfile 分析优化版的新瓶颈")
|
||
print("=" * 80)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|