technical-patterns-lab/scripts/test_full_pipeline.py
褚宏光 0f8b9d836b Refactor strength scoring system with new parameters and renaming
- Introduced a new "tilt" parameter to the strength scoring system, allowing for the assessment of triangle slope directionality.
- Renamed existing parameters: "拟合贴合度" to "形态规则度" and "边界利用率" to "价格活跃度" for improved clarity.
- Updated normalization methods for all strength components to ensure they remain within the [0, 1] range, facilitating LLM tuning.
- Enhanced documentation to reflect changes in parameter names and scoring logic, including detailed explanations of the new tilt parameter.
- Modified multiple source files and scripts to accommodate the new scoring structure and ensure backward compatibility.

Files modified:
- `src/converging_triangle.py`, `src/converging_triangle_optimized.py`, `src/triangle_detector_api.py`: Updated parameter names and scoring logic.
- `scripts/plot_converging_triangles.py`, `scripts/generate_stock_viewer.py`: Adjusted for new scoring parameters in output.
- New documentation files created to explain the renaming and new scoring system in detail.
2026-01-29 15:55:50 +08:00

385 lines
12 KiB
Python

"""
完整流水线性能测试 - 验证Numba优化效果
此脚本模拟完整的批量检测流程,对比原版和优化版的性能。
"""
import os
import sys
import pickle
import time
import cProfile
import pstats
from io import StringIO
import numpy as np
import pandas as pd
# 添加 src 路径
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from converging_triangle import (
ConvergingTriangleParams,
detect_converging_triangle_batch as detect_batch_original,
)
class FakeModule:
"""空壳模块,绕过 model 依赖"""
ndarray = np.ndarray
def load_pkl(pkl_path: str) -> dict:
"""加载 pkl 文件"""
sys.modules['model'] = FakeModule()
sys.modules['model.index_info'] = FakeModule()
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data
def load_test_data(data_dir: str, n_stocks: int = None, n_days: int = None):
"""加载测试数据"""
print(f"加载数据...")
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
# 截取子集(如果指定)
if n_stocks and n_days:
open_mtx = open_data["mtx"][:n_stocks, -n_days:]
high_mtx = high_data["mtx"][:n_stocks, -n_days:]
low_mtx = low_data["mtx"][:n_stocks, -n_days:]
close_mtx = close_data["mtx"][:n_stocks, -n_days:]
volume_mtx = volume_data["mtx"][:n_stocks, -n_days:]
dates = close_data["dtes"][-n_days:]
tkrs = close_data["tkrs"][:n_stocks]
tkrs_name = close_data["tkrs_name"][:n_stocks]
else:
# 全量数据
open_mtx = open_data["mtx"]
high_mtx = high_data["mtx"]
low_mtx = low_data["mtx"]
close_mtx = close_data["mtx"]
volume_mtx = volume_data["mtx"]
dates = close_data["dtes"]
tkrs = close_data["tkrs"]
tkrs_name = close_data["tkrs_name"]
print(f" 数据形状: {close_mtx.shape}")
print(f" 日期范围: {dates[0]} ~ {dates[-1]}")
return open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name
def test_pipeline(
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx,
params: ConvergingTriangleParams,
use_optimized: bool = False,
profile: bool = False
):
"""
测试完整流水线
Args:
use_optimized: 是否使用优化版本
profile: 是否生成profile
"""
print(f"\n{'='*80}")
print(f"测试: {'Numba优化版' if use_optimized else '原版'}")
print(f"{'='*80}")
n_stocks, n_days = close_mtx.shape
window = params.window
# 计算测试范围
start_day = window - 1
# 找到最后有效的数据日
any_valid = np.any(~np.isnan(close_mtx), axis=0)
valid_day_idx = np.where(any_valid)[0]
end_day = valid_day_idx[-1] if len(valid_day_idx) > 0 else n_days - 1
total_points = n_stocks * (end_day - start_day + 1)
print(f"\n配置:")
print(f" 股票数: {n_stocks}")
print(f" 交易日: {n_days}")
print(f" 窗口大小: {window}")
print(f" 检测点数: {total_points}")
print(f" 使用优化: {'' if use_optimized else ''}")
# 如果使用优化版,导入优化模块并替换函数
if use_optimized:
try:
print("\n导入Numba优化模块...")
import converging_triangle
from converging_triangle_optimized import (
pivots_fractal_optimized,
pivots_fractal_hybrid_optimized,
fit_boundary_anchor_optimized,
calc_geometry_score_optimized,
calc_activity_score_optimized,
calc_breakout_strength_optimized,
)
# 猴子补丁替换
converging_triangle.pivots_fractal = pivots_fractal_optimized
converging_triangle.pivots_fractal_hybrid = pivots_fractal_hybrid_optimized
converging_triangle.fit_boundary_anchor = fit_boundary_anchor_optimized
converging_triangle.calc_geometry_score = calc_geometry_score_optimized
converging_triangle.calc_activity_score = calc_activity_score_optimized
converging_triangle.calc_breakout_strength = calc_breakout_strength_optimized
print(" [OK] Numba优化已启用")
# 预热编译
print("\n预热Numba编译...")
sample_high = high_mtx[0, :window]
sample_low = low_mtx[0, :window]
valid_mask = ~(np.isnan(sample_high) | np.isnan(sample_low))
if np.sum(valid_mask) >= window:
sample_high = sample_high[valid_mask]
sample_low = sample_low[valid_mask]
_ = pivots_fractal_optimized(sample_high, sample_low, k=params.pivot_k)
print(" [OK] 预热完成")
except Exception as e:
print(f" [ERROR] 无法启用优化: {e}")
return None, 0
# 运行检测
print("\n开始批量检测...")
if profile:
profiler = cProfile.Profile()
profiler.enable()
start_time = time.time()
df = detect_batch_original(
open_mtx=open_mtx,
high_mtx=high_mtx,
low_mtx=low_mtx,
close_mtx=close_mtx,
volume_mtx=volume_mtx,
params=params,
start_day=start_day,
end_day=end_day,
only_valid=True,
verbose=False,
)
elapsed = time.time() - start_time
if profile:
profiler.disable()
# 打印profile
print("\n" + "-" * 80)
print("Profile Top 20:")
print("-" * 80)
s = StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats('cumulative')
ps.print_stats(20)
print(s.getvalue())
# 统计结果
print(f"\n{'='*80}")
print("性能统计")
print(f"{'='*80}")
print(f"\n总耗时: {elapsed:.2f} 秒 ({elapsed/60:.2f} 分钟)")
print(f"处理点数: {total_points}")
print(f"平均速度: {total_points/elapsed:.1f} 点/秒")
print(f"单点耗时: {elapsed/total_points*1000:.3f} 毫秒/点")
if len(df) > 0:
valid_count = len(df[df['is_valid'] == True])
print(f"\n检测结果:")
print(f" 有效三角形: {valid_count}")
print(f" 检出率: {valid_count/total_points*100:.2f}%")
if 'breakout_strength_up' in df.columns:
strong_up = (df['breakout_strength_up'] > 0.3).sum()
strong_down = (df['breakout_strength_down'] > 0.3).sum()
print(f" 高强度向上突破 (>0.3): {strong_up}")
print(f" 高强度向下突破 (>0.3): {strong_down}")
return df, elapsed
def compare_results(df_original, df_optimized):
"""对比两个版本的输出结果"""
print(f"\n{'='*80}")
print("结果一致性验证")
print(f"{'='*80}")
if df_original is None or df_optimized is None:
print("\n[ERROR] 无法对比:某个版本未成功运行")
return False
# 检查记录数
if len(df_original) != len(df_optimized):
print(f"\n[WARNING] 记录数不一致:")
print(f" 原版: {len(df_original)}")
print(f" 优化: {len(df_optimized)}")
return False
print(f"\n记录数: {len(df_original)} (一致 [OK])")
# 检查数值列
numeric_cols = [
'breakout_strength_up', 'breakout_strength_down',
'price_score_up', 'price_score_down',
'convergence_score', 'volume_score', 'geometry_score',
'upper_slope', 'lower_slope', 'width_ratio',
'touches_upper', 'touches_lower', 'apex_x'
]
numeric_cols = [c for c in numeric_cols if c in df_original.columns]
print(f"\n数值列对比:")
print(f"{'列名':<30} {'最大差异':<15} {'平均差异':<15} {'状态':<10}")
print("-" * 70)
all_match = True
for col in numeric_cols:
diff = (df_original[col] - df_optimized[col]).abs()
max_diff = diff.max()
mean_diff = diff.mean()
# 判断标准:最大差异 < 1e-6
match = max_diff < 1e-6
status = "[OK]" if match else "[ERR]"
print(f"{col:<30} {max_diff:<15.10f} {mean_diff:<15.10f} {status:<10}")
if not match:
all_match = False
print("-" * 70)
if all_match:
print("\n[结论] 所有数值列完全一致 (误差 < 1e-6) ✓")
else:
print("\n[结论] 发现不一致的列,请检查优化实现")
return all_match
def main():
"""主测试流程"""
print("=" * 80)
print("收敛三角形检测 - 完整流水线性能测试")
print("=" * 80)
# 配置
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
# 检测参数
params = ConvergingTriangleParams(
window=240,
pivot_k=15,
boundary_n_segments=2,
boundary_source="full",
fitting_method="anchor",
upper_slope_max=0,
lower_slope_min=0,
touch_tol=0.10,
touch_loss_max=0.10,
shrink_ratio=0.45,
break_tol=0.005,
vol_window=20,
vol_k=1.5,
false_break_m=5,
)
# 加载数据(全量)
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = \
load_test_data(DATA_DIR)
# 测试1: 原版
print("\n" + "=" * 80)
print("阶段 1/2: 测试原版")
print("=" * 80)
df_original, time_original = test_pipeline(
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx,
params,
use_optimized=False,
profile=False
)
# 测试2: 优化版
print("\n" + "=" * 80)
print("阶段 2/2: 测试优化版")
print("=" * 80)
df_optimized, time_optimized = test_pipeline(
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx,
params,
use_optimized=True,
profile=False
)
# 对比结果
if df_original is not None and df_optimized is not None:
results_match = compare_results(df_original, df_optimized)
# 性能对比
print(f"\n{'='*80}")
print("性能对比总结")
print(f"{'='*80}")
speedup = time_original / time_optimized if time_optimized > 0 else 0
improvement = ((time_original - time_optimized) / time_original * 100) if time_original > 0 else 0
time_saved = time_original - time_optimized
print(f"\n{'指标':<20} {'原版':<20} {'优化版':<20} {'改善':<20}")
print("-" * 80)
print(f"{'总耗时':<20} {time_original:.2f}秒 ({time_original/60:.2f}分) "
f"{time_optimized:.2f}秒 ({time_optimized/60:.2f}分) "
f"-{time_saved:.2f}")
print(f"{'加速比':<20} {'1.00x':<20} {f'{speedup:.2f}x':<20} {f'+{speedup-1:.2f}x':<20}")
print(f"{'性能提升':<20} {'0%':<20} {f'{improvement:.1f}%':<20} {f'+{improvement:.1f}%':<20}")
n_stocks, n_days = close_mtx.shape
window = params.window
end_day_idx = np.where(np.any(~np.isnan(close_mtx), axis=0))[0][-1]
total_points = n_stocks * (end_day_idx - window + 2)
speed_original = total_points / time_original
speed_optimized = total_points / time_optimized
print(f"{'处理速度':<20} {f'{speed_original:.0f}点/秒':<20} {f'{speed_optimized:.0f}点/秒':<20} {f'+{speed_optimized-speed_original:.0f}点/秒':<20}")
print("\n" + "=" * 80)
print("最终结论")
print("=" * 80)
if results_match:
print("\n[OK] 输出结果完全一致 ✓")
else:
print("\n[WARNING] 输出结果存在差异,请检查")
print(f"\n性能提升: {speedup:.1f}x ({improvement:.1f}%)")
print(f"时间节省: {time_saved:.2f}秒 ({time_saved/60:.2f}分钟)")
if speedup > 100:
print("\n[推荐] 性能提升巨大 (>100x),强烈建议部署优化版本!")
elif speedup > 10:
print("\n[推荐] 性能提升显著 (>10x),建议部署优化版本")
elif speedup > 2:
print("\n[推荐] 性能有提升 (>2x),可以考虑部署优化版本")
else:
print("\n[提示] 性能提升不明显,可能不需要部署")
print("\n" + "=" * 80)
if __name__ == "__main__":
main()