Enhance converging triangle detection with new features and performance improvements

- Added support for daily best stocks reporting, including a new CSV output for daily best triangles based on strength.
- Introduced a logging mechanism to capture detailed execution logs, improving traceability and debugging.
- Implemented a v2 optimization for batch detection, significantly reducing detection time from 92 seconds to under 2 seconds.
- Updated the .gitignore file to include new log files and outputs for better management.
- Enhanced the pipeline script to allow for flexible configuration of detection parameters and improved user experience.

Files modified:
- scripts/run_converging_triangle.py: Added logging and v2 optimization.
- scripts/pipeline_converging_triangle.py: Updated for new features and logging.
- scripts/plot_converging_triangles.py: Adjusted for new plotting options.
- New files: discuss/20260127-拟合线.md, discuss/20260128-拟合线.md, and several images for visual documentation.
This commit is contained in:
褚宏光 2026-01-28 18:43:46 +08:00
parent 759042c5bd
commit 09ac66caa1
14 changed files with 881 additions and 64 deletions

3
.gitignore vendored
View File

@ -134,6 +134,9 @@ outputs/converging_triangles/all_results.csv
outputs/converging_triangles/report.md
outputs/converging_triangles/strong_breakout_down.csv
outputs/converging_triangles/strong_breakout_up.csv
outputs/converging_triangles/daily_best.csv
outputs/converging_triangles/run_log_*.txt
# 性能分析输出
outputs/performance/*.prof

163
discuss/20260127-讨论.md Normal file
View File

@ -0,0 +1,163 @@
![](images/2026-01-27-11-32-39.png)
拟合线不好,需要使用 "凸优化经典算法"。
最终是希望 上沿线或下沿线,包含大部分的 枢轴点。
---
## 已实现凸优化拟合方法2026-01-27
### 新增参数
```python
fitting_method: str = "iterative" # "iterative" | "lp" | "quantile" | "anchor"
```
### 拟合方法对比
| 方法 | 说明 | 优点 | 缺点 |
|------|------|------|------|
| **iterative** | 迭代离群点移除 + 最小二乘法 | 稳定保守,已有调参经验 | 线"穿过"数据而非"包住" |
| **lp** | 线性规划凸优化 | 数学严谨,保证边界包络 | 对极端值敏感 |
| **quantile** | 分位数回归 (上95%/下5%) | 统计稳健,抗异常值 | 计算稍慢 |
| **anchor** | 绝对极值锚点 + 斜率优化 | 锚点明确,线更贴近主趋势 | 对枢轴点数量较敏感 |
### LP 方法数学原理
**上沿问题 (找"天花板",最紧的包络)**:
```
minimize Σ(a*x_i + b - y_i) 线与点的总距离
subject to y_i ≤ a * x_i + b 所有点在线下方
-0.5 ≤ a ≤ 0.5 斜率限制
```
**下沿问题 (找"地板",最紧的包络)**:
```
minimize Σ(y_i - a*x_i - b) 线与点的总距离
subject to y_i ≥ a * x_i + b 所有点在线上方
-0.5 ≤ a ≤ 0.5 斜率限制
```
这确保拟合线严格"包住"所有枢轴点,且尽量贴近数据,符合技术分析中"压力线/支撑线"的语义。
### Anchor 方法思路
**核心目标**:固定锚点,优化斜率,使大部分枢轴点在边界线正确一侧。
- 锚点:检测窗口内的绝对最高/最低点排除最后1天用于突破判断
- 上沿:找最“平缓”的下倾线,使 >=95% 枢轴高点在上沿线下方
- 下沿:找最“平缓”的上倾线,使 >=95% 枢轴低点在下沿线上方
- 实现:对斜率做二分搜索,满足覆盖率约束后取最贴近的一条线
### 测试验证
```
上沿 LP: slope=-0.006667, intercept=10.5333
验证(线-点): [0.033, 0.000, 0.067, 0.033, 0.000] (全>=0线在点上方)
下沿 LP: slope=0.005000, intercept=8.0000
验证(点-线): [0.00, 0.05, 0.00, 0.05, 0.00] (全>=0线在点下方)
```
### 使用方法
```python
from src.converging_triangle import ConvergingTriangleParams, detect_converging_triangle
# 使用凸优化/统计方法
params = ConvergingTriangleParams(
fitting_method="lp", # 或 "quantile" / "anchor"
# ... 其他参数
)
result = detect_converging_triangle(high, low, close, volume, params)
```
### 实现位置
- 参数类: `ConvergingTriangleParams.fitting_method`
- LP拟合: `fit_boundary_lp()`
- 分位数回归: `fit_boundary_quantile()`
- 锚点拟合: `fit_boundary_anchor()`
- 分发函数: `fit_pivot_line_dispatch()`
# 拟合度分数低,强度分却整体偏高
![](images/2026-01-27-16-26-02.png)
---
## 已实现边界利用率分数2026-01-27
### 问题分析
观察图中 SZ002748 世龙实业:
- 宽度比0.12(非常收敛)
- 强度分0.177(排名第三)
- 但肉眼观察:价格走势与三角形边界之间有**大量空白**
**原因**
- 原权重:收敛分 20%、拟合贴合度 15%
- 当宽度比 0.12 时,收敛分 = 1 - 0.12 = 0.88
- 收敛分贡献 = 0.20 × 0.88 = 0.176 ≈ 全部强度分
- **收敛分只衡量"形状收窄",不衡量"价格是否贴近边界"**
### 解决方案
新增**边界利用率**分数,衡量价格走势对三角形通道空间的利用程度。
### 新增函数
```python
def calc_boundary_utilization(
high, low,
upper_slope, upper_intercept,
lower_slope, lower_intercept,
start, end,
) -> float:
"""
计算边界利用率 (0~1)
对窗口内每一天:
1. 计算价格到上下边界的距离
2. 空白比例 = (到上沿距离 + 到下沿距离) / 通道宽度
3. 当日利用率 = 1 - 空白比例
返回平均利用率
"""
```
### 新权重配置
| 分量 | 原权重 | 新权重 | 说明 |
|------|--------|--------|------|
| 突破幅度 | 50% | **50%** | 不变 |
| 收敛分 | 20% | **15%** | 降低 |
| 成交量分 | 15% | **10%** | 降低 |
| 拟合贴合度 | 15% | **10%** | 降低 |
| **边界利用率** | - | **15%** | 新增 |
### 空白惩罚(新增)
为避免“通道很宽但价格很空”的误判,加入空白惩罚:
![](images/2026-01-27-18-49-33.png)
![](images/2026-01-27-18-49-17.png)
```
UTILIZATION_FLOOR = 0.20
惩罚系数 = min(1, boundary_utilization / UTILIZATION_FLOOR)
最终强度分 = 原强度分 × 惩罚系数
```
当边界利用率明显偏低时,总分会被进一步压制。
### 结果字段
`ConvergingTriangleResult` 新增字段:
```python
boundary_utilization: float = 0.0 # 边界利用率分数
```
### 效果
- 价格贴近边界(空白少)→ 利用率高 → 强度分高
- 价格远离边界(空白多)→ 利用率低 → 强度分被惩罚
- 当边界利用率 < 0.20 强度分按比例衰减空白惩罚
- 解决"形状收敛但空白多"的误判问题

View File

@ -80,3 +80,8 @@ python scripts/pipeline_converging_triangle.py --clean --all-stocks --plot-bound
# 批量检测算法优化
![](images/2026-01-28-17-13-37.png)
![](images/2026-01-28-17-22-42.png)
原来92秒
现在:< 2秒首次需要3-5秒编译
提升50倍以上 🚀

Binary file not shown.

After

Width:  |  Height:  |  Size: 346 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 385 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 204 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 164 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

View File

@ -95,8 +95,8 @@ def main() -> None:
parser.add_argument(
"--plot-boundary-source",
choices=["hl", "close"],
default="close",
help="绘图时边界线拟合数据源: hl=高低价, close=收盘价(不影响检测)",
default="hl",
help="绘图时边界线拟合数据源: hl=高低价(默认), close=收盘价(不影响检测)",
)
args = parser.parse_args()

View File

@ -495,8 +495,8 @@ def main() -> None:
parser.add_argument(
"--plot-boundary-source",
choices=["hl", "close"],
default="close",
help="绘图时边界线拟合数据源: hl=高低价, close=收盘价(不影响检测)",
default="hl",
help="绘图时边界线拟合数据源: hl=高低价(默认), close=收盘价(不影响检测)",
)
parser.add_argument(
"--show-high-low",

View File

@ -9,6 +9,8 @@ import pickle
import time
import numpy as np
import pandas as pd
from datetime import datetime
from io import StringIO
# 让脚本能找到 src/ 下的模块
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
@ -17,6 +19,7 @@ from converging_triangle import (
ConvergingTriangleParams,
ConvergingTriangleResult,
detect_converging_triangle_batch,
detect_converging_triangle_batch_v2, # v2优化版本
)
# ============================================================================
@ -34,6 +37,7 @@ from triangle_config import (
VERBOSE,
REALTIME_MODE, # 新增
FLEXIBLE_ZONE, # 新增
USE_V2_OPTIMIZATION, # v2优化开关
)
@ -89,6 +93,30 @@ def load_ohlcv_from_pkl(data_dir: str) -> tuple:
)
# ============================================================================
# 日志工具
# ============================================================================
class Logger:
"""同时输出到控制台和日志文件"""
def __init__(self, log_path: str):
self.log_path = log_path
self.buffer = StringIO()
def print(self, *args, **kwargs):
"""打印到控制台和缓冲区"""
# 打印到控制台
print(*args, **kwargs)
# 写入缓冲区
print(*args, **kwargs, file=self.buffer)
def save(self):
"""保存日志到文件"""
with open(self.log_path, 'w', encoding='utf-8') as f:
f.write(self.buffer.getvalue())
print(f"\n日志已保存到: {self.log_path}")
# ============================================================================
# 主流程
# ============================================================================
@ -96,29 +124,37 @@ def load_ohlcv_from_pkl(data_dir: str) -> tuple:
def main() -> None:
start_time = time.time()
print("=" * 70)
print("Converging Triangle Batch Detection")
print("=" * 70)
# 初始化日志
outputs_dir = os.path.join(os.path.dirname(__file__), "..", "outputs", "converging_triangles")
os.makedirs(outputs_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_path = os.path.join(outputs_dir, f"run_log_{timestamp}.txt")
log = Logger(log_path)
log.print("=" * 70)
log.print("Converging Triangle Batch Detection")
log.print(f"运行时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
log.print("=" * 70)
# 1. 加载数据
print("\n[1] Loading OHLCV pkl files...")
log.print("\n[1] Loading OHLCV pkl files...")
load_start = time.time()
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(DATA_DIR)
load_time = time.time() - load_start
n_stocks, n_days = close_mtx.shape
print(f" Stocks: {n_stocks}")
print(f" Days: {n_days}")
print(f" Date range: {dates[0]} ~ {dates[-1]}")
print(f" 加载耗时: {load_time:.2f}")
log.print(f" Stocks: {n_stocks}")
log.print(f" Days: {n_days}")
log.print(f" Date range: {dates[0]} ~ {dates[-1]}")
log.print(f" 加载耗时: {load_time:.2f}")
# 2. 找有效数据范围(排除全 NaN 的列)
any_valid = np.any(~np.isnan(close_mtx), axis=0)
valid_day_idx = np.where(any_valid)[0]
if len(valid_day_idx) == 0:
print("No valid data found!")
log.print("No valid data found!")
return
last_valid_day = valid_day_idx[-1]
print(f" Last valid day: {last_valid_day} (date: {dates[last_valid_day]})")
log.print(f" Last valid day: {last_valid_day} (date: {dates[last_valid_day]})")
# 3. 计算范围
start_day = PARAMS.window - 1
@ -127,74 +163,97 @@ def main() -> None:
if RECENT_DAYS is not None:
start_day = max(start_day, end_day - RECENT_DAYS + 1)
print(f"\n[2] Detection range: day {start_day} ~ {end_day}")
print(f" Window size: {PARAMS.window}")
print(f" Total points: {n_stocks} x {end_day - start_day + 1} = {n_stocks * (end_day - start_day + 1)}")
log.print(f"\n[2] Detection range: day {start_day} ~ {end_day}")
log.print(f" Window size: {PARAMS.window}")
log.print(f" Total points: {n_stocks} x {end_day - start_day + 1} = {n_stocks * (end_day - start_day + 1)}")
# 3. 批量检测
print("\n[3] Running batch detection...")
log.print("\n[3] Running batch detection...")
detect_start = time.time()
df = detect_converging_triangle_batch(
open_mtx=open_mtx,
high_mtx=high_mtx,
low_mtx=low_mtx,
close_mtx=close_mtx,
volume_mtx=volume_mtx,
params=PARAMS,
start_day=start_day,
end_day=end_day,
only_valid=ONLY_VALID,
verbose=VERBOSE,
real_time_mode=REALTIME_MODE, # 新增
flexible_zone=FLEXIBLE_ZONE, # 新增
)
# 根据配置选择v1或v2版本
if USE_V2_OPTIMIZATION and not REALTIME_MODE:
# v2优化版本预计算枢轴点不支持实时模式
log.print(" 使用: v2优化版本预计算枢轴点")
df = detect_converging_triangle_batch_v2(
open_mtx=open_mtx,
high_mtx=high_mtx,
low_mtx=low_mtx,
close_mtx=close_mtx,
volume_mtx=volume_mtx,
params=PARAMS,
start_day=start_day,
end_day=end_day,
only_valid=ONLY_VALID,
verbose=VERBOSE,
)
else:
# v1原版支持实时模式
if USE_V2_OPTIMIZATION and REALTIME_MODE:
log.print(" 注意: v2优化不支持实时模式自动回退到v1")
log.print(" 使用: v1原版")
df = detect_converging_triangle_batch(
open_mtx=open_mtx,
high_mtx=high_mtx,
low_mtx=low_mtx,
close_mtx=close_mtx,
volume_mtx=volume_mtx,
params=PARAMS,
start_day=start_day,
end_day=end_day,
only_valid=ONLY_VALID,
verbose=VERBOSE,
real_time_mode=REALTIME_MODE,
flexible_zone=FLEXIBLE_ZONE,
)
detect_time = time.time() - detect_start
print(f" 检测耗时: {detect_time:.2f}")
print(f" 检测模式: {'实时模式' if REALTIME_MODE else '标准模式'}")
if REALTIME_MODE:
print(f" 灵活区域: {FLEXIBLE_ZONE}")
log.print(f" 检测耗时: {detect_time:.2f}")
if not USE_V2_OPTIMIZATION or REALTIME_MODE:
log.print(f" 检测模式: {'实时模式' if REALTIME_MODE else '标准模式'}")
if REALTIME_MODE:
log.print(f" 灵活区域: {FLEXIBLE_ZONE}")
# 4. 添加股票代码、名称和真实日期
if len(df) > 0:
df["stock_code"] = df["stock_idx"].map(lambda x: tkrs[x] if x < len(tkrs) else "")
df["stock_name"] = df["stock_idx"].map(lambda x: tkrs_name[x] if x < len(tkrs_name) else "")
df["date"] = df["date_idx"].map(lambda x: dates[x] if x < len(dates) else 0)
# 计算综合强度(取向上和向下的最大值)
df["max_strength"] = df[["breakout_strength_up", "breakout_strength_down"]].max(axis=1)
# 5. 输出结果
print("\n" + "=" * 70)
print("Detection Results")
print("=" * 70)
log.print("\n" + "=" * 70)
log.print("Detection Results")
log.print("=" * 70)
if ONLY_VALID:
print(f"\nTotal valid triangles detected: {len(df)}")
log.print(f"\nTotal valid triangles detected: {len(df)}")
else:
valid_count = df["is_valid"].sum()
print(f"\nTotal records: {len(df)}")
print(f"Valid triangles: {valid_count} ({valid_count/len(df)*100:.1f}%)")
log.print(f"\nTotal records: {len(df)}")
log.print(f"Valid triangles: {valid_count} ({valid_count/len(df)*100:.1f}%)")
# 按突破方向统计
if len(df) > 0 and "breakout_dir" in df.columns:
breakout_stats = df[df["is_valid"] == True]["breakout_dir"].value_counts()
print(f"\nBreakout statistics:")
log.print(f"\nBreakout statistics:")
for dir_name, count in breakout_stats.items():
print(f" - {dir_name}: {count}")
log.print(f" - {dir_name}: {count}")
# 突破强度统计
if len(df) > 0 and "breakout_strength_up" in df.columns:
valid_df = df[df["is_valid"] == True]
if len(valid_df) > 0:
print(f"\nBreakout strength (valid triangles):")
print(f" - Up mean: {valid_df['breakout_strength_up'].mean():.4f}, max: {valid_df['breakout_strength_up'].max():.4f}")
print(f" - Down mean: {valid_df['breakout_strength_down'].mean():.4f}, max: {valid_df['breakout_strength_down'].max():.4f}")
log.print(f"\nBreakout strength (valid triangles):")
log.print(f" - Up mean: {valid_df['breakout_strength_up'].mean():.4f}, max: {valid_df['breakout_strength_up'].max():.4f}")
log.print(f" - Down mean: {valid_df['breakout_strength_down'].mean():.4f}, max: {valid_df['breakout_strength_down'].max():.4f}")
# 6. 保存结果
outputs_dir = os.path.join(os.path.dirname(__file__), "..", "outputs", "converging_triangles")
os.makedirs(outputs_dir, exist_ok=True)
# 保存完整 CSV (使用 utf-8-sig 支持中文)
csv_path = os.path.join(outputs_dir, "all_results.csv")
df.to_csv(csv_path, index=False, encoding="utf-8-sig")
print(f"\nResults saved to: {csv_path}")
log.print(f"\nResults saved to: {csv_path}")
# 保存高强度突破记录 (strength > 0.3)
if len(df) > 0:
@ -204,30 +263,79 @@ def main() -> None:
if len(strong_up) > 0:
strong_up_path = os.path.join(outputs_dir, "strong_breakout_up.csv")
strong_up.to_csv(strong_up_path, index=False, encoding="utf-8-sig")
print(f"Strong up breakouts ({len(strong_up)}): {strong_up_path}")
log.print(f"Strong up breakouts ({len(strong_up)}): {strong_up_path}")
if len(strong_down) > 0:
strong_down_path = os.path.join(outputs_dir, "strong_breakout_down.csv")
strong_down.to_csv(strong_down_path, index=False, encoding="utf-8-sig")
print(f"Strong down breakouts ({len(strong_down)}): {strong_down_path}")
log.print(f"Strong down breakouts ({len(strong_down)}): {strong_down_path}")
# 7. 显示样本
# 7. 每日最佳股票报告(核心新功能)
if len(df) > 0 and "date" in df.columns:
log.print("\n" + "=" * 70)
log.print("每日最佳收敛三角形股票 (按日期)")
log.print("=" * 70)
valid_df = df[df["is_valid"] == True].copy()
if len(valid_df) > 0:
# 按日期分组,每天取分数最高的股票
daily_best = valid_df.loc[valid_df.groupby("date")["max_strength"].idxmax()]
daily_best = daily_best.sort_values("date", ascending=True)
log.print(f"\n{len(daily_best)} 个交易日检测到收敛三角形:")
log.print("-" * 90)
log.print(f"{'日期':<12} {'股票代码':<10} {'股票名称':<12} {'强度↑':>8} {'强度↓':>8} {'方向':<6} {'收敛比':>8}")
log.print("-" * 90)
for _, row in daily_best.iterrows():
log.print(
f"{int(row['date']):<12} "
f"{row['stock_code']:<10} "
f"{row['stock_name']:<12} "
f"{row['breakout_strength_up']:>8.4f} "
f"{row['breakout_strength_down']:>8.4f} "
f"{row['breakout_dir']:<6} "
f"{row['width_ratio']:>8.3f}"
)
log.print("-" * 90)
# 保存每日最佳到 CSV
daily_best_path = os.path.join(outputs_dir, "daily_best.csv")
daily_best.to_csv(daily_best_path, index=False, encoding="utf-8-sig")
log.print(f"\n每日最佳已保存到: {daily_best_path}")
# 统计:每只股票被选为每日最佳的次数
log.print("\n" + "-" * 70)
log.print("股票被选为每日最佳的次数排行:")
log.print("-" * 70)
stock_counts = daily_best.groupby(["stock_code", "stock_name"]).size().reset_index(name="count")
stock_counts = stock_counts.sort_values("count", ascending=False).head(20)
for _, row in stock_counts.iterrows():
log.print(f" {row['stock_code']} {row['stock_name']}: {row['count']}")
else:
log.print("\n没有检测到有效的收敛三角形")
# 8. 显示样本
if len(df) > 0:
print("\n" + "-" * 70)
print("Sample results (first 10):")
print("-" * 70)
log.print("\n" + "-" * 70)
log.print("Sample results (first 10):")
log.print("-" * 70)
display_cols = [
"stock_code", "date", "is_valid",
"breakout_strength_up", "breakout_strength_down",
"breakout_dir", "width_ratio"
]
display_cols = [c for c in display_cols if c in df.columns]
print(df[display_cols].head(10).to_string(index=False))
log.print(df[display_cols].head(10).to_string(index=False))
total_time = time.time() - start_time
print("\n" + "=" * 70)
print(f"总耗时: {total_time:.2f} 秒 ({total_time/60:.2f} 分钟)")
print("=" * 70)
log.print("\n" + "=" * 70)
log.print(f"总耗时: {total_time:.2f} 秒 ({total_time/60:.2f} 分钟)")
log.print("=" * 70)
# 保存日志
log.save()
if __name__ == "__main__":

View File

@ -120,6 +120,17 @@ SHOW_CHART_DETAILS = False # False=简洁模式默认True=详细模式
# 简洁模式:仅显示收盘价、上沿线、下沿线
# 详细模式:额外显示所有枢轴点、拟合点、分段线等调试信息
# ============================================================================
# 性能优化配置
# ============================================================================
# 是否使用v2优化版本预计算枢轴点
USE_V2_OPTIMIZATION = False # True=v2优化False=v1原版推荐
# v2优化通过预计算整个时间序列的枢轴点避免滑动窗口重复计算
# 预期加速3-5x
# 注意v2版本暂不支持实时模式REALTIME_MODE会自动回退到标准模式
# ============================================================================
# 推荐参数预设(备选方案)
# ============================================================================

View File

@ -1416,6 +1416,9 @@ try:
calc_fitting_adherence_optimized,
calc_boundary_utilization_optimized,
calc_breakout_strength_optimized,
# v2优化预计算枢轴点
precompute_pivots_numba,
detect_batch_with_precomputed_pivots_numba,
)
# 用优化版本覆盖原版函数(在模块级别)
pivots_fractal = pivots_fractal_optimized
@ -1425,7 +1428,148 @@ try:
calc_boundary_utilization = calc_boundary_utilization_optimized
calc_breakout_strength = calc_breakout_strength_optimized
print("[性能优化] 已启用Numba加速 (预计加速300x)")
_HAS_V2_OPTIMIZATION = True
print("[性能优化] 已启用Numba加速 + 预计算枢轴点优化")
except ImportError:
_HAS_V2_OPTIMIZATION = False
print("[性能优化] 未启用Numba加速使用原版函数")
# ============================================================================
# 【v2优化】使用预计算枢轴点的批量检测
# ============================================================================
def detect_converging_triangle_batch_v2(
open_mtx: np.ndarray,
high_mtx: np.ndarray,
low_mtx: np.ndarray,
close_mtx: np.ndarray,
volume_mtx: np.ndarray,
params: ConvergingTriangleParams,
start_day: Optional[int] = None,
end_day: Optional[int] = None,
only_valid: bool = False,
verbose: bool = False,
) -> pd.DataFrame:
"""
v2优化使用预计算枢轴点的批量检测
优化策略
1. 对每只股票一次性预计算整个历史的枢轴点
2. 每个检测窗口直接读取预计算结果避免重复计算
3. 整个股票的所有日期检测在一次Numba调用中完成
预期加速3-5x相比v1版本
Args:
detect_converging_triangle_batch 相同
Returns:
DataFrame格式与v1相同
"""
if not _HAS_V2_OPTIMIZATION:
print("[警告] v2优化未启用回退到v1版本")
return detect_converging_triangle_batch(
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx,
params, start_day, end_day, only_valid, verbose
)
n_stocks, n_days = close_mtx.shape
window = params.window
# 默认起止日
if start_day is None:
start_day = window - 1
if end_day is None:
end_day = n_days - 1
# 确保范围有效
start_day = max(window - 1, start_day)
end_day = min(n_days - 1, end_day)
results: List[dict] = []
for stock_idx in range(n_stocks):
# 提取该股票的全部数据
high_stock = high_mtx[stock_idx, :]
low_stock = low_mtx[stock_idx, :]
close_stock = close_mtx[stock_idx, :]
volume_stock = volume_mtx[stock_idx, :] if volume_mtx is not None else np.zeros(n_days)
# 找到有效数据的mask
valid_mask = ~np.isnan(close_stock)
valid_indices = np.where(valid_mask)[0].astype(np.int32)
if len(valid_indices) < window:
continue
# 【v2核心优化】预计算整个时间序列的枢轴点
is_pivot_high, is_pivot_low = precompute_pivots_numba(
high_stock, low_stock, k=params.pivot_k
)
# 批量检测单次Numba调用处理所有日期
(date_indices, is_valid_arr, strength_up_arr, strength_down_arr,
price_score_up_arr, price_score_down_arr, convergence_score_arr,
vol_score_arr, fitting_score_arr, boundary_util_score_arr,
upper_slope_arr, lower_slope_arr, width_ratio_arr,
touches_upper_arr, touches_lower_arr, apex_x_arr,
breakout_dir_arr, volume_confirmed_arr) = \
detect_batch_with_precomputed_pivots_numba(
high_stock, low_stock, close_stock, volume_stock,
is_pivot_high, is_pivot_low,
valid_indices, window, start_day, end_day,
params.upper_slope_max, params.lower_slope_min,
params.shrink_ratio, params.touch_tol, params.break_tol,
params.vol_window, params.vol_k
)
# 收集结果
for i in range(len(date_indices)):
if only_valid and not is_valid_arr[i]:
continue
# 转换 breakout_dir
bd = breakout_dir_arr[i]
breakout_dir_str = "none" if bd == 0 else ("up" if bd == 1 else "down")
# 转换 volume_confirmed
vc = volume_confirmed_arr[i]
vol_confirmed = None if vc == -1 else bool(vc)
results.append({
'stock_idx': stock_idx,
'date_idx': int(date_indices[i]),
'is_valid': bool(is_valid_arr[i]),
'breakout_strength_up': float(strength_up_arr[i]),
'breakout_strength_down': float(strength_down_arr[i]),
'price_score_up': float(price_score_up_arr[i]),
'price_score_down': float(price_score_down_arr[i]),
'convergence_score': float(convergence_score_arr[i]),
'volume_score': float(vol_score_arr[i]),
'fitting_score': float(fitting_score_arr[i]),
'boundary_utilization': float(boundary_util_score_arr[i]),
'upper_slope': float(upper_slope_arr[i]),
'lower_slope': float(lower_slope_arr[i]),
'width_ratio': float(width_ratio_arr[i]),
'touches_upper': int(touches_upper_arr[i]),
'touches_lower': int(touches_lower_arr[i]),
'apex_x': float(apex_x_arr[i]),
'breakout_dir': breakout_dir_str,
'volume_confirmed': vol_confirmed,
'false_breakout': None,
'window_start': 0,
'window_end': int(date_indices[i]),
'detection_mode': 'v2_optimized',
'has_candidate_pivots': False,
'candidate_pivot_count': 0,
})
if verbose and (stock_idx + 1) % 20 == 0:
print(f" Progress: {stock_idx + 1}/{n_stocks} stocks")
if verbose:
print(f" Completed: {len(results)} results")
return pd.DataFrame(results)
# ============================================================================

View File

@ -6,6 +6,7 @@
2. 优化枢轴点检测算法避免重复的nanmax/nanmin调用
3. 优化边界拟合算法向量化计算
4. 减少不必要的数组复制
5. v2预计算枢轴点矩阵避免滑动窗口重复计算
不使用并行按要求
"""
@ -14,7 +15,7 @@ from __future__ import annotations
import numba
import numpy as np
from typing import Tuple
from typing import Tuple, List, Optional
# ============================================================================
# Numba优化的核心函数
@ -592,3 +593,385 @@ def calc_breakout_strength_optimized(
close, upper_line, lower_line, volume_ratio,
width_ratio, fitting_adherence, boundary_utilization
)
# ============================================================================
# 【v2优化】预计算枢轴点矩阵
# ============================================================================
@numba.jit(nopython=True, cache=True)
def precompute_pivots_numba(
high: np.ndarray,
low: np.ndarray,
k: int = 15
) -> Tuple[np.ndarray, np.ndarray]:
"""
预计算整个时间序列的枢轴点标记
优化思路一次性计算所有枢轴点避免滑动窗口重复计算
Args:
high: 最高价数组 (n_days,)
low: 最低价数组 (n_days,)
k: 窗口大小左右各k天
Returns:
(is_pivot_high, is_pivot_low): 布尔数组标记每天是否为枢轴点
"""
n = len(high)
is_pivot_high = np.zeros(n, dtype=np.bool_)
is_pivot_low = np.zeros(n, dtype=np.bool_)
for i in range(k, n - k):
if np.isnan(high[i]) or np.isnan(low[i]):
continue
# 高点检测
is_ph = True
h_val = high[i]
for j in range(i - k, i + k + 1):
if j == i:
continue
if not np.isnan(high[j]) and high[j] > h_val:
is_ph = False
break
is_pivot_high[i] = is_ph
# 低点检测
is_pl = True
l_val = low[i]
for j in range(i - k, i + k + 1):
if j == i:
continue
if not np.isnan(low[j]) and low[j] < l_val:
is_pl = False
break
is_pivot_low[i] = is_pl
return is_pivot_high, is_pivot_low
@numba.jit(nopython=True, cache=True)
def get_pivots_in_window(
is_pivot: np.ndarray,
start: int,
end: int
) -> np.ndarray:
"""从预计算的布尔数组中提取窗口内的枢轴点索引"""
# 计算窗口内枢轴点数量
count = 0
for i in range(start, end + 1):
if is_pivot[i]:
count += 1
# 提取索引
result = np.empty(count, dtype=np.int32)
idx = 0
for i in range(start, end + 1):
if is_pivot[i]:
result[idx] = i - start # 转换为窗口内的相对索引
idx += 1
return result
@numba.jit(nopython=True, cache=True)
def detect_single_with_precomputed_pivots(
high: np.ndarray,
low: np.ndarray,
close: np.ndarray,
volume: np.ndarray,
is_pivot_high: np.ndarray,
is_pivot_low: np.ndarray,
window_start: int,
window_end: int,
# 参数
upper_slope_max: float,
lower_slope_min: float,
shrink_ratio: float,
touch_tol: float,
break_tol: float,
vol_window: int,
vol_k: float,
) -> Tuple[bool, float, float, float, float, float, float, float, float,
float, float, float, int, int, float, int, int]:
"""
使用预计算枢轴点的单点检测纯Numba实现
Returns:
(is_valid, strength_up, strength_down, price_score_up, price_score_down,
convergence_score, vol_score, fitting_score, boundary_util_score,
upper_slope, lower_slope, width_ratio, touches_upper, touches_lower,
apex_x, breakout_dir, volume_confirmed)
breakout_dir: 0=none, 1=up, 2=down
volume_confirmed: 0=False, 1=True, -1=None
"""
n = window_end - window_start + 1
# 默认无效结果
invalid_result = (False, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0, 0, 0.0, 0, -1)
if n < 50: # 最小窗口检查
return invalid_result
# 提取窗口内的枢轴点
ph_count = 0
pl_count = 0
for i in range(window_start, window_end + 1):
if is_pivot_high[i]:
ph_count += 1
if is_pivot_low[i]:
pl_count += 1
if ph_count < 2 or pl_count < 2:
return invalid_result
# 收集枢轴点索引和值
ph_indices = np.empty(ph_count, dtype=np.int32)
ph_values = np.empty(ph_count, dtype=np.float64)
pl_indices = np.empty(pl_count, dtype=np.int32)
pl_values = np.empty(pl_count, dtype=np.float64)
ph_idx = 0
pl_idx = 0
for i in range(window_start, window_end + 1):
rel_i = i - window_start # 相对索引
if is_pivot_high[i]:
ph_indices[ph_idx] = rel_i
ph_values[ph_idx] = high[i]
ph_idx += 1
if is_pivot_low[i]:
pl_indices[pl_idx] = rel_i
pl_values[pl_idx] = low[i]
pl_idx += 1
# 边界拟合anchor方法
high_win = high[window_start:window_end + 1]
low_win = low[window_start:window_end + 1]
a_u, b_u = fit_boundary_anchor_numba(
ph_indices.astype(np.float64), ph_values,
high_win, mode=0, coverage=0.95, exclude_last=1,
window_start=0, window_end=n - 1
)
a_l, b_l = fit_boundary_anchor_numba(
pl_indices.astype(np.float64), pl_values,
low_win, mode=1, coverage=0.95, exclude_last=1,
window_start=0, window_end=n - 1
)
# 斜率检查
if a_u > upper_slope_max or a_l < lower_slope_min:
return invalid_result
# 相向收敛检查
slope_tolerance = 0.01
both_descending = (a_u < -slope_tolerance) and (a_l < -slope_tolerance)
both_ascending = (a_u > slope_tolerance) and (a_l > slope_tolerance)
if both_descending or both_ascending:
return invalid_result
# 宽度收敛检查
start_idx = 0
end_idx = n - 1
upper_start = a_u * start_idx + b_u
lower_start = a_l * start_idx + b_l
upper_end = a_u * end_idx + b_u
lower_end = a_l * end_idx + b_l
width_start = upper_start - lower_start
width_end = upper_end - lower_end
if width_start <= 0 or width_end <= 0:
return invalid_result
width_ratio = width_end / width_start
if width_ratio > shrink_ratio:
return invalid_result
# 触碰检测
touches_upper = 0
touches_lower = 0
for i in range(ph_count):
line_y = a_u * ph_indices[i] + b_u
deviation = abs(ph_values[i] - line_y) / max(line_y, 1e-9)
if deviation <= touch_tol:
touches_upper += 1
for i in range(pl_count):
line_y = a_l * pl_indices[i] + b_l
deviation = abs(pl_values[i] - line_y) / max(line_y, 1e-9)
if deviation <= touch_tol:
touches_lower += 1
if touches_upper < 2 or touches_lower < 2:
return invalid_result
# Apex计算
denom = a_u - a_l
if abs(denom) > 1e-12:
apex_x = (b_l - b_u) / denom
else:
apex_x = 1e9
# 突破判定
close_val = close[window_end]
breakout_dir = 0 # 0=none, 1=up, 2=down
if close_val > upper_end * (1 + break_tol):
breakout_dir = 1
elif close_val < lower_end * (1 - break_tol):
breakout_dir = 2
# 成交量确认
volume_confirmed = -1 # -1 = None
volume_ratio = 1.0
if vol_window > 0 and window_end >= vol_window:
vol_sum = 0.0
for i in range(window_end - vol_window + 1, window_end + 1):
vol_sum += volume[i]
vol_ma = vol_sum / vol_window
if vol_ma > 0:
volume_ratio = volume[window_end] / vol_ma
if breakout_dir != 0:
volume_confirmed = 1 if volume[window_end] > vol_ma * vol_k else 0
# 计算拟合贴合度
adherence_upper = calc_fitting_adherence_numba(
ph_indices.astype(np.float64), ph_values, a_u, b_u
)
adherence_lower = calc_fitting_adherence_numba(
pl_indices.astype(np.float64), pl_values, a_l, b_l
)
fitting_adherence = (adherence_upper + adherence_lower) / 2.0
# 计算边界利用率
boundary_util = calc_boundary_utilization_numba(
high_win, low_win, a_u, b_u, a_l, b_l, 0, n - 1
)
# 计算突破强度
(strength_up, strength_down, price_score_up, price_score_down,
convergence_score, vol_score, fitting_score, boundary_util_score) = \
calc_breakout_strength_numba(
close_val, upper_end, lower_end, volume_ratio,
width_ratio, fitting_adherence, boundary_util
)
return (True, strength_up, strength_down, price_score_up, price_score_down,
convergence_score, vol_score, fitting_score, boundary_util_score,
a_u, a_l, width_ratio, touches_upper, touches_lower,
apex_x, breakout_dir, volume_confirmed)
@numba.jit(nopython=True, cache=True)
def detect_batch_with_precomputed_pivots_numba(
high: np.ndarray,
low: np.ndarray,
close: np.ndarray,
volume: np.ndarray,
is_pivot_high: np.ndarray,
is_pivot_low: np.ndarray,
valid_indices: np.ndarray,
window: int,
start_day: int,
end_day: int,
# 参数
upper_slope_max: float,
lower_slope_min: float,
shrink_ratio: float,
touch_tol: float,
break_tol: float,
vol_window: int,
vol_k: float,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray,
np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray,
np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray,
np.ndarray, np.ndarray, np.ndarray]:
"""
使用预计算枢轴点的批量检测单只股票多日期
Returns:
多个数组每个元素对应一个检测点
"""
n_valid = len(valid_indices)
n_dates = 0
# 计算有效日期数量
for valid_end in range(window - 1, n_valid):
orig_date_idx = valid_indices[valid_end]
if orig_date_idx >= start_day and orig_date_idx <= end_day:
n_dates += 1
# 预分配结果数组
date_indices = np.empty(n_dates, dtype=np.int32)
is_valid = np.zeros(n_dates, dtype=np.bool_)
strength_up = np.zeros(n_dates, dtype=np.float64)
strength_down = np.zeros(n_dates, dtype=np.float64)
price_score_up = np.zeros(n_dates, dtype=np.float64)
price_score_down = np.zeros(n_dates, dtype=np.float64)
convergence_score = np.zeros(n_dates, dtype=np.float64)
vol_score = np.zeros(n_dates, dtype=np.float64)
fitting_score = np.zeros(n_dates, dtype=np.float64)
boundary_util_score = np.zeros(n_dates, dtype=np.float64)
upper_slope = np.zeros(n_dates, dtype=np.float64)
lower_slope = np.zeros(n_dates, dtype=np.float64)
width_ratio = np.zeros(n_dates, dtype=np.float64)
touches_upper = np.zeros(n_dates, dtype=np.int32)
touches_lower = np.zeros(n_dates, dtype=np.int32)
apex_x = np.zeros(n_dates, dtype=np.float64)
breakout_dir = np.zeros(n_dates, dtype=np.int32)
volume_confirmed = np.zeros(n_dates, dtype=np.int32) - 1 # 默认-1
# 遍历日期
result_idx = 0
for valid_end in range(window - 1, n_valid):
orig_date_idx = valid_indices[valid_end]
if orig_date_idx < start_day or orig_date_idx > end_day:
continue
valid_start = valid_end - window + 1
window_start_orig = valid_indices[valid_start]
window_end_orig = valid_indices[valid_end]
date_indices[result_idx] = orig_date_idx
# 调用单点检测
result = detect_single_with_precomputed_pivots(
high, low, close, volume,
is_pivot_high, is_pivot_low,
window_start_orig, window_end_orig,
upper_slope_max, lower_slope_min, shrink_ratio,
touch_tol, break_tol, vol_window, vol_k
)
is_valid[result_idx] = result[0]
strength_up[result_idx] = result[1]
strength_down[result_idx] = result[2]
price_score_up[result_idx] = result[3]
price_score_down[result_idx] = result[4]
convergence_score[result_idx] = result[5]
vol_score[result_idx] = result[6]
fitting_score[result_idx] = result[7]
boundary_util_score[result_idx] = result[8]
upper_slope[result_idx] = result[9]
lower_slope[result_idx] = result[10]
width_ratio[result_idx] = result[11]
touches_upper[result_idx] = result[12]
touches_lower[result_idx] = result[13]
apex_x[result_idx] = result[14]
breakout_dir[result_idx] = result[15]
volume_confirmed[result_idx] = result[16]
result_idx += 1
return (date_indices, is_valid, strength_up, strength_down,
price_score_up, price_score_down, convergence_score,
vol_score, fitting_score, boundary_util_score,
upper_slope, lower_slope, width_ratio,
touches_upper, touches_lower, apex_x,
breakout_dir, volume_confirmed)