- Introduced a new "tilt" parameter to the strength scoring system, allowing for the assessment of triangle slope directionality. - Renamed existing parameters: "拟合贴合度" to "形态规则度" and "边界利用率" to "价格活跃度" for improved clarity. - Updated normalization methods for all strength components to ensure they remain within the [0, 1] range, facilitating LLM tuning. - Enhanced documentation to reflect changes in parameter names and scoring logic, including detailed explanations of the new tilt parameter. - Modified multiple source files and scripts to accommodate the new scoring structure and ensure backward compatibility. Files modified: - `src/converging_triangle.py`, `src/converging_triangle_optimized.py`, `src/triangle_detector_api.py`: Updated parameter names and scoring logic. - `scripts/plot_converging_triangles.py`, `scripts/generate_stock_viewer.py`: Adjusted for new scoring parameters in output. - New documentation files created to explain the renaming and new scoring system in detail.
385 lines
12 KiB
Python
385 lines
12 KiB
Python
"""
|
|
完整流水线性能测试 - 验证Numba优化效果
|
|
|
|
此脚本模拟完整的批量检测流程,对比原版和优化版的性能。
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import pickle
|
|
import time
|
|
import cProfile
|
|
import pstats
|
|
from io import StringIO
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
# 添加 src 路径
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
|
|
|
from converging_triangle import (
|
|
ConvergingTriangleParams,
|
|
detect_converging_triangle_batch as detect_batch_original,
|
|
)
|
|
|
|
|
|
class FakeModule:
|
|
"""空壳模块,绕过 model 依赖"""
|
|
ndarray = np.ndarray
|
|
|
|
|
|
def load_pkl(pkl_path: str) -> dict:
|
|
"""加载 pkl 文件"""
|
|
sys.modules['model'] = FakeModule()
|
|
sys.modules['model.index_info'] = FakeModule()
|
|
|
|
with open(pkl_path, 'rb') as f:
|
|
data = pickle.load(f)
|
|
return data
|
|
|
|
|
|
def load_test_data(data_dir: str, n_stocks: int = None, n_days: int = None):
|
|
"""加载测试数据"""
|
|
print(f"加载数据...")
|
|
|
|
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
|
|
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
|
|
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
|
|
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
|
|
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
|
|
|
|
# 截取子集(如果指定)
|
|
if n_stocks and n_days:
|
|
open_mtx = open_data["mtx"][:n_stocks, -n_days:]
|
|
high_mtx = high_data["mtx"][:n_stocks, -n_days:]
|
|
low_mtx = low_data["mtx"][:n_stocks, -n_days:]
|
|
close_mtx = close_data["mtx"][:n_stocks, -n_days:]
|
|
volume_mtx = volume_data["mtx"][:n_stocks, -n_days:]
|
|
dates = close_data["dtes"][-n_days:]
|
|
tkrs = close_data["tkrs"][:n_stocks]
|
|
tkrs_name = close_data["tkrs_name"][:n_stocks]
|
|
else:
|
|
# 全量数据
|
|
open_mtx = open_data["mtx"]
|
|
high_mtx = high_data["mtx"]
|
|
low_mtx = low_data["mtx"]
|
|
close_mtx = close_data["mtx"]
|
|
volume_mtx = volume_data["mtx"]
|
|
dates = close_data["dtes"]
|
|
tkrs = close_data["tkrs"]
|
|
tkrs_name = close_data["tkrs_name"]
|
|
|
|
print(f" 数据形状: {close_mtx.shape}")
|
|
print(f" 日期范围: {dates[0]} ~ {dates[-1]}")
|
|
|
|
return open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name
|
|
|
|
|
|
def test_pipeline(
|
|
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx,
|
|
params: ConvergingTriangleParams,
|
|
use_optimized: bool = False,
|
|
profile: bool = False
|
|
):
|
|
"""
|
|
测试完整流水线
|
|
|
|
Args:
|
|
use_optimized: 是否使用优化版本
|
|
profile: 是否生成profile
|
|
"""
|
|
print(f"\n{'='*80}")
|
|
print(f"测试: {'Numba优化版' if use_optimized else '原版'}")
|
|
print(f"{'='*80}")
|
|
|
|
n_stocks, n_days = close_mtx.shape
|
|
window = params.window
|
|
|
|
# 计算测试范围
|
|
start_day = window - 1
|
|
# 找到最后有效的数据日
|
|
any_valid = np.any(~np.isnan(close_mtx), axis=0)
|
|
valid_day_idx = np.where(any_valid)[0]
|
|
end_day = valid_day_idx[-1] if len(valid_day_idx) > 0 else n_days - 1
|
|
|
|
total_points = n_stocks * (end_day - start_day + 1)
|
|
|
|
print(f"\n配置:")
|
|
print(f" 股票数: {n_stocks}")
|
|
print(f" 交易日: {n_days}")
|
|
print(f" 窗口大小: {window}")
|
|
print(f" 检测点数: {total_points}")
|
|
print(f" 使用优化: {'是' if use_optimized else '否'}")
|
|
|
|
# 如果使用优化版,导入优化模块并替换函数
|
|
if use_optimized:
|
|
try:
|
|
print("\n导入Numba优化模块...")
|
|
import converging_triangle
|
|
from converging_triangle_optimized import (
|
|
pivots_fractal_optimized,
|
|
pivots_fractal_hybrid_optimized,
|
|
fit_boundary_anchor_optimized,
|
|
calc_geometry_score_optimized,
|
|
calc_activity_score_optimized,
|
|
calc_breakout_strength_optimized,
|
|
)
|
|
|
|
# 猴子补丁替换
|
|
converging_triangle.pivots_fractal = pivots_fractal_optimized
|
|
converging_triangle.pivots_fractal_hybrid = pivots_fractal_hybrid_optimized
|
|
converging_triangle.fit_boundary_anchor = fit_boundary_anchor_optimized
|
|
converging_triangle.calc_geometry_score = calc_geometry_score_optimized
|
|
converging_triangle.calc_activity_score = calc_activity_score_optimized
|
|
converging_triangle.calc_breakout_strength = calc_breakout_strength_optimized
|
|
|
|
print(" [OK] Numba优化已启用")
|
|
|
|
# 预热编译
|
|
print("\n预热Numba编译...")
|
|
sample_high = high_mtx[0, :window]
|
|
sample_low = low_mtx[0, :window]
|
|
valid_mask = ~(np.isnan(sample_high) | np.isnan(sample_low))
|
|
if np.sum(valid_mask) >= window:
|
|
sample_high = sample_high[valid_mask]
|
|
sample_low = sample_low[valid_mask]
|
|
_ = pivots_fractal_optimized(sample_high, sample_low, k=params.pivot_k)
|
|
print(" [OK] 预热完成")
|
|
|
|
except Exception as e:
|
|
print(f" [ERROR] 无法启用优化: {e}")
|
|
return None, 0
|
|
|
|
# 运行检测
|
|
print("\n开始批量检测...")
|
|
|
|
if profile:
|
|
profiler = cProfile.Profile()
|
|
profiler.enable()
|
|
|
|
start_time = time.time()
|
|
|
|
df = detect_batch_original(
|
|
open_mtx=open_mtx,
|
|
high_mtx=high_mtx,
|
|
low_mtx=low_mtx,
|
|
close_mtx=close_mtx,
|
|
volume_mtx=volume_mtx,
|
|
params=params,
|
|
start_day=start_day,
|
|
end_day=end_day,
|
|
only_valid=True,
|
|
verbose=False,
|
|
)
|
|
|
|
elapsed = time.time() - start_time
|
|
|
|
if profile:
|
|
profiler.disable()
|
|
|
|
# 打印profile
|
|
print("\n" + "-" * 80)
|
|
print("Profile Top 20:")
|
|
print("-" * 80)
|
|
s = StringIO()
|
|
ps = pstats.Stats(profiler, stream=s).sort_stats('cumulative')
|
|
ps.print_stats(20)
|
|
print(s.getvalue())
|
|
|
|
# 统计结果
|
|
print(f"\n{'='*80}")
|
|
print("性能统计")
|
|
print(f"{'='*80}")
|
|
print(f"\n总耗时: {elapsed:.2f} 秒 ({elapsed/60:.2f} 分钟)")
|
|
print(f"处理点数: {total_points}")
|
|
print(f"平均速度: {total_points/elapsed:.1f} 点/秒")
|
|
print(f"单点耗时: {elapsed/total_points*1000:.3f} 毫秒/点")
|
|
|
|
if len(df) > 0:
|
|
valid_count = len(df[df['is_valid'] == True])
|
|
print(f"\n检测结果:")
|
|
print(f" 有效三角形: {valid_count}")
|
|
print(f" 检出率: {valid_count/total_points*100:.2f}%")
|
|
|
|
if 'breakout_strength_up' in df.columns:
|
|
strong_up = (df['breakout_strength_up'] > 0.3).sum()
|
|
strong_down = (df['breakout_strength_down'] > 0.3).sum()
|
|
print(f" 高强度向上突破 (>0.3): {strong_up}")
|
|
print(f" 高强度向下突破 (>0.3): {strong_down}")
|
|
|
|
return df, elapsed
|
|
|
|
|
|
def compare_results(df_original, df_optimized):
|
|
"""对比两个版本的输出结果"""
|
|
print(f"\n{'='*80}")
|
|
print("结果一致性验证")
|
|
print(f"{'='*80}")
|
|
|
|
if df_original is None or df_optimized is None:
|
|
print("\n[ERROR] 无法对比:某个版本未成功运行")
|
|
return False
|
|
|
|
# 检查记录数
|
|
if len(df_original) != len(df_optimized):
|
|
print(f"\n[WARNING] 记录数不一致:")
|
|
print(f" 原版: {len(df_original)}")
|
|
print(f" 优化: {len(df_optimized)}")
|
|
return False
|
|
|
|
print(f"\n记录数: {len(df_original)} (一致 [OK])")
|
|
|
|
# 检查数值列
|
|
numeric_cols = [
|
|
'breakout_strength_up', 'breakout_strength_down',
|
|
'price_score_up', 'price_score_down',
|
|
'convergence_score', 'volume_score', 'geometry_score',
|
|
'upper_slope', 'lower_slope', 'width_ratio',
|
|
'touches_upper', 'touches_lower', 'apex_x'
|
|
]
|
|
|
|
numeric_cols = [c for c in numeric_cols if c in df_original.columns]
|
|
|
|
print(f"\n数值列对比:")
|
|
print(f"{'列名':<30} {'最大差异':<15} {'平均差异':<15} {'状态':<10}")
|
|
print("-" * 70)
|
|
|
|
all_match = True
|
|
|
|
for col in numeric_cols:
|
|
diff = (df_original[col] - df_optimized[col]).abs()
|
|
max_diff = diff.max()
|
|
mean_diff = diff.mean()
|
|
|
|
# 判断标准:最大差异 < 1e-6
|
|
match = max_diff < 1e-6
|
|
status = "[OK]" if match else "[ERR]"
|
|
|
|
print(f"{col:<30} {max_diff:<15.10f} {mean_diff:<15.10f} {status:<10}")
|
|
|
|
if not match:
|
|
all_match = False
|
|
|
|
print("-" * 70)
|
|
|
|
if all_match:
|
|
print("\n[结论] 所有数值列完全一致 (误差 < 1e-6) ✓")
|
|
else:
|
|
print("\n[结论] 发现不一致的列,请检查优化实现")
|
|
|
|
return all_match
|
|
|
|
|
|
def main():
|
|
"""主测试流程"""
|
|
print("=" * 80)
|
|
print("收敛三角形检测 - 完整流水线性能测试")
|
|
print("=" * 80)
|
|
|
|
# 配置
|
|
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
|
|
|
|
# 检测参数
|
|
params = ConvergingTriangleParams(
|
|
window=240,
|
|
pivot_k=15,
|
|
boundary_n_segments=2,
|
|
boundary_source="full",
|
|
fitting_method="anchor",
|
|
upper_slope_max=0,
|
|
lower_slope_min=0,
|
|
touch_tol=0.10,
|
|
touch_loss_max=0.10,
|
|
shrink_ratio=0.45,
|
|
break_tol=0.005,
|
|
vol_window=20,
|
|
vol_k=1.5,
|
|
false_break_m=5,
|
|
)
|
|
|
|
# 加载数据(全量)
|
|
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = \
|
|
load_test_data(DATA_DIR)
|
|
|
|
# 测试1: 原版
|
|
print("\n" + "=" * 80)
|
|
print("阶段 1/2: 测试原版")
|
|
print("=" * 80)
|
|
|
|
df_original, time_original = test_pipeline(
|
|
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx,
|
|
params,
|
|
use_optimized=False,
|
|
profile=False
|
|
)
|
|
|
|
# 测试2: 优化版
|
|
print("\n" + "=" * 80)
|
|
print("阶段 2/2: 测试优化版")
|
|
print("=" * 80)
|
|
|
|
df_optimized, time_optimized = test_pipeline(
|
|
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx,
|
|
params,
|
|
use_optimized=True,
|
|
profile=False
|
|
)
|
|
|
|
# 对比结果
|
|
if df_original is not None and df_optimized is not None:
|
|
results_match = compare_results(df_original, df_optimized)
|
|
|
|
# 性能对比
|
|
print(f"\n{'='*80}")
|
|
print("性能对比总结")
|
|
print(f"{'='*80}")
|
|
|
|
speedup = time_original / time_optimized if time_optimized > 0 else 0
|
|
improvement = ((time_original - time_optimized) / time_original * 100) if time_original > 0 else 0
|
|
time_saved = time_original - time_optimized
|
|
|
|
print(f"\n{'指标':<20} {'原版':<20} {'优化版':<20} {'改善':<20}")
|
|
print("-" * 80)
|
|
print(f"{'总耗时':<20} {time_original:.2f}秒 ({time_original/60:.2f}分) "
|
|
f"{time_optimized:.2f}秒 ({time_optimized/60:.2f}分) "
|
|
f"-{time_saved:.2f}秒")
|
|
print(f"{'加速比':<20} {'1.00x':<20} {f'{speedup:.2f}x':<20} {f'+{speedup-1:.2f}x':<20}")
|
|
print(f"{'性能提升':<20} {'0%':<20} {f'{improvement:.1f}%':<20} {f'+{improvement:.1f}%':<20}")
|
|
|
|
n_stocks, n_days = close_mtx.shape
|
|
window = params.window
|
|
end_day_idx = np.where(np.any(~np.isnan(close_mtx), axis=0))[0][-1]
|
|
total_points = n_stocks * (end_day_idx - window + 2)
|
|
|
|
speed_original = total_points / time_original
|
|
speed_optimized = total_points / time_optimized
|
|
|
|
print(f"{'处理速度':<20} {f'{speed_original:.0f}点/秒':<20} {f'{speed_optimized:.0f}点/秒':<20} {f'+{speed_optimized-speed_original:.0f}点/秒':<20}")
|
|
|
|
print("\n" + "=" * 80)
|
|
print("最终结论")
|
|
print("=" * 80)
|
|
|
|
if results_match:
|
|
print("\n[OK] 输出结果完全一致 ✓")
|
|
else:
|
|
print("\n[WARNING] 输出结果存在差异,请检查")
|
|
|
|
print(f"\n性能提升: {speedup:.1f}x ({improvement:.1f}%)")
|
|
print(f"时间节省: {time_saved:.2f}秒 ({time_saved/60:.2f}分钟)")
|
|
|
|
if speedup > 100:
|
|
print("\n[推荐] 性能提升巨大 (>100x),强烈建议部署优化版本!")
|
|
elif speedup > 10:
|
|
print("\n[推荐] 性能提升显著 (>10x),建议部署优化版本")
|
|
elif speedup > 2:
|
|
print("\n[推荐] 性能有提升 (>2x),可以考虑部署优化版本")
|
|
else:
|
|
print("\n[提示] 性能提升不明显,可能不需要部署")
|
|
|
|
print("\n" + "=" * 80)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|