diff --git a/.gitignore b/.gitignore index 964caf6..f88ab09 100644 --- a/.gitignore +++ b/.gitignore @@ -134,6 +134,9 @@ outputs/converging_triangles/all_results.csv outputs/converging_triangles/report.md outputs/converging_triangles/strong_breakout_down.csv outputs/converging_triangles/strong_breakout_up.csv +outputs/converging_triangles/daily_best.csv +outputs/converging_triangles/run_log_*.txt + # 性能分析输出 outputs/performance/*.prof diff --git a/discuss/20260127-讨论.md b/discuss/20260127-讨论.md new file mode 100644 index 0000000..7c42f58 --- /dev/null +++ b/discuss/20260127-讨论.md @@ -0,0 +1,163 @@ +![](images/2026-01-27-11-32-39.png) + +拟合线不好,需要使用 "凸优化经典算法"。 +最终是希望 上沿线或下沿线,包含大部分的 枢轴点。 + +--- + +## 已实现:凸优化拟合方法(2026-01-27) + +### 新增参数 + +```python +fitting_method: str = "iterative" # "iterative" | "lp" | "quantile" | "anchor" +``` + +### 拟合方法对比 + +| 方法 | 说明 | 优点 | 缺点 | +|------|------|------|------| +| **iterative** | 迭代离群点移除 + 最小二乘法 | 稳定保守,已有调参经验 | 线"穿过"数据而非"包住" | +| **lp** | 线性规划凸优化 | 数学严谨,保证边界包络 | 对极端值敏感 | +| **quantile** | 分位数回归 (上95%/下5%) | 统计稳健,抗异常值 | 计算稍慢 | +| **anchor** | 绝对极值锚点 + 斜率优化 | 锚点明确,线更贴近主趋势 | 对枢轴点数量较敏感 | + +### LP 方法数学原理 + +**上沿问题 (找"天花板",最紧的包络)**: +``` +minimize Σ(a*x_i + b - y_i) 线与点的总距离 +subject to y_i ≤ a * x_i + b 所有点在线下方 + -0.5 ≤ a ≤ 0.5 斜率限制 +``` + +**下沿问题 (找"地板",最紧的包络)**: +``` +minimize Σ(y_i - a*x_i - b) 线与点的总距离 +subject to y_i ≥ a * x_i + b 所有点在线上方 + -0.5 ≤ a ≤ 0.5 斜率限制 +``` + +这确保拟合线严格"包住"所有枢轴点,且尽量贴近数据,符合技术分析中"压力线/支撑线"的语义。 + +### Anchor 方法思路 + +**核心目标**:固定锚点,优化斜率,使大部分枢轴点在边界线正确一侧。 + +- 锚点:检测窗口内的绝对最高/最低点(排除最后1天用于突破判断) +- 上沿:找最“平缓”的下倾线,使 >=95% 枢轴高点在上沿线下方 +- 下沿:找最“平缓”的上倾线,使 >=95% 枢轴低点在下沿线上方 +- 实现:对斜率做二分搜索,满足覆盖率约束后取最贴近的一条线 + +### 测试验证 + +``` +上沿 LP: slope=-0.006667, intercept=10.5333 + 验证(线-点): [0.033, 0.000, 0.067, 0.033, 0.000] (全>=0,线在点上方) +下沿 LP: slope=0.005000, intercept=8.0000 + 验证(点-线): [0.00, 0.05, 0.00, 0.05, 0.00] (全>=0,线在点下方) +``` + +### 使用方法 + +```python +from src.converging_triangle import ConvergingTriangleParams, detect_converging_triangle + +# 使用凸优化/统计方法 +params = ConvergingTriangleParams( + fitting_method="lp", # 或 "quantile" / "anchor" + # ... 其他参数 +) + +result = detect_converging_triangle(high, low, close, volume, params) +``` + +### 实现位置 + +- 参数类: `ConvergingTriangleParams.fitting_method` +- LP拟合: `fit_boundary_lp()` +- 分位数回归: `fit_boundary_quantile()` +- 锚点拟合: `fit_boundary_anchor()` +- 分发函数: `fit_pivot_line_dispatch()` + +# 拟合度分数低,强度分却整体偏高 +![](images/2026-01-27-16-26-02.png) + +--- + +## 已实现:边界利用率分数(2026-01-27) + +### 问题分析 + +观察图中 SZ002748 世龙实业: +- 宽度比:0.12(非常收敛) +- 强度分:0.177(排名第三) +- 但肉眼观察:价格走势与三角形边界之间有**大量空白** + +**原因**: +- 原权重:收敛分 20%、拟合贴合度 15% +- 当宽度比 0.12 时,收敛分 = 1 - 0.12 = 0.88 +- 收敛分贡献 = 0.20 × 0.88 = 0.176 ≈ 全部强度分 +- **收敛分只衡量"形状收窄",不衡量"价格是否贴近边界"** + +### 解决方案 + +新增**边界利用率**分数,衡量价格走势对三角形通道空间的利用程度。 + +### 新增函数 + +```python +def calc_boundary_utilization( + high, low, + upper_slope, upper_intercept, + lower_slope, lower_intercept, + start, end, +) -> float: + """ + 计算边界利用率 (0~1) + + 对窗口内每一天: + 1. 计算价格到上下边界的距离 + 2. 空白比例 = (到上沿距离 + 到下沿距离) / 通道宽度 + 3. 当日利用率 = 1 - 空白比例 + + 返回平均利用率 + """ +``` + +### 新权重配置 + +| 分量 | 原权重 | 新权重 | 说明 | +|------|--------|--------|------| +| 突破幅度 | 50% | **50%** | 不变 | +| 收敛分 | 20% | **15%** | 降低 | +| 成交量分 | 15% | **10%** | 降低 | +| 拟合贴合度 | 15% | **10%** | 降低 | +| **边界利用率** | - | **15%** | 新增 | + +### 空白惩罚(新增) + +为避免“通道很宽但价格很空”的误判,加入空白惩罚: +![](images/2026-01-27-18-49-33.png) +![](images/2026-01-27-18-49-17.png) +``` +UTILIZATION_FLOOR = 0.20 +惩罚系数 = min(1, boundary_utilization / UTILIZATION_FLOOR) +最终强度分 = 原强度分 × 惩罚系数 +``` + +当边界利用率明显偏低时,总分会被进一步压制。 + +### 结果字段 + +`ConvergingTriangleResult` 新增字段: +```python +boundary_utilization: float = 0.0 # 边界利用率分数 +``` + +### 效果 + +- 价格贴近边界(空白少)→ 利用率高 → 强度分高 +- 价格远离边界(空白多)→ 利用率低 → 强度分被惩罚 +- 当边界利用率 < 0.20 时,强度分按比例衰减(空白惩罚) +- 解决"形状收敛但空白多"的误判问题 diff --git a/discuss/20260128-讨论.md b/discuss/20260128-讨论.md index f40e0cc..8e700a1 100644 --- a/discuss/20260128-讨论.md +++ b/discuss/20260128-讨论.md @@ -79,4 +79,9 @@ python scripts/pipeline_converging_triangle.py --clean --all-stocks --plot-bound - 强度分中的其他部分(价格、收敛、成交量、边界利用率)仍基于检测算法的结果 # 批量检测算法优化 -![](images/2026-01-28-17-13-37.png) \ No newline at end of file +![](images/2026-01-28-17-13-37.png) +![](images/2026-01-28-17-22-42.png) + +原来:92秒 +现在:< 2秒(首次需要3-5秒编译) +提升:50倍以上 🚀 \ No newline at end of file diff --git a/discuss/images/2026-01-28-11-16-46.png b/discuss/images/2026-01-28-11-16-46.png new file mode 100644 index 0000000..0b957df Binary files /dev/null and b/discuss/images/2026-01-28-11-16-46.png differ diff --git a/discuss/images/2026-01-28-11-16-56.png b/discuss/images/2026-01-28-11-16-56.png new file mode 100644 index 0000000..c24cdbb Binary files /dev/null and b/discuss/images/2026-01-28-11-16-56.png differ diff --git a/discuss/images/2026-01-28-15-56-12.png b/discuss/images/2026-01-28-15-56-12.png new file mode 100644 index 0000000..6a8e497 Binary files /dev/null and b/discuss/images/2026-01-28-15-56-12.png differ diff --git a/discuss/images/2026-01-28-17-13-37.png b/discuss/images/2026-01-28-17-13-37.png new file mode 100644 index 0000000..f189884 Binary files /dev/null and b/discuss/images/2026-01-28-17-13-37.png differ diff --git a/discuss/images/2026-01-28-17-22-42.png b/discuss/images/2026-01-28-17-22-42.png new file mode 100644 index 0000000..08f6c7c Binary files /dev/null and b/discuss/images/2026-01-28-17-22-42.png differ diff --git a/scripts/pipeline_converging_triangle.py b/scripts/pipeline_converging_triangle.py index 2dcc90b..810ade7 100644 --- a/scripts/pipeline_converging_triangle.py +++ b/scripts/pipeline_converging_triangle.py @@ -95,8 +95,8 @@ def main() -> None: parser.add_argument( "--plot-boundary-source", choices=["hl", "close"], - default="close", - help="绘图时边界线拟合数据源: hl=高低价, close=收盘价(不影响检测)", + default="hl", + help="绘图时边界线拟合数据源: hl=高低价(默认), close=收盘价(不影响检测)", ) args = parser.parse_args() diff --git a/scripts/plot_converging_triangles.py b/scripts/plot_converging_triangles.py index cee3f8f..cc5a825 100644 --- a/scripts/plot_converging_triangles.py +++ b/scripts/plot_converging_triangles.py @@ -495,8 +495,8 @@ def main() -> None: parser.add_argument( "--plot-boundary-source", choices=["hl", "close"], - default="close", - help="绘图时边界线拟合数据源: hl=高低价, close=收盘价(不影响检测)", + default="hl", + help="绘图时边界线拟合数据源: hl=高低价(默认), close=收盘价(不影响检测)", ) parser.add_argument( "--show-high-low", diff --git a/scripts/run_converging_triangle.py b/scripts/run_converging_triangle.py index bf41397..7d68f2b 100644 --- a/scripts/run_converging_triangle.py +++ b/scripts/run_converging_triangle.py @@ -9,6 +9,8 @@ import pickle import time import numpy as np import pandas as pd +from datetime import datetime +from io import StringIO # 让脚本能找到 src/ 下的模块 sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src")) @@ -17,6 +19,7 @@ from converging_triangle import ( ConvergingTriangleParams, ConvergingTriangleResult, detect_converging_triangle_batch, + detect_converging_triangle_batch_v2, # v2优化版本 ) # ============================================================================ @@ -34,6 +37,7 @@ from triangle_config import ( VERBOSE, REALTIME_MODE, # 新增 FLEXIBLE_ZONE, # 新增 + USE_V2_OPTIMIZATION, # v2优化开关 ) @@ -89,6 +93,30 @@ def load_ohlcv_from_pkl(data_dir: str) -> tuple: ) +# ============================================================================ +# 日志工具 +# ============================================================================ + +class Logger: + """同时输出到控制台和日志文件""" + def __init__(self, log_path: str): + self.log_path = log_path + self.buffer = StringIO() + + def print(self, *args, **kwargs): + """打印到控制台和缓冲区""" + # 打印到控制台 + print(*args, **kwargs) + # 写入缓冲区 + print(*args, **kwargs, file=self.buffer) + + def save(self): + """保存日志到文件""" + with open(self.log_path, 'w', encoding='utf-8') as f: + f.write(self.buffer.getvalue()) + print(f"\n日志已保存到: {self.log_path}") + + # ============================================================================ # 主流程 # ============================================================================ @@ -96,29 +124,37 @@ def load_ohlcv_from_pkl(data_dir: str) -> tuple: def main() -> None: start_time = time.time() - print("=" * 70) - print("Converging Triangle Batch Detection") - print("=" * 70) + # 初始化日志 + outputs_dir = os.path.join(os.path.dirname(__file__), "..", "outputs", "converging_triangles") + os.makedirs(outputs_dir, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_path = os.path.join(outputs_dir, f"run_log_{timestamp}.txt") + log = Logger(log_path) + + log.print("=" * 70) + log.print("Converging Triangle Batch Detection") + log.print(f"运行时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + log.print("=" * 70) # 1. 加载数据 - print("\n[1] Loading OHLCV pkl files...") + log.print("\n[1] Loading OHLCV pkl files...") load_start = time.time() open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(DATA_DIR) load_time = time.time() - load_start n_stocks, n_days = close_mtx.shape - print(f" Stocks: {n_stocks}") - print(f" Days: {n_days}") - print(f" Date range: {dates[0]} ~ {dates[-1]}") - print(f" 加载耗时: {load_time:.2f} 秒") + log.print(f" Stocks: {n_stocks}") + log.print(f" Days: {n_days}") + log.print(f" Date range: {dates[0]} ~ {dates[-1]}") + log.print(f" 加载耗时: {load_time:.2f} 秒") # 2. 找有效数据范围(排除全 NaN 的列) any_valid = np.any(~np.isnan(close_mtx), axis=0) valid_day_idx = np.where(any_valid)[0] if len(valid_day_idx) == 0: - print("No valid data found!") + log.print("No valid data found!") return last_valid_day = valid_day_idx[-1] - print(f" Last valid day: {last_valid_day} (date: {dates[last_valid_day]})") + log.print(f" Last valid day: {last_valid_day} (date: {dates[last_valid_day]})") # 3. 计算范围 start_day = PARAMS.window - 1 @@ -127,74 +163,97 @@ def main() -> None: if RECENT_DAYS is not None: start_day = max(start_day, end_day - RECENT_DAYS + 1) - print(f"\n[2] Detection range: day {start_day} ~ {end_day}") - print(f" Window size: {PARAMS.window}") - print(f" Total points: {n_stocks} x {end_day - start_day + 1} = {n_stocks * (end_day - start_day + 1)}") + log.print(f"\n[2] Detection range: day {start_day} ~ {end_day}") + log.print(f" Window size: {PARAMS.window}") + log.print(f" Total points: {n_stocks} x {end_day - start_day + 1} = {n_stocks * (end_day - start_day + 1)}") # 3. 批量检测 - print("\n[3] Running batch detection...") + log.print("\n[3] Running batch detection...") detect_start = time.time() - df = detect_converging_triangle_batch( - open_mtx=open_mtx, - high_mtx=high_mtx, - low_mtx=low_mtx, - close_mtx=close_mtx, - volume_mtx=volume_mtx, - params=PARAMS, - start_day=start_day, - end_day=end_day, - only_valid=ONLY_VALID, - verbose=VERBOSE, - real_time_mode=REALTIME_MODE, # 新增 - flexible_zone=FLEXIBLE_ZONE, # 新增 - ) + + # 根据配置选择v1或v2版本 + if USE_V2_OPTIMIZATION and not REALTIME_MODE: + # v2优化版本(预计算枢轴点,不支持实时模式) + log.print(" 使用: v2优化版本(预计算枢轴点)") + df = detect_converging_triangle_batch_v2( + open_mtx=open_mtx, + high_mtx=high_mtx, + low_mtx=low_mtx, + close_mtx=close_mtx, + volume_mtx=volume_mtx, + params=PARAMS, + start_day=start_day, + end_day=end_day, + only_valid=ONLY_VALID, + verbose=VERBOSE, + ) + else: + # v1原版(支持实时模式) + if USE_V2_OPTIMIZATION and REALTIME_MODE: + log.print(" 注意: v2优化不支持实时模式,自动回退到v1") + log.print(" 使用: v1原版") + df = detect_converging_triangle_batch( + open_mtx=open_mtx, + high_mtx=high_mtx, + low_mtx=low_mtx, + close_mtx=close_mtx, + volume_mtx=volume_mtx, + params=PARAMS, + start_day=start_day, + end_day=end_day, + only_valid=ONLY_VALID, + verbose=VERBOSE, + real_time_mode=REALTIME_MODE, + flexible_zone=FLEXIBLE_ZONE, + ) + detect_time = time.time() - detect_start - print(f" 检测耗时: {detect_time:.2f} 秒") - print(f" 检测模式: {'实时模式' if REALTIME_MODE else '标准模式'}") - if REALTIME_MODE: - print(f" 灵活区域: {FLEXIBLE_ZONE} 天") + log.print(f" 检测耗时: {detect_time:.2f} 秒") + if not USE_V2_OPTIMIZATION or REALTIME_MODE: + log.print(f" 检测模式: {'实时模式' if REALTIME_MODE else '标准模式'}") + if REALTIME_MODE: + log.print(f" 灵活区域: {FLEXIBLE_ZONE} 天") # 4. 添加股票代码、名称和真实日期 if len(df) > 0: df["stock_code"] = df["stock_idx"].map(lambda x: tkrs[x] if x < len(tkrs) else "") df["stock_name"] = df["stock_idx"].map(lambda x: tkrs_name[x] if x < len(tkrs_name) else "") df["date"] = df["date_idx"].map(lambda x: dates[x] if x < len(dates) else 0) + # 计算综合强度(取向上和向下的最大值) + df["max_strength"] = df[["breakout_strength_up", "breakout_strength_down"]].max(axis=1) # 5. 输出结果 - print("\n" + "=" * 70) - print("Detection Results") - print("=" * 70) + log.print("\n" + "=" * 70) + log.print("Detection Results") + log.print("=" * 70) if ONLY_VALID: - print(f"\nTotal valid triangles detected: {len(df)}") + log.print(f"\nTotal valid triangles detected: {len(df)}") else: valid_count = df["is_valid"].sum() - print(f"\nTotal records: {len(df)}") - print(f"Valid triangles: {valid_count} ({valid_count/len(df)*100:.1f}%)") + log.print(f"\nTotal records: {len(df)}") + log.print(f"Valid triangles: {valid_count} ({valid_count/len(df)*100:.1f}%)") # 按突破方向统计 if len(df) > 0 and "breakout_dir" in df.columns: breakout_stats = df[df["is_valid"] == True]["breakout_dir"].value_counts() - print(f"\nBreakout statistics:") + log.print(f"\nBreakout statistics:") for dir_name, count in breakout_stats.items(): - print(f" - {dir_name}: {count}") + log.print(f" - {dir_name}: {count}") # 突破强度统计 if len(df) > 0 and "breakout_strength_up" in df.columns: valid_df = df[df["is_valid"] == True] if len(valid_df) > 0: - print(f"\nBreakout strength (valid triangles):") - print(f" - Up mean: {valid_df['breakout_strength_up'].mean():.4f}, max: {valid_df['breakout_strength_up'].max():.4f}") - print(f" - Down mean: {valid_df['breakout_strength_down'].mean():.4f}, max: {valid_df['breakout_strength_down'].max():.4f}") + log.print(f"\nBreakout strength (valid triangles):") + log.print(f" - Up mean: {valid_df['breakout_strength_up'].mean():.4f}, max: {valid_df['breakout_strength_up'].max():.4f}") + log.print(f" - Down mean: {valid_df['breakout_strength_down'].mean():.4f}, max: {valid_df['breakout_strength_down'].max():.4f}") # 6. 保存结果 - outputs_dir = os.path.join(os.path.dirname(__file__), "..", "outputs", "converging_triangles") - os.makedirs(outputs_dir, exist_ok=True) - # 保存完整 CSV (使用 utf-8-sig 支持中文) csv_path = os.path.join(outputs_dir, "all_results.csv") df.to_csv(csv_path, index=False, encoding="utf-8-sig") - print(f"\nResults saved to: {csv_path}") + log.print(f"\nResults saved to: {csv_path}") # 保存高强度突破记录 (strength > 0.3) if len(df) > 0: @@ -204,30 +263,79 @@ def main() -> None: if len(strong_up) > 0: strong_up_path = os.path.join(outputs_dir, "strong_breakout_up.csv") strong_up.to_csv(strong_up_path, index=False, encoding="utf-8-sig") - print(f"Strong up breakouts ({len(strong_up)}): {strong_up_path}") + log.print(f"Strong up breakouts ({len(strong_up)}): {strong_up_path}") if len(strong_down) > 0: strong_down_path = os.path.join(outputs_dir, "strong_breakout_down.csv") strong_down.to_csv(strong_down_path, index=False, encoding="utf-8-sig") - print(f"Strong down breakouts ({len(strong_down)}): {strong_down_path}") + log.print(f"Strong down breakouts ({len(strong_down)}): {strong_down_path}") - # 7. 显示样本 + # 7. 每日最佳股票报告(核心新功能) + if len(df) > 0 and "date" in df.columns: + log.print("\n" + "=" * 70) + log.print("每日最佳收敛三角形股票 (按日期)") + log.print("=" * 70) + + valid_df = df[df["is_valid"] == True].copy() + if len(valid_df) > 0: + # 按日期分组,每天取分数最高的股票 + daily_best = valid_df.loc[valid_df.groupby("date")["max_strength"].idxmax()] + daily_best = daily_best.sort_values("date", ascending=True) + + log.print(f"\n共 {len(daily_best)} 个交易日检测到收敛三角形:") + log.print("-" * 90) + log.print(f"{'日期':<12} {'股票代码':<10} {'股票名称':<12} {'强度↑':>8} {'强度↓':>8} {'方向':<6} {'收敛比':>8}") + log.print("-" * 90) + + for _, row in daily_best.iterrows(): + log.print( + f"{int(row['date']):<12} " + f"{row['stock_code']:<10} " + f"{row['stock_name']:<12} " + f"{row['breakout_strength_up']:>8.4f} " + f"{row['breakout_strength_down']:>8.4f} " + f"{row['breakout_dir']:<6} " + f"{row['width_ratio']:>8.3f}" + ) + + log.print("-" * 90) + + # 保存每日最佳到 CSV + daily_best_path = os.path.join(outputs_dir, "daily_best.csv") + daily_best.to_csv(daily_best_path, index=False, encoding="utf-8-sig") + log.print(f"\n每日最佳已保存到: {daily_best_path}") + + # 统计:每只股票被选为每日最佳的次数 + log.print("\n" + "-" * 70) + log.print("股票被选为每日最佳的次数排行:") + log.print("-" * 70) + stock_counts = daily_best.groupby(["stock_code", "stock_name"]).size().reset_index(name="count") + stock_counts = stock_counts.sort_values("count", ascending=False).head(20) + for _, row in stock_counts.iterrows(): + log.print(f" {row['stock_code']} {row['stock_name']}: {row['count']} 次") + else: + log.print("\n没有检测到有效的收敛三角形") + + # 8. 显示样本 if len(df) > 0: - print("\n" + "-" * 70) - print("Sample results (first 10):") - print("-" * 70) + log.print("\n" + "-" * 70) + log.print("Sample results (first 10):") + log.print("-" * 70) display_cols = [ "stock_code", "date", "is_valid", "breakout_strength_up", "breakout_strength_down", "breakout_dir", "width_ratio" ] display_cols = [c for c in display_cols if c in df.columns] - print(df[display_cols].head(10).to_string(index=False)) + log.print(df[display_cols].head(10).to_string(index=False)) total_time = time.time() - start_time - print("\n" + "=" * 70) - print(f"总耗时: {total_time:.2f} 秒 ({total_time/60:.2f} 分钟)") - print("=" * 70) + log.print("\n" + "=" * 70) + log.print(f"总耗时: {total_time:.2f} 秒 ({total_time/60:.2f} 分钟)") + log.print("=" * 70) + + # 保存日志 + log.save() if __name__ == "__main__": diff --git a/scripts/triangle_config.py b/scripts/triangle_config.py index d1c8ee0..5552468 100644 --- a/scripts/triangle_config.py +++ b/scripts/triangle_config.py @@ -120,6 +120,17 @@ SHOW_CHART_DETAILS = False # False=简洁模式(默认),True=详细模式 # 简洁模式:仅显示收盘价、上沿线、下沿线 # 详细模式:额外显示所有枢轴点、拟合点、分段线等调试信息 + +# ============================================================================ +# 性能优化配置 +# ============================================================================ + +# 是否使用v2优化版本(预计算枢轴点) +USE_V2_OPTIMIZATION = False # True=v2优化,False=v1原版(推荐) +# v2优化通过预计算整个时间序列的枢轴点,避免滑动窗口重复计算 +# 预期加速:3-5x +# 注意:v2版本暂不支持实时模式(REALTIME_MODE),会自动回退到标准模式 + # ============================================================================ # 推荐参数预设(备选方案) # ============================================================================ diff --git a/src/converging_triangle.py b/src/converging_triangle.py index c00cf08..ccfcaac 100644 --- a/src/converging_triangle.py +++ b/src/converging_triangle.py @@ -1416,6 +1416,9 @@ try: calc_fitting_adherence_optimized, calc_boundary_utilization_optimized, calc_breakout_strength_optimized, + # v2优化:预计算枢轴点 + precompute_pivots_numba, + detect_batch_with_precomputed_pivots_numba, ) # 用优化版本覆盖原版函数(在模块级别) pivots_fractal = pivots_fractal_optimized @@ -1425,7 +1428,148 @@ try: calc_boundary_utilization = calc_boundary_utilization_optimized calc_breakout_strength = calc_breakout_strength_optimized - print("[性能优化] 已启用Numba加速 (预计加速300x)") + _HAS_V2_OPTIMIZATION = True + print("[性能优化] 已启用Numba加速 + 预计算枢轴点优化") except ImportError: + _HAS_V2_OPTIMIZATION = False print("[性能优化] 未启用Numba加速,使用原版函数") + + +# ============================================================================ +# 【v2优化】使用预计算枢轴点的批量检测 +# ============================================================================ + +def detect_converging_triangle_batch_v2( + open_mtx: np.ndarray, + high_mtx: np.ndarray, + low_mtx: np.ndarray, + close_mtx: np.ndarray, + volume_mtx: np.ndarray, + params: ConvergingTriangleParams, + start_day: Optional[int] = None, + end_day: Optional[int] = None, + only_valid: bool = False, + verbose: bool = False, +) -> pd.DataFrame: + """ + 【v2优化】使用预计算枢轴点的批量检测 + + 优化策略: + 1. 对每只股票,一次性预计算整个历史的枢轴点 + 2. 每个检测窗口直接读取预计算结果,避免重复计算 + 3. 整个股票的所有日期检测在一次Numba调用中完成 + + 预期加速:3-5x(相比v1版本) + + Args: + 与 detect_converging_triangle_batch 相同 + + Returns: + DataFrame(格式与v1相同) + """ + if not _HAS_V2_OPTIMIZATION: + print("[警告] v2优化未启用,回退到v1版本") + return detect_converging_triangle_batch( + open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, + params, start_day, end_day, only_valid, verbose + ) + + n_stocks, n_days = close_mtx.shape + window = params.window + + # 默认起止日 + if start_day is None: + start_day = window - 1 + if end_day is None: + end_day = n_days - 1 + + # 确保范围有效 + start_day = max(window - 1, start_day) + end_day = min(n_days - 1, end_day) + + results: List[dict] = [] + + for stock_idx in range(n_stocks): + # 提取该股票的全部数据 + high_stock = high_mtx[stock_idx, :] + low_stock = low_mtx[stock_idx, :] + close_stock = close_mtx[stock_idx, :] + volume_stock = volume_mtx[stock_idx, :] if volume_mtx is not None else np.zeros(n_days) + + # 找到有效数据的mask + valid_mask = ~np.isnan(close_stock) + valid_indices = np.where(valid_mask)[0].astype(np.int32) + + if len(valid_indices) < window: + continue + + # 【v2核心优化】预计算整个时间序列的枢轴点 + is_pivot_high, is_pivot_low = precompute_pivots_numba( + high_stock, low_stock, k=params.pivot_k + ) + + # 批量检测(单次Numba调用处理所有日期) + (date_indices, is_valid_arr, strength_up_arr, strength_down_arr, + price_score_up_arr, price_score_down_arr, convergence_score_arr, + vol_score_arr, fitting_score_arr, boundary_util_score_arr, + upper_slope_arr, lower_slope_arr, width_ratio_arr, + touches_upper_arr, touches_lower_arr, apex_x_arr, + breakout_dir_arr, volume_confirmed_arr) = \ + detect_batch_with_precomputed_pivots_numba( + high_stock, low_stock, close_stock, volume_stock, + is_pivot_high, is_pivot_low, + valid_indices, window, start_day, end_day, + params.upper_slope_max, params.lower_slope_min, + params.shrink_ratio, params.touch_tol, params.break_tol, + params.vol_window, params.vol_k + ) + + # 收集结果 + for i in range(len(date_indices)): + if only_valid and not is_valid_arr[i]: + continue + + # 转换 breakout_dir + bd = breakout_dir_arr[i] + breakout_dir_str = "none" if bd == 0 else ("up" if bd == 1 else "down") + + # 转换 volume_confirmed + vc = volume_confirmed_arr[i] + vol_confirmed = None if vc == -1 else bool(vc) + + results.append({ + 'stock_idx': stock_idx, + 'date_idx': int(date_indices[i]), + 'is_valid': bool(is_valid_arr[i]), + 'breakout_strength_up': float(strength_up_arr[i]), + 'breakout_strength_down': float(strength_down_arr[i]), + 'price_score_up': float(price_score_up_arr[i]), + 'price_score_down': float(price_score_down_arr[i]), + 'convergence_score': float(convergence_score_arr[i]), + 'volume_score': float(vol_score_arr[i]), + 'fitting_score': float(fitting_score_arr[i]), + 'boundary_utilization': float(boundary_util_score_arr[i]), + 'upper_slope': float(upper_slope_arr[i]), + 'lower_slope': float(lower_slope_arr[i]), + 'width_ratio': float(width_ratio_arr[i]), + 'touches_upper': int(touches_upper_arr[i]), + 'touches_lower': int(touches_lower_arr[i]), + 'apex_x': float(apex_x_arr[i]), + 'breakout_dir': breakout_dir_str, + 'volume_confirmed': vol_confirmed, + 'false_breakout': None, + 'window_start': 0, + 'window_end': int(date_indices[i]), + 'detection_mode': 'v2_optimized', + 'has_candidate_pivots': False, + 'candidate_pivot_count': 0, + }) + + if verbose and (stock_idx + 1) % 20 == 0: + print(f" Progress: {stock_idx + 1}/{n_stocks} stocks") + + if verbose: + print(f" Completed: {len(results)} results") + + return pd.DataFrame(results) # ============================================================================ diff --git a/src/converging_triangle_optimized.py b/src/converging_triangle_optimized.py index 68bd0ab..0cca79c 100644 --- a/src/converging_triangle_optimized.py +++ b/src/converging_triangle_optimized.py @@ -6,6 +6,7 @@ 2. 优化枢轴点检测算法(避免重复的nanmax/nanmin调用) 3. 优化边界拟合算法(向量化计算) 4. 减少不必要的数组复制 +5. 【v2】预计算枢轴点矩阵,避免滑动窗口重复计算 不使用并行(按要求)。 """ @@ -14,7 +15,7 @@ from __future__ import annotations import numba import numpy as np -from typing import Tuple +from typing import Tuple, List, Optional # ============================================================================ # Numba优化的核心函数 @@ -592,3 +593,385 @@ def calc_breakout_strength_optimized( close, upper_line, lower_line, volume_ratio, width_ratio, fitting_adherence, boundary_utilization ) + + +# ============================================================================ +# 【v2优化】预计算枢轴点矩阵 +# ============================================================================ + +@numba.jit(nopython=True, cache=True) +def precompute_pivots_numba( + high: np.ndarray, + low: np.ndarray, + k: int = 15 +) -> Tuple[np.ndarray, np.ndarray]: + """ + 预计算整个时间序列的枢轴点标记 + + 优化思路:一次性计算所有枢轴点,避免滑动窗口重复计算 + + Args: + high: 最高价数组 (n_days,) + low: 最低价数组 (n_days,) + k: 窗口大小(左右各k天) + + Returns: + (is_pivot_high, is_pivot_low): 布尔数组,标记每天是否为枢轴点 + """ + n = len(high) + is_pivot_high = np.zeros(n, dtype=np.bool_) + is_pivot_low = np.zeros(n, dtype=np.bool_) + + for i in range(k, n - k): + if np.isnan(high[i]) or np.isnan(low[i]): + continue + + # 高点检测 + is_ph = True + h_val = high[i] + for j in range(i - k, i + k + 1): + if j == i: + continue + if not np.isnan(high[j]) and high[j] > h_val: + is_ph = False + break + is_pivot_high[i] = is_ph + + # 低点检测 + is_pl = True + l_val = low[i] + for j in range(i - k, i + k + 1): + if j == i: + continue + if not np.isnan(low[j]) and low[j] < l_val: + is_pl = False + break + is_pivot_low[i] = is_pl + + return is_pivot_high, is_pivot_low + + +@numba.jit(nopython=True, cache=True) +def get_pivots_in_window( + is_pivot: np.ndarray, + start: int, + end: int +) -> np.ndarray: + """从预计算的布尔数组中提取窗口内的枢轴点索引""" + # 计算窗口内枢轴点数量 + count = 0 + for i in range(start, end + 1): + if is_pivot[i]: + count += 1 + + # 提取索引 + result = np.empty(count, dtype=np.int32) + idx = 0 + for i in range(start, end + 1): + if is_pivot[i]: + result[idx] = i - start # 转换为窗口内的相对索引 + idx += 1 + + return result + + +@numba.jit(nopython=True, cache=True) +def detect_single_with_precomputed_pivots( + high: np.ndarray, + low: np.ndarray, + close: np.ndarray, + volume: np.ndarray, + is_pivot_high: np.ndarray, + is_pivot_low: np.ndarray, + window_start: int, + window_end: int, + # 参数 + upper_slope_max: float, + lower_slope_min: float, + shrink_ratio: float, + touch_tol: float, + break_tol: float, + vol_window: int, + vol_k: float, +) -> Tuple[bool, float, float, float, float, float, float, float, float, + float, float, float, int, int, float, int, int]: + """ + 使用预计算枢轴点的单点检测(纯Numba实现) + + Returns: + (is_valid, strength_up, strength_down, price_score_up, price_score_down, + convergence_score, vol_score, fitting_score, boundary_util_score, + upper_slope, lower_slope, width_ratio, touches_upper, touches_lower, + apex_x, breakout_dir, volume_confirmed) + + breakout_dir: 0=none, 1=up, 2=down + volume_confirmed: 0=False, 1=True, -1=None + """ + n = window_end - window_start + 1 + + # 默认无效结果 + invalid_result = (False, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0, 0, 0.0, 0, -1) + + if n < 50: # 最小窗口检查 + return invalid_result + + # 提取窗口内的枢轴点 + ph_count = 0 + pl_count = 0 + for i in range(window_start, window_end + 1): + if is_pivot_high[i]: + ph_count += 1 + if is_pivot_low[i]: + pl_count += 1 + + if ph_count < 2 or pl_count < 2: + return invalid_result + + # 收集枢轴点索引和值 + ph_indices = np.empty(ph_count, dtype=np.int32) + ph_values = np.empty(ph_count, dtype=np.float64) + pl_indices = np.empty(pl_count, dtype=np.int32) + pl_values = np.empty(pl_count, dtype=np.float64) + + ph_idx = 0 + pl_idx = 0 + for i in range(window_start, window_end + 1): + rel_i = i - window_start # 相对索引 + if is_pivot_high[i]: + ph_indices[ph_idx] = rel_i + ph_values[ph_idx] = high[i] + ph_idx += 1 + if is_pivot_low[i]: + pl_indices[pl_idx] = rel_i + pl_values[pl_idx] = low[i] + pl_idx += 1 + + # 边界拟合(anchor方法) + high_win = high[window_start:window_end + 1] + low_win = low[window_start:window_end + 1] + + a_u, b_u = fit_boundary_anchor_numba( + ph_indices.astype(np.float64), ph_values, + high_win, mode=0, coverage=0.95, exclude_last=1, + window_start=0, window_end=n - 1 + ) + + a_l, b_l = fit_boundary_anchor_numba( + pl_indices.astype(np.float64), pl_values, + low_win, mode=1, coverage=0.95, exclude_last=1, + window_start=0, window_end=n - 1 + ) + + # 斜率检查 + if a_u > upper_slope_max or a_l < lower_slope_min: + return invalid_result + + # 相向收敛检查 + slope_tolerance = 0.01 + both_descending = (a_u < -slope_tolerance) and (a_l < -slope_tolerance) + both_ascending = (a_u > slope_tolerance) and (a_l > slope_tolerance) + if both_descending or both_ascending: + return invalid_result + + # 宽度收敛检查 + start_idx = 0 + end_idx = n - 1 + upper_start = a_u * start_idx + b_u + lower_start = a_l * start_idx + b_l + upper_end = a_u * end_idx + b_u + lower_end = a_l * end_idx + b_l + + width_start = upper_start - lower_start + width_end = upper_end - lower_end + + if width_start <= 0 or width_end <= 0: + return invalid_result + + width_ratio = width_end / width_start + if width_ratio > shrink_ratio: + return invalid_result + + # 触碰检测 + touches_upper = 0 + touches_lower = 0 + + for i in range(ph_count): + line_y = a_u * ph_indices[i] + b_u + deviation = abs(ph_values[i] - line_y) / max(line_y, 1e-9) + if deviation <= touch_tol: + touches_upper += 1 + + for i in range(pl_count): + line_y = a_l * pl_indices[i] + b_l + deviation = abs(pl_values[i] - line_y) / max(line_y, 1e-9) + if deviation <= touch_tol: + touches_lower += 1 + + if touches_upper < 2 or touches_lower < 2: + return invalid_result + + # Apex计算 + denom = a_u - a_l + if abs(denom) > 1e-12: + apex_x = (b_l - b_u) / denom + else: + apex_x = 1e9 + + # 突破判定 + close_val = close[window_end] + breakout_dir = 0 # 0=none, 1=up, 2=down + if close_val > upper_end * (1 + break_tol): + breakout_dir = 1 + elif close_val < lower_end * (1 - break_tol): + breakout_dir = 2 + + # 成交量确认 + volume_confirmed = -1 # -1 = None + volume_ratio = 1.0 + if vol_window > 0 and window_end >= vol_window: + vol_sum = 0.0 + for i in range(window_end - vol_window + 1, window_end + 1): + vol_sum += volume[i] + vol_ma = vol_sum / vol_window + if vol_ma > 0: + volume_ratio = volume[window_end] / vol_ma + if breakout_dir != 0: + volume_confirmed = 1 if volume[window_end] > vol_ma * vol_k else 0 + + # 计算拟合贴合度 + adherence_upper = calc_fitting_adherence_numba( + ph_indices.astype(np.float64), ph_values, a_u, b_u + ) + adherence_lower = calc_fitting_adherence_numba( + pl_indices.astype(np.float64), pl_values, a_l, b_l + ) + fitting_adherence = (adherence_upper + adherence_lower) / 2.0 + + # 计算边界利用率 + boundary_util = calc_boundary_utilization_numba( + high_win, low_win, a_u, b_u, a_l, b_l, 0, n - 1 + ) + + # 计算突破强度 + (strength_up, strength_down, price_score_up, price_score_down, + convergence_score, vol_score, fitting_score, boundary_util_score) = \ + calc_breakout_strength_numba( + close_val, upper_end, lower_end, volume_ratio, + width_ratio, fitting_adherence, boundary_util + ) + + return (True, strength_up, strength_down, price_score_up, price_score_down, + convergence_score, vol_score, fitting_score, boundary_util_score, + a_u, a_l, width_ratio, touches_upper, touches_lower, + apex_x, breakout_dir, volume_confirmed) + + +@numba.jit(nopython=True, cache=True) +def detect_batch_with_precomputed_pivots_numba( + high: np.ndarray, + low: np.ndarray, + close: np.ndarray, + volume: np.ndarray, + is_pivot_high: np.ndarray, + is_pivot_low: np.ndarray, + valid_indices: np.ndarray, + window: int, + start_day: int, + end_day: int, + # 参数 + upper_slope_max: float, + lower_slope_min: float, + shrink_ratio: float, + touch_tol: float, + break_tol: float, + vol_window: int, + vol_k: float, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, + np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, + np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, + np.ndarray, np.ndarray, np.ndarray]: + """ + 使用预计算枢轴点的批量检测(单只股票,多日期) + + Returns: + 多个数组,每个元素对应一个检测点 + """ + n_valid = len(valid_indices) + n_dates = 0 + + # 计算有效日期数量 + for valid_end in range(window - 1, n_valid): + orig_date_idx = valid_indices[valid_end] + if orig_date_idx >= start_day and orig_date_idx <= end_day: + n_dates += 1 + + # 预分配结果数组 + date_indices = np.empty(n_dates, dtype=np.int32) + is_valid = np.zeros(n_dates, dtype=np.bool_) + strength_up = np.zeros(n_dates, dtype=np.float64) + strength_down = np.zeros(n_dates, dtype=np.float64) + price_score_up = np.zeros(n_dates, dtype=np.float64) + price_score_down = np.zeros(n_dates, dtype=np.float64) + convergence_score = np.zeros(n_dates, dtype=np.float64) + vol_score = np.zeros(n_dates, dtype=np.float64) + fitting_score = np.zeros(n_dates, dtype=np.float64) + boundary_util_score = np.zeros(n_dates, dtype=np.float64) + upper_slope = np.zeros(n_dates, dtype=np.float64) + lower_slope = np.zeros(n_dates, dtype=np.float64) + width_ratio = np.zeros(n_dates, dtype=np.float64) + touches_upper = np.zeros(n_dates, dtype=np.int32) + touches_lower = np.zeros(n_dates, dtype=np.int32) + apex_x = np.zeros(n_dates, dtype=np.float64) + breakout_dir = np.zeros(n_dates, dtype=np.int32) + volume_confirmed = np.zeros(n_dates, dtype=np.int32) - 1 # 默认-1 + + # 遍历日期 + result_idx = 0 + for valid_end in range(window - 1, n_valid): + orig_date_idx = valid_indices[valid_end] + + if orig_date_idx < start_day or orig_date_idx > end_day: + continue + + valid_start = valid_end - window + 1 + window_start_orig = valid_indices[valid_start] + window_end_orig = valid_indices[valid_end] + + date_indices[result_idx] = orig_date_idx + + # 调用单点检测 + result = detect_single_with_precomputed_pivots( + high, low, close, volume, + is_pivot_high, is_pivot_low, + window_start_orig, window_end_orig, + upper_slope_max, lower_slope_min, shrink_ratio, + touch_tol, break_tol, vol_window, vol_k + ) + + is_valid[result_idx] = result[0] + strength_up[result_idx] = result[1] + strength_down[result_idx] = result[2] + price_score_up[result_idx] = result[3] + price_score_down[result_idx] = result[4] + convergence_score[result_idx] = result[5] + vol_score[result_idx] = result[6] + fitting_score[result_idx] = result[7] + boundary_util_score[result_idx] = result[8] + upper_slope[result_idx] = result[9] + lower_slope[result_idx] = result[10] + width_ratio[result_idx] = result[11] + touches_upper[result_idx] = result[12] + touches_lower[result_idx] = result[13] + apex_x[result_idx] = result[14] + breakout_dir[result_idx] = result[15] + volume_confirmed[result_idx] = result[16] + + result_idx += 1 + + return (date_indices, is_valid, strength_up, strength_down, + price_score_up, price_score_down, convergence_score, + vol_score, fitting_score, boundary_util_score, + upper_slope, lower_slope, width_ratio, + touches_upper, touches_lower, apex_x, + breakout_dir, volume_confirmed)