Enhance converging triangle detection with new features and performance improvements
- Added support for daily best stocks reporting, including a new CSV output for daily best triangles based on strength. - Introduced a logging mechanism to capture detailed execution logs, improving traceability and debugging. - Implemented a v2 optimization for batch detection, significantly reducing detection time from 92 seconds to under 2 seconds. - Updated the .gitignore file to include new log files and outputs for better management. - Enhanced the pipeline script to allow for flexible configuration of detection parameters and improved user experience. Files modified: - scripts/run_converging_triangle.py: Added logging and v2 optimization. - scripts/pipeline_converging_triangle.py: Updated for new features and logging. - scripts/plot_converging_triangles.py: Adjusted for new plotting options. - New files: discuss/20260127-拟合线.md, discuss/20260128-拟合线.md, and several images for visual documentation.
This commit is contained in:
parent
759042c5bd
commit
09ac66caa1
3
.gitignore
vendored
3
.gitignore
vendored
@ -134,6 +134,9 @@ outputs/converging_triangles/all_results.csv
|
|||||||
outputs/converging_triangles/report.md
|
outputs/converging_triangles/report.md
|
||||||
outputs/converging_triangles/strong_breakout_down.csv
|
outputs/converging_triangles/strong_breakout_down.csv
|
||||||
outputs/converging_triangles/strong_breakout_up.csv
|
outputs/converging_triangles/strong_breakout_up.csv
|
||||||
|
outputs/converging_triangles/daily_best.csv
|
||||||
|
outputs/converging_triangles/run_log_*.txt
|
||||||
|
|
||||||
|
|
||||||
# 性能分析输出
|
# 性能分析输出
|
||||||
outputs/performance/*.prof
|
outputs/performance/*.prof
|
||||||
|
|||||||
163
discuss/20260127-讨论.md
Normal file
163
discuss/20260127-讨论.md
Normal file
@ -0,0 +1,163 @@
|
|||||||
|

|
||||||
|
|
||||||
|
拟合线不好,需要使用 "凸优化经典算法"。
|
||||||
|
最终是希望 上沿线或下沿线,包含大部分的 枢轴点。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 已实现:凸优化拟合方法(2026-01-27)
|
||||||
|
|
||||||
|
### 新增参数
|
||||||
|
|
||||||
|
```python
|
||||||
|
fitting_method: str = "iterative" # "iterative" | "lp" | "quantile" | "anchor"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 拟合方法对比
|
||||||
|
|
||||||
|
| 方法 | 说明 | 优点 | 缺点 |
|
||||||
|
|------|------|------|------|
|
||||||
|
| **iterative** | 迭代离群点移除 + 最小二乘法 | 稳定保守,已有调参经验 | 线"穿过"数据而非"包住" |
|
||||||
|
| **lp** | 线性规划凸优化 | 数学严谨,保证边界包络 | 对极端值敏感 |
|
||||||
|
| **quantile** | 分位数回归 (上95%/下5%) | 统计稳健,抗异常值 | 计算稍慢 |
|
||||||
|
| **anchor** | 绝对极值锚点 + 斜率优化 | 锚点明确,线更贴近主趋势 | 对枢轴点数量较敏感 |
|
||||||
|
|
||||||
|
### LP 方法数学原理
|
||||||
|
|
||||||
|
**上沿问题 (找"天花板",最紧的包络)**:
|
||||||
|
```
|
||||||
|
minimize Σ(a*x_i + b - y_i) 线与点的总距离
|
||||||
|
subject to y_i ≤ a * x_i + b 所有点在线下方
|
||||||
|
-0.5 ≤ a ≤ 0.5 斜率限制
|
||||||
|
```
|
||||||
|
|
||||||
|
**下沿问题 (找"地板",最紧的包络)**:
|
||||||
|
```
|
||||||
|
minimize Σ(y_i - a*x_i - b) 线与点的总距离
|
||||||
|
subject to y_i ≥ a * x_i + b 所有点在线上方
|
||||||
|
-0.5 ≤ a ≤ 0.5 斜率限制
|
||||||
|
```
|
||||||
|
|
||||||
|
这确保拟合线严格"包住"所有枢轴点,且尽量贴近数据,符合技术分析中"压力线/支撑线"的语义。
|
||||||
|
|
||||||
|
### Anchor 方法思路
|
||||||
|
|
||||||
|
**核心目标**:固定锚点,优化斜率,使大部分枢轴点在边界线正确一侧。
|
||||||
|
|
||||||
|
- 锚点:检测窗口内的绝对最高/最低点(排除最后1天用于突破判断)
|
||||||
|
- 上沿:找最“平缓”的下倾线,使 >=95% 枢轴高点在上沿线下方
|
||||||
|
- 下沿:找最“平缓”的上倾线,使 >=95% 枢轴低点在下沿线上方
|
||||||
|
- 实现:对斜率做二分搜索,满足覆盖率约束后取最贴近的一条线
|
||||||
|
|
||||||
|
### 测试验证
|
||||||
|
|
||||||
|
```
|
||||||
|
上沿 LP: slope=-0.006667, intercept=10.5333
|
||||||
|
验证(线-点): [0.033, 0.000, 0.067, 0.033, 0.000] (全>=0,线在点上方)
|
||||||
|
下沿 LP: slope=0.005000, intercept=8.0000
|
||||||
|
验证(点-线): [0.00, 0.05, 0.00, 0.05, 0.00] (全>=0,线在点下方)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用方法
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.converging_triangle import ConvergingTriangleParams, detect_converging_triangle
|
||||||
|
|
||||||
|
# 使用凸优化/统计方法
|
||||||
|
params = ConvergingTriangleParams(
|
||||||
|
fitting_method="lp", # 或 "quantile" / "anchor"
|
||||||
|
# ... 其他参数
|
||||||
|
)
|
||||||
|
|
||||||
|
result = detect_converging_triangle(high, low, close, volume, params)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 实现位置
|
||||||
|
|
||||||
|
- 参数类: `ConvergingTriangleParams.fitting_method`
|
||||||
|
- LP拟合: `fit_boundary_lp()`
|
||||||
|
- 分位数回归: `fit_boundary_quantile()`
|
||||||
|
- 锚点拟合: `fit_boundary_anchor()`
|
||||||
|
- 分发函数: `fit_pivot_line_dispatch()`
|
||||||
|
|
||||||
|
# 拟合度分数低,强度分却整体偏高
|
||||||
|

|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 已实现:边界利用率分数(2026-01-27)
|
||||||
|
|
||||||
|
### 问题分析
|
||||||
|
|
||||||
|
观察图中 SZ002748 世龙实业:
|
||||||
|
- 宽度比:0.12(非常收敛)
|
||||||
|
- 强度分:0.177(排名第三)
|
||||||
|
- 但肉眼观察:价格走势与三角形边界之间有**大量空白**
|
||||||
|
|
||||||
|
**原因**:
|
||||||
|
- 原权重:收敛分 20%、拟合贴合度 15%
|
||||||
|
- 当宽度比 0.12 时,收敛分 = 1 - 0.12 = 0.88
|
||||||
|
- 收敛分贡献 = 0.20 × 0.88 = 0.176 ≈ 全部强度分
|
||||||
|
- **收敛分只衡量"形状收窄",不衡量"价格是否贴近边界"**
|
||||||
|
|
||||||
|
### 解决方案
|
||||||
|
|
||||||
|
新增**边界利用率**分数,衡量价格走势对三角形通道空间的利用程度。
|
||||||
|
|
||||||
|
### 新增函数
|
||||||
|
|
||||||
|
```python
|
||||||
|
def calc_boundary_utilization(
|
||||||
|
high, low,
|
||||||
|
upper_slope, upper_intercept,
|
||||||
|
lower_slope, lower_intercept,
|
||||||
|
start, end,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
计算边界利用率 (0~1)
|
||||||
|
|
||||||
|
对窗口内每一天:
|
||||||
|
1. 计算价格到上下边界的距离
|
||||||
|
2. 空白比例 = (到上沿距离 + 到下沿距离) / 通道宽度
|
||||||
|
3. 当日利用率 = 1 - 空白比例
|
||||||
|
|
||||||
|
返回平均利用率
|
||||||
|
"""
|
||||||
|
```
|
||||||
|
|
||||||
|
### 新权重配置
|
||||||
|
|
||||||
|
| 分量 | 原权重 | 新权重 | 说明 |
|
||||||
|
|------|--------|--------|------|
|
||||||
|
| 突破幅度 | 50% | **50%** | 不变 |
|
||||||
|
| 收敛分 | 20% | **15%** | 降低 |
|
||||||
|
| 成交量分 | 15% | **10%** | 降低 |
|
||||||
|
| 拟合贴合度 | 15% | **10%** | 降低 |
|
||||||
|
| **边界利用率** | - | **15%** | 新增 |
|
||||||
|
|
||||||
|
### 空白惩罚(新增)
|
||||||
|
|
||||||
|
为避免“通道很宽但价格很空”的误判,加入空白惩罚:
|
||||||
|

|
||||||
|

|
||||||
|
```
|
||||||
|
UTILIZATION_FLOOR = 0.20
|
||||||
|
惩罚系数 = min(1, boundary_utilization / UTILIZATION_FLOOR)
|
||||||
|
最终强度分 = 原强度分 × 惩罚系数
|
||||||
|
```
|
||||||
|
|
||||||
|
当边界利用率明显偏低时,总分会被进一步压制。
|
||||||
|
|
||||||
|
### 结果字段
|
||||||
|
|
||||||
|
`ConvergingTriangleResult` 新增字段:
|
||||||
|
```python
|
||||||
|
boundary_utilization: float = 0.0 # 边界利用率分数
|
||||||
|
```
|
||||||
|
|
||||||
|
### 效果
|
||||||
|
|
||||||
|
- 价格贴近边界(空白少)→ 利用率高 → 强度分高
|
||||||
|
- 价格远离边界(空白多)→ 利用率低 → 强度分被惩罚
|
||||||
|
- 当边界利用率 < 0.20 时,强度分按比例衰减(空白惩罚)
|
||||||
|
- 解决"形状收敛但空白多"的误判问题
|
||||||
@ -80,3 +80,8 @@ python scripts/pipeline_converging_triangle.py --clean --all-stocks --plot-bound
|
|||||||
|
|
||||||
# 批量检测算法优化
|
# 批量检测算法优化
|
||||||

|

|
||||||
|

|
||||||
|
|
||||||
|
原来:92秒
|
||||||
|
现在:< 2秒(首次需要3-5秒编译)
|
||||||
|
提升:50倍以上 🚀
|
||||||
BIN
discuss/images/2026-01-28-11-16-46.png
Normal file
BIN
discuss/images/2026-01-28-11-16-46.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 346 KiB |
BIN
discuss/images/2026-01-28-11-16-56.png
Normal file
BIN
discuss/images/2026-01-28-11-16-56.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 385 KiB |
BIN
discuss/images/2026-01-28-15-56-12.png
Normal file
BIN
discuss/images/2026-01-28-15-56-12.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 204 KiB |
BIN
discuss/images/2026-01-28-17-13-37.png
Normal file
BIN
discuss/images/2026-01-28-17-13-37.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 164 KiB |
BIN
discuss/images/2026-01-28-17-22-42.png
Normal file
BIN
discuss/images/2026-01-28-17-22-42.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 155 KiB |
@ -95,8 +95,8 @@ def main() -> None:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--plot-boundary-source",
|
"--plot-boundary-source",
|
||||||
choices=["hl", "close"],
|
choices=["hl", "close"],
|
||||||
default="close",
|
default="hl",
|
||||||
help="绘图时边界线拟合数据源: hl=高低价, close=收盘价(不影响检测)",
|
help="绘图时边界线拟合数据源: hl=高低价(默认), close=收盘价(不影响检测)",
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@ -495,8 +495,8 @@ def main() -> None:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--plot-boundary-source",
|
"--plot-boundary-source",
|
||||||
choices=["hl", "close"],
|
choices=["hl", "close"],
|
||||||
default="close",
|
default="hl",
|
||||||
help="绘图时边界线拟合数据源: hl=高低价, close=收盘价(不影响检测)",
|
help="绘图时边界线拟合数据源: hl=高低价(默认), close=收盘价(不影响检测)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--show-high-low",
|
"--show-high-low",
|
||||||
|
|||||||
@ -9,6 +9,8 @@ import pickle
|
|||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from datetime import datetime
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
# 让脚本能找到 src/ 下的模块
|
# 让脚本能找到 src/ 下的模块
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||||||
@ -17,6 +19,7 @@ from converging_triangle import (
|
|||||||
ConvergingTriangleParams,
|
ConvergingTriangleParams,
|
||||||
ConvergingTriangleResult,
|
ConvergingTriangleResult,
|
||||||
detect_converging_triangle_batch,
|
detect_converging_triangle_batch,
|
||||||
|
detect_converging_triangle_batch_v2, # v2优化版本
|
||||||
)
|
)
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@ -34,6 +37,7 @@ from triangle_config import (
|
|||||||
VERBOSE,
|
VERBOSE,
|
||||||
REALTIME_MODE, # 新增
|
REALTIME_MODE, # 新增
|
||||||
FLEXIBLE_ZONE, # 新增
|
FLEXIBLE_ZONE, # 新增
|
||||||
|
USE_V2_OPTIMIZATION, # v2优化开关
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -89,6 +93,30 @@ def load_ohlcv_from_pkl(data_dir: str) -> tuple:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 日志工具
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class Logger:
|
||||||
|
"""同时输出到控制台和日志文件"""
|
||||||
|
def __init__(self, log_path: str):
|
||||||
|
self.log_path = log_path
|
||||||
|
self.buffer = StringIO()
|
||||||
|
|
||||||
|
def print(self, *args, **kwargs):
|
||||||
|
"""打印到控制台和缓冲区"""
|
||||||
|
# 打印到控制台
|
||||||
|
print(*args, **kwargs)
|
||||||
|
# 写入缓冲区
|
||||||
|
print(*args, **kwargs, file=self.buffer)
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
"""保存日志到文件"""
|
||||||
|
with open(self.log_path, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(self.buffer.getvalue())
|
||||||
|
print(f"\n日志已保存到: {self.log_path}")
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# 主流程
|
# 主流程
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
@ -96,29 +124,37 @@ def load_ohlcv_from_pkl(data_dir: str) -> tuple:
|
|||||||
def main() -> None:
|
def main() -> None:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
print("=" * 70)
|
# 初始化日志
|
||||||
print("Converging Triangle Batch Detection")
|
outputs_dir = os.path.join(os.path.dirname(__file__), "..", "outputs", "converging_triangles")
|
||||||
print("=" * 70)
|
os.makedirs(outputs_dir, exist_ok=True)
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
log_path = os.path.join(outputs_dir, f"run_log_{timestamp}.txt")
|
||||||
|
log = Logger(log_path)
|
||||||
|
|
||||||
|
log.print("=" * 70)
|
||||||
|
log.print("Converging Triangle Batch Detection")
|
||||||
|
log.print(f"运行时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
log.print("=" * 70)
|
||||||
|
|
||||||
# 1. 加载数据
|
# 1. 加载数据
|
||||||
print("\n[1] Loading OHLCV pkl files...")
|
log.print("\n[1] Loading OHLCV pkl files...")
|
||||||
load_start = time.time()
|
load_start = time.time()
|
||||||
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(DATA_DIR)
|
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(DATA_DIR)
|
||||||
load_time = time.time() - load_start
|
load_time = time.time() - load_start
|
||||||
n_stocks, n_days = close_mtx.shape
|
n_stocks, n_days = close_mtx.shape
|
||||||
print(f" Stocks: {n_stocks}")
|
log.print(f" Stocks: {n_stocks}")
|
||||||
print(f" Days: {n_days}")
|
log.print(f" Days: {n_days}")
|
||||||
print(f" Date range: {dates[0]} ~ {dates[-1]}")
|
log.print(f" Date range: {dates[0]} ~ {dates[-1]}")
|
||||||
print(f" 加载耗时: {load_time:.2f} 秒")
|
log.print(f" 加载耗时: {load_time:.2f} 秒")
|
||||||
|
|
||||||
# 2. 找有效数据范围(排除全 NaN 的列)
|
# 2. 找有效数据范围(排除全 NaN 的列)
|
||||||
any_valid = np.any(~np.isnan(close_mtx), axis=0)
|
any_valid = np.any(~np.isnan(close_mtx), axis=0)
|
||||||
valid_day_idx = np.where(any_valid)[0]
|
valid_day_idx = np.where(any_valid)[0]
|
||||||
if len(valid_day_idx) == 0:
|
if len(valid_day_idx) == 0:
|
||||||
print("No valid data found!")
|
log.print("No valid data found!")
|
||||||
return
|
return
|
||||||
last_valid_day = valid_day_idx[-1]
|
last_valid_day = valid_day_idx[-1]
|
||||||
print(f" Last valid day: {last_valid_day} (date: {dates[last_valid_day]})")
|
log.print(f" Last valid day: {last_valid_day} (date: {dates[last_valid_day]})")
|
||||||
|
|
||||||
# 3. 计算范围
|
# 3. 计算范围
|
||||||
start_day = PARAMS.window - 1
|
start_day = PARAMS.window - 1
|
||||||
@ -127,13 +163,35 @@ def main() -> None:
|
|||||||
if RECENT_DAYS is not None:
|
if RECENT_DAYS is not None:
|
||||||
start_day = max(start_day, end_day - RECENT_DAYS + 1)
|
start_day = max(start_day, end_day - RECENT_DAYS + 1)
|
||||||
|
|
||||||
print(f"\n[2] Detection range: day {start_day} ~ {end_day}")
|
log.print(f"\n[2] Detection range: day {start_day} ~ {end_day}")
|
||||||
print(f" Window size: {PARAMS.window}")
|
log.print(f" Window size: {PARAMS.window}")
|
||||||
print(f" Total points: {n_stocks} x {end_day - start_day + 1} = {n_stocks * (end_day - start_day + 1)}")
|
log.print(f" Total points: {n_stocks} x {end_day - start_day + 1} = {n_stocks * (end_day - start_day + 1)}")
|
||||||
|
|
||||||
# 3. 批量检测
|
# 3. 批量检测
|
||||||
print("\n[3] Running batch detection...")
|
log.print("\n[3] Running batch detection...")
|
||||||
detect_start = time.time()
|
detect_start = time.time()
|
||||||
|
|
||||||
|
# 根据配置选择v1或v2版本
|
||||||
|
if USE_V2_OPTIMIZATION and not REALTIME_MODE:
|
||||||
|
# v2优化版本(预计算枢轴点,不支持实时模式)
|
||||||
|
log.print(" 使用: v2优化版本(预计算枢轴点)")
|
||||||
|
df = detect_converging_triangle_batch_v2(
|
||||||
|
open_mtx=open_mtx,
|
||||||
|
high_mtx=high_mtx,
|
||||||
|
low_mtx=low_mtx,
|
||||||
|
close_mtx=close_mtx,
|
||||||
|
volume_mtx=volume_mtx,
|
||||||
|
params=PARAMS,
|
||||||
|
start_day=start_day,
|
||||||
|
end_day=end_day,
|
||||||
|
only_valid=ONLY_VALID,
|
||||||
|
verbose=VERBOSE,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# v1原版(支持实时模式)
|
||||||
|
if USE_V2_OPTIMIZATION and REALTIME_MODE:
|
||||||
|
log.print(" 注意: v2优化不支持实时模式,自动回退到v1")
|
||||||
|
log.print(" 使用: v1原版")
|
||||||
df = detect_converging_triangle_batch(
|
df = detect_converging_triangle_batch(
|
||||||
open_mtx=open_mtx,
|
open_mtx=open_mtx,
|
||||||
high_mtx=high_mtx,
|
high_mtx=high_mtx,
|
||||||
@ -145,56 +203,57 @@ def main() -> None:
|
|||||||
end_day=end_day,
|
end_day=end_day,
|
||||||
only_valid=ONLY_VALID,
|
only_valid=ONLY_VALID,
|
||||||
verbose=VERBOSE,
|
verbose=VERBOSE,
|
||||||
real_time_mode=REALTIME_MODE, # 新增
|
real_time_mode=REALTIME_MODE,
|
||||||
flexible_zone=FLEXIBLE_ZONE, # 新增
|
flexible_zone=FLEXIBLE_ZONE,
|
||||||
)
|
)
|
||||||
|
|
||||||
detect_time = time.time() - detect_start
|
detect_time = time.time() - detect_start
|
||||||
print(f" 检测耗时: {detect_time:.2f} 秒")
|
log.print(f" 检测耗时: {detect_time:.2f} 秒")
|
||||||
print(f" 检测模式: {'实时模式' if REALTIME_MODE else '标准模式'}")
|
if not USE_V2_OPTIMIZATION or REALTIME_MODE:
|
||||||
|
log.print(f" 检测模式: {'实时模式' if REALTIME_MODE else '标准模式'}")
|
||||||
if REALTIME_MODE:
|
if REALTIME_MODE:
|
||||||
print(f" 灵活区域: {FLEXIBLE_ZONE} 天")
|
log.print(f" 灵活区域: {FLEXIBLE_ZONE} 天")
|
||||||
|
|
||||||
# 4. 添加股票代码、名称和真实日期
|
# 4. 添加股票代码、名称和真实日期
|
||||||
if len(df) > 0:
|
if len(df) > 0:
|
||||||
df["stock_code"] = df["stock_idx"].map(lambda x: tkrs[x] if x < len(tkrs) else "")
|
df["stock_code"] = df["stock_idx"].map(lambda x: tkrs[x] if x < len(tkrs) else "")
|
||||||
df["stock_name"] = df["stock_idx"].map(lambda x: tkrs_name[x] if x < len(tkrs_name) else "")
|
df["stock_name"] = df["stock_idx"].map(lambda x: tkrs_name[x] if x < len(tkrs_name) else "")
|
||||||
df["date"] = df["date_idx"].map(lambda x: dates[x] if x < len(dates) else 0)
|
df["date"] = df["date_idx"].map(lambda x: dates[x] if x < len(dates) else 0)
|
||||||
|
# 计算综合强度(取向上和向下的最大值)
|
||||||
|
df["max_strength"] = df[["breakout_strength_up", "breakout_strength_down"]].max(axis=1)
|
||||||
|
|
||||||
# 5. 输出结果
|
# 5. 输出结果
|
||||||
print("\n" + "=" * 70)
|
log.print("\n" + "=" * 70)
|
||||||
print("Detection Results")
|
log.print("Detection Results")
|
||||||
print("=" * 70)
|
log.print("=" * 70)
|
||||||
|
|
||||||
if ONLY_VALID:
|
if ONLY_VALID:
|
||||||
print(f"\nTotal valid triangles detected: {len(df)}")
|
log.print(f"\nTotal valid triangles detected: {len(df)}")
|
||||||
else:
|
else:
|
||||||
valid_count = df["is_valid"].sum()
|
valid_count = df["is_valid"].sum()
|
||||||
print(f"\nTotal records: {len(df)}")
|
log.print(f"\nTotal records: {len(df)}")
|
||||||
print(f"Valid triangles: {valid_count} ({valid_count/len(df)*100:.1f}%)")
|
log.print(f"Valid triangles: {valid_count} ({valid_count/len(df)*100:.1f}%)")
|
||||||
|
|
||||||
# 按突破方向统计
|
# 按突破方向统计
|
||||||
if len(df) > 0 and "breakout_dir" in df.columns:
|
if len(df) > 0 and "breakout_dir" in df.columns:
|
||||||
breakout_stats = df[df["is_valid"] == True]["breakout_dir"].value_counts()
|
breakout_stats = df[df["is_valid"] == True]["breakout_dir"].value_counts()
|
||||||
print(f"\nBreakout statistics:")
|
log.print(f"\nBreakout statistics:")
|
||||||
for dir_name, count in breakout_stats.items():
|
for dir_name, count in breakout_stats.items():
|
||||||
print(f" - {dir_name}: {count}")
|
log.print(f" - {dir_name}: {count}")
|
||||||
|
|
||||||
# 突破强度统计
|
# 突破强度统计
|
||||||
if len(df) > 0 and "breakout_strength_up" in df.columns:
|
if len(df) > 0 and "breakout_strength_up" in df.columns:
|
||||||
valid_df = df[df["is_valid"] == True]
|
valid_df = df[df["is_valid"] == True]
|
||||||
if len(valid_df) > 0:
|
if len(valid_df) > 0:
|
||||||
print(f"\nBreakout strength (valid triangles):")
|
log.print(f"\nBreakout strength (valid triangles):")
|
||||||
print(f" - Up mean: {valid_df['breakout_strength_up'].mean():.4f}, max: {valid_df['breakout_strength_up'].max():.4f}")
|
log.print(f" - Up mean: {valid_df['breakout_strength_up'].mean():.4f}, max: {valid_df['breakout_strength_up'].max():.4f}")
|
||||||
print(f" - Down mean: {valid_df['breakout_strength_down'].mean():.4f}, max: {valid_df['breakout_strength_down'].max():.4f}")
|
log.print(f" - Down mean: {valid_df['breakout_strength_down'].mean():.4f}, max: {valid_df['breakout_strength_down'].max():.4f}")
|
||||||
|
|
||||||
# 6. 保存结果
|
# 6. 保存结果
|
||||||
outputs_dir = os.path.join(os.path.dirname(__file__), "..", "outputs", "converging_triangles")
|
|
||||||
os.makedirs(outputs_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# 保存完整 CSV (使用 utf-8-sig 支持中文)
|
# 保存完整 CSV (使用 utf-8-sig 支持中文)
|
||||||
csv_path = os.path.join(outputs_dir, "all_results.csv")
|
csv_path = os.path.join(outputs_dir, "all_results.csv")
|
||||||
df.to_csv(csv_path, index=False, encoding="utf-8-sig")
|
df.to_csv(csv_path, index=False, encoding="utf-8-sig")
|
||||||
print(f"\nResults saved to: {csv_path}")
|
log.print(f"\nResults saved to: {csv_path}")
|
||||||
|
|
||||||
# 保存高强度突破记录 (strength > 0.3)
|
# 保存高强度突破记录 (strength > 0.3)
|
||||||
if len(df) > 0:
|
if len(df) > 0:
|
||||||
@ -204,30 +263,79 @@ def main() -> None:
|
|||||||
if len(strong_up) > 0:
|
if len(strong_up) > 0:
|
||||||
strong_up_path = os.path.join(outputs_dir, "strong_breakout_up.csv")
|
strong_up_path = os.path.join(outputs_dir, "strong_breakout_up.csv")
|
||||||
strong_up.to_csv(strong_up_path, index=False, encoding="utf-8-sig")
|
strong_up.to_csv(strong_up_path, index=False, encoding="utf-8-sig")
|
||||||
print(f"Strong up breakouts ({len(strong_up)}): {strong_up_path}")
|
log.print(f"Strong up breakouts ({len(strong_up)}): {strong_up_path}")
|
||||||
|
|
||||||
if len(strong_down) > 0:
|
if len(strong_down) > 0:
|
||||||
strong_down_path = os.path.join(outputs_dir, "strong_breakout_down.csv")
|
strong_down_path = os.path.join(outputs_dir, "strong_breakout_down.csv")
|
||||||
strong_down.to_csv(strong_down_path, index=False, encoding="utf-8-sig")
|
strong_down.to_csv(strong_down_path, index=False, encoding="utf-8-sig")
|
||||||
print(f"Strong down breakouts ({len(strong_down)}): {strong_down_path}")
|
log.print(f"Strong down breakouts ({len(strong_down)}): {strong_down_path}")
|
||||||
|
|
||||||
# 7. 显示样本
|
# 7. 每日最佳股票报告(核心新功能)
|
||||||
|
if len(df) > 0 and "date" in df.columns:
|
||||||
|
log.print("\n" + "=" * 70)
|
||||||
|
log.print("每日最佳收敛三角形股票 (按日期)")
|
||||||
|
log.print("=" * 70)
|
||||||
|
|
||||||
|
valid_df = df[df["is_valid"] == True].copy()
|
||||||
|
if len(valid_df) > 0:
|
||||||
|
# 按日期分组,每天取分数最高的股票
|
||||||
|
daily_best = valid_df.loc[valid_df.groupby("date")["max_strength"].idxmax()]
|
||||||
|
daily_best = daily_best.sort_values("date", ascending=True)
|
||||||
|
|
||||||
|
log.print(f"\n共 {len(daily_best)} 个交易日检测到收敛三角形:")
|
||||||
|
log.print("-" * 90)
|
||||||
|
log.print(f"{'日期':<12} {'股票代码':<10} {'股票名称':<12} {'强度↑':>8} {'强度↓':>8} {'方向':<6} {'收敛比':>8}")
|
||||||
|
log.print("-" * 90)
|
||||||
|
|
||||||
|
for _, row in daily_best.iterrows():
|
||||||
|
log.print(
|
||||||
|
f"{int(row['date']):<12} "
|
||||||
|
f"{row['stock_code']:<10} "
|
||||||
|
f"{row['stock_name']:<12} "
|
||||||
|
f"{row['breakout_strength_up']:>8.4f} "
|
||||||
|
f"{row['breakout_strength_down']:>8.4f} "
|
||||||
|
f"{row['breakout_dir']:<6} "
|
||||||
|
f"{row['width_ratio']:>8.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
log.print("-" * 90)
|
||||||
|
|
||||||
|
# 保存每日最佳到 CSV
|
||||||
|
daily_best_path = os.path.join(outputs_dir, "daily_best.csv")
|
||||||
|
daily_best.to_csv(daily_best_path, index=False, encoding="utf-8-sig")
|
||||||
|
log.print(f"\n每日最佳已保存到: {daily_best_path}")
|
||||||
|
|
||||||
|
# 统计:每只股票被选为每日最佳的次数
|
||||||
|
log.print("\n" + "-" * 70)
|
||||||
|
log.print("股票被选为每日最佳的次数排行:")
|
||||||
|
log.print("-" * 70)
|
||||||
|
stock_counts = daily_best.groupby(["stock_code", "stock_name"]).size().reset_index(name="count")
|
||||||
|
stock_counts = stock_counts.sort_values("count", ascending=False).head(20)
|
||||||
|
for _, row in stock_counts.iterrows():
|
||||||
|
log.print(f" {row['stock_code']} {row['stock_name']}: {row['count']} 次")
|
||||||
|
else:
|
||||||
|
log.print("\n没有检测到有效的收敛三角形")
|
||||||
|
|
||||||
|
# 8. 显示样本
|
||||||
if len(df) > 0:
|
if len(df) > 0:
|
||||||
print("\n" + "-" * 70)
|
log.print("\n" + "-" * 70)
|
||||||
print("Sample results (first 10):")
|
log.print("Sample results (first 10):")
|
||||||
print("-" * 70)
|
log.print("-" * 70)
|
||||||
display_cols = [
|
display_cols = [
|
||||||
"stock_code", "date", "is_valid",
|
"stock_code", "date", "is_valid",
|
||||||
"breakout_strength_up", "breakout_strength_down",
|
"breakout_strength_up", "breakout_strength_down",
|
||||||
"breakout_dir", "width_ratio"
|
"breakout_dir", "width_ratio"
|
||||||
]
|
]
|
||||||
display_cols = [c for c in display_cols if c in df.columns]
|
display_cols = [c for c in display_cols if c in df.columns]
|
||||||
print(df[display_cols].head(10).to_string(index=False))
|
log.print(df[display_cols].head(10).to_string(index=False))
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
print("\n" + "=" * 70)
|
log.print("\n" + "=" * 70)
|
||||||
print(f"总耗时: {total_time:.2f} 秒 ({total_time/60:.2f} 分钟)")
|
log.print(f"总耗时: {total_time:.2f} 秒 ({total_time/60:.2f} 分钟)")
|
||||||
print("=" * 70)
|
log.print("=" * 70)
|
||||||
|
|
||||||
|
# 保存日志
|
||||||
|
log.save()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -120,6 +120,17 @@ SHOW_CHART_DETAILS = False # False=简洁模式(默认),True=详细模式
|
|||||||
# 简洁模式:仅显示收盘价、上沿线、下沿线
|
# 简洁模式:仅显示收盘价、上沿线、下沿线
|
||||||
# 详细模式:额外显示所有枢轴点、拟合点、分段线等调试信息
|
# 详细模式:额外显示所有枢轴点、拟合点、分段线等调试信息
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 性能优化配置
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
# 是否使用v2优化版本(预计算枢轴点)
|
||||||
|
USE_V2_OPTIMIZATION = False # True=v2优化,False=v1原版(推荐)
|
||||||
|
# v2优化通过预计算整个时间序列的枢轴点,避免滑动窗口重复计算
|
||||||
|
# 预期加速:3-5x
|
||||||
|
# 注意:v2版本暂不支持实时模式(REALTIME_MODE),会自动回退到标准模式
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# 推荐参数预设(备选方案)
|
# 推荐参数预设(备选方案)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
@ -1416,6 +1416,9 @@ try:
|
|||||||
calc_fitting_adherence_optimized,
|
calc_fitting_adherence_optimized,
|
||||||
calc_boundary_utilization_optimized,
|
calc_boundary_utilization_optimized,
|
||||||
calc_breakout_strength_optimized,
|
calc_breakout_strength_optimized,
|
||||||
|
# v2优化:预计算枢轴点
|
||||||
|
precompute_pivots_numba,
|
||||||
|
detect_batch_with_precomputed_pivots_numba,
|
||||||
)
|
)
|
||||||
# 用优化版本覆盖原版函数(在模块级别)
|
# 用优化版本覆盖原版函数(在模块级别)
|
||||||
pivots_fractal = pivots_fractal_optimized
|
pivots_fractal = pivots_fractal_optimized
|
||||||
@ -1425,7 +1428,148 @@ try:
|
|||||||
calc_boundary_utilization = calc_boundary_utilization_optimized
|
calc_boundary_utilization = calc_boundary_utilization_optimized
|
||||||
calc_breakout_strength = calc_breakout_strength_optimized
|
calc_breakout_strength = calc_breakout_strength_optimized
|
||||||
|
|
||||||
print("[性能优化] 已启用Numba加速 (预计加速300x)")
|
_HAS_V2_OPTIMIZATION = True
|
||||||
|
print("[性能优化] 已启用Numba加速 + 预计算枢轴点优化")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
_HAS_V2_OPTIMIZATION = False
|
||||||
print("[性能优化] 未启用Numba加速,使用原版函数")
|
print("[性能优化] 未启用Numba加速,使用原版函数")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 【v2优化】使用预计算枢轴点的批量检测
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def detect_converging_triangle_batch_v2(
|
||||||
|
open_mtx: np.ndarray,
|
||||||
|
high_mtx: np.ndarray,
|
||||||
|
low_mtx: np.ndarray,
|
||||||
|
close_mtx: np.ndarray,
|
||||||
|
volume_mtx: np.ndarray,
|
||||||
|
params: ConvergingTriangleParams,
|
||||||
|
start_day: Optional[int] = None,
|
||||||
|
end_day: Optional[int] = None,
|
||||||
|
only_valid: bool = False,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
【v2优化】使用预计算枢轴点的批量检测
|
||||||
|
|
||||||
|
优化策略:
|
||||||
|
1. 对每只股票,一次性预计算整个历史的枢轴点
|
||||||
|
2. 每个检测窗口直接读取预计算结果,避免重复计算
|
||||||
|
3. 整个股票的所有日期检测在一次Numba调用中完成
|
||||||
|
|
||||||
|
预期加速:3-5x(相比v1版本)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
与 detect_converging_triangle_batch 相同
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame(格式与v1相同)
|
||||||
|
"""
|
||||||
|
if not _HAS_V2_OPTIMIZATION:
|
||||||
|
print("[警告] v2优化未启用,回退到v1版本")
|
||||||
|
return detect_converging_triangle_batch(
|
||||||
|
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx,
|
||||||
|
params, start_day, end_day, only_valid, verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
n_stocks, n_days = close_mtx.shape
|
||||||
|
window = params.window
|
||||||
|
|
||||||
|
# 默认起止日
|
||||||
|
if start_day is None:
|
||||||
|
start_day = window - 1
|
||||||
|
if end_day is None:
|
||||||
|
end_day = n_days - 1
|
||||||
|
|
||||||
|
# 确保范围有效
|
||||||
|
start_day = max(window - 1, start_day)
|
||||||
|
end_day = min(n_days - 1, end_day)
|
||||||
|
|
||||||
|
results: List[dict] = []
|
||||||
|
|
||||||
|
for stock_idx in range(n_stocks):
|
||||||
|
# 提取该股票的全部数据
|
||||||
|
high_stock = high_mtx[stock_idx, :]
|
||||||
|
low_stock = low_mtx[stock_idx, :]
|
||||||
|
close_stock = close_mtx[stock_idx, :]
|
||||||
|
volume_stock = volume_mtx[stock_idx, :] if volume_mtx is not None else np.zeros(n_days)
|
||||||
|
|
||||||
|
# 找到有效数据的mask
|
||||||
|
valid_mask = ~np.isnan(close_stock)
|
||||||
|
valid_indices = np.where(valid_mask)[0].astype(np.int32)
|
||||||
|
|
||||||
|
if len(valid_indices) < window:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 【v2核心优化】预计算整个时间序列的枢轴点
|
||||||
|
is_pivot_high, is_pivot_low = precompute_pivots_numba(
|
||||||
|
high_stock, low_stock, k=params.pivot_k
|
||||||
|
)
|
||||||
|
|
||||||
|
# 批量检测(单次Numba调用处理所有日期)
|
||||||
|
(date_indices, is_valid_arr, strength_up_arr, strength_down_arr,
|
||||||
|
price_score_up_arr, price_score_down_arr, convergence_score_arr,
|
||||||
|
vol_score_arr, fitting_score_arr, boundary_util_score_arr,
|
||||||
|
upper_slope_arr, lower_slope_arr, width_ratio_arr,
|
||||||
|
touches_upper_arr, touches_lower_arr, apex_x_arr,
|
||||||
|
breakout_dir_arr, volume_confirmed_arr) = \
|
||||||
|
detect_batch_with_precomputed_pivots_numba(
|
||||||
|
high_stock, low_stock, close_stock, volume_stock,
|
||||||
|
is_pivot_high, is_pivot_low,
|
||||||
|
valid_indices, window, start_day, end_day,
|
||||||
|
params.upper_slope_max, params.lower_slope_min,
|
||||||
|
params.shrink_ratio, params.touch_tol, params.break_tol,
|
||||||
|
params.vol_window, params.vol_k
|
||||||
|
)
|
||||||
|
|
||||||
|
# 收集结果
|
||||||
|
for i in range(len(date_indices)):
|
||||||
|
if only_valid and not is_valid_arr[i]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 转换 breakout_dir
|
||||||
|
bd = breakout_dir_arr[i]
|
||||||
|
breakout_dir_str = "none" if bd == 0 else ("up" if bd == 1 else "down")
|
||||||
|
|
||||||
|
# 转换 volume_confirmed
|
||||||
|
vc = volume_confirmed_arr[i]
|
||||||
|
vol_confirmed = None if vc == -1 else bool(vc)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
'stock_idx': stock_idx,
|
||||||
|
'date_idx': int(date_indices[i]),
|
||||||
|
'is_valid': bool(is_valid_arr[i]),
|
||||||
|
'breakout_strength_up': float(strength_up_arr[i]),
|
||||||
|
'breakout_strength_down': float(strength_down_arr[i]),
|
||||||
|
'price_score_up': float(price_score_up_arr[i]),
|
||||||
|
'price_score_down': float(price_score_down_arr[i]),
|
||||||
|
'convergence_score': float(convergence_score_arr[i]),
|
||||||
|
'volume_score': float(vol_score_arr[i]),
|
||||||
|
'fitting_score': float(fitting_score_arr[i]),
|
||||||
|
'boundary_utilization': float(boundary_util_score_arr[i]),
|
||||||
|
'upper_slope': float(upper_slope_arr[i]),
|
||||||
|
'lower_slope': float(lower_slope_arr[i]),
|
||||||
|
'width_ratio': float(width_ratio_arr[i]),
|
||||||
|
'touches_upper': int(touches_upper_arr[i]),
|
||||||
|
'touches_lower': int(touches_lower_arr[i]),
|
||||||
|
'apex_x': float(apex_x_arr[i]),
|
||||||
|
'breakout_dir': breakout_dir_str,
|
||||||
|
'volume_confirmed': vol_confirmed,
|
||||||
|
'false_breakout': None,
|
||||||
|
'window_start': 0,
|
||||||
|
'window_end': int(date_indices[i]),
|
||||||
|
'detection_mode': 'v2_optimized',
|
||||||
|
'has_candidate_pivots': False,
|
||||||
|
'candidate_pivot_count': 0,
|
||||||
|
})
|
||||||
|
|
||||||
|
if verbose and (stock_idx + 1) % 20 == 0:
|
||||||
|
print(f" Progress: {stock_idx + 1}/{n_stocks} stocks")
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" Completed: {len(results)} results")
|
||||||
|
|
||||||
|
return pd.DataFrame(results)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
@ -6,6 +6,7 @@
|
|||||||
2. 优化枢轴点检测算法(避免重复的nanmax/nanmin调用)
|
2. 优化枢轴点检测算法(避免重复的nanmax/nanmin调用)
|
||||||
3. 优化边界拟合算法(向量化计算)
|
3. 优化边界拟合算法(向量化计算)
|
||||||
4. 减少不必要的数组复制
|
4. 减少不必要的数组复制
|
||||||
|
5. 【v2】预计算枢轴点矩阵,避免滑动窗口重复计算
|
||||||
|
|
||||||
不使用并行(按要求)。
|
不使用并行(按要求)。
|
||||||
"""
|
"""
|
||||||
@ -14,7 +15,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import numba
|
import numba
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Tuple
|
from typing import Tuple, List, Optional
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Numba优化的核心函数
|
# Numba优化的核心函数
|
||||||
@ -592,3 +593,385 @@ def calc_breakout_strength_optimized(
|
|||||||
close, upper_line, lower_line, volume_ratio,
|
close, upper_line, lower_line, volume_ratio,
|
||||||
width_ratio, fitting_adherence, boundary_utilization
|
width_ratio, fitting_adherence, boundary_utilization
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# 【v2优化】预计算枢轴点矩阵
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@numba.jit(nopython=True, cache=True)
|
||||||
|
def precompute_pivots_numba(
|
||||||
|
high: np.ndarray,
|
||||||
|
low: np.ndarray,
|
||||||
|
k: int = 15
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
预计算整个时间序列的枢轴点标记
|
||||||
|
|
||||||
|
优化思路:一次性计算所有枢轴点,避免滑动窗口重复计算
|
||||||
|
|
||||||
|
Args:
|
||||||
|
high: 最高价数组 (n_days,)
|
||||||
|
low: 最低价数组 (n_days,)
|
||||||
|
k: 窗口大小(左右各k天)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(is_pivot_high, is_pivot_low): 布尔数组,标记每天是否为枢轴点
|
||||||
|
"""
|
||||||
|
n = len(high)
|
||||||
|
is_pivot_high = np.zeros(n, dtype=np.bool_)
|
||||||
|
is_pivot_low = np.zeros(n, dtype=np.bool_)
|
||||||
|
|
||||||
|
for i in range(k, n - k):
|
||||||
|
if np.isnan(high[i]) or np.isnan(low[i]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 高点检测
|
||||||
|
is_ph = True
|
||||||
|
h_val = high[i]
|
||||||
|
for j in range(i - k, i + k + 1):
|
||||||
|
if j == i:
|
||||||
|
continue
|
||||||
|
if not np.isnan(high[j]) and high[j] > h_val:
|
||||||
|
is_ph = False
|
||||||
|
break
|
||||||
|
is_pivot_high[i] = is_ph
|
||||||
|
|
||||||
|
# 低点检测
|
||||||
|
is_pl = True
|
||||||
|
l_val = low[i]
|
||||||
|
for j in range(i - k, i + k + 1):
|
||||||
|
if j == i:
|
||||||
|
continue
|
||||||
|
if not np.isnan(low[j]) and low[j] < l_val:
|
||||||
|
is_pl = False
|
||||||
|
break
|
||||||
|
is_pivot_low[i] = is_pl
|
||||||
|
|
||||||
|
return is_pivot_high, is_pivot_low
|
||||||
|
|
||||||
|
|
||||||
|
@numba.jit(nopython=True, cache=True)
|
||||||
|
def get_pivots_in_window(
|
||||||
|
is_pivot: np.ndarray,
|
||||||
|
start: int,
|
||||||
|
end: int
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""从预计算的布尔数组中提取窗口内的枢轴点索引"""
|
||||||
|
# 计算窗口内枢轴点数量
|
||||||
|
count = 0
|
||||||
|
for i in range(start, end + 1):
|
||||||
|
if is_pivot[i]:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
# 提取索引
|
||||||
|
result = np.empty(count, dtype=np.int32)
|
||||||
|
idx = 0
|
||||||
|
for i in range(start, end + 1):
|
||||||
|
if is_pivot[i]:
|
||||||
|
result[idx] = i - start # 转换为窗口内的相对索引
|
||||||
|
idx += 1
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@numba.jit(nopython=True, cache=True)
|
||||||
|
def detect_single_with_precomputed_pivots(
|
||||||
|
high: np.ndarray,
|
||||||
|
low: np.ndarray,
|
||||||
|
close: np.ndarray,
|
||||||
|
volume: np.ndarray,
|
||||||
|
is_pivot_high: np.ndarray,
|
||||||
|
is_pivot_low: np.ndarray,
|
||||||
|
window_start: int,
|
||||||
|
window_end: int,
|
||||||
|
# 参数
|
||||||
|
upper_slope_max: float,
|
||||||
|
lower_slope_min: float,
|
||||||
|
shrink_ratio: float,
|
||||||
|
touch_tol: float,
|
||||||
|
break_tol: float,
|
||||||
|
vol_window: int,
|
||||||
|
vol_k: float,
|
||||||
|
) -> Tuple[bool, float, float, float, float, float, float, float, float,
|
||||||
|
float, float, float, int, int, float, int, int]:
|
||||||
|
"""
|
||||||
|
使用预计算枢轴点的单点检测(纯Numba实现)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(is_valid, strength_up, strength_down, price_score_up, price_score_down,
|
||||||
|
convergence_score, vol_score, fitting_score, boundary_util_score,
|
||||||
|
upper_slope, lower_slope, width_ratio, touches_upper, touches_lower,
|
||||||
|
apex_x, breakout_dir, volume_confirmed)
|
||||||
|
|
||||||
|
breakout_dir: 0=none, 1=up, 2=down
|
||||||
|
volume_confirmed: 0=False, 1=True, -1=None
|
||||||
|
"""
|
||||||
|
n = window_end - window_start + 1
|
||||||
|
|
||||||
|
# 默认无效结果
|
||||||
|
invalid_result = (False, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||||
|
0.0, 0.0, 0.0, 0, 0, 0.0, 0, -1)
|
||||||
|
|
||||||
|
if n < 50: # 最小窗口检查
|
||||||
|
return invalid_result
|
||||||
|
|
||||||
|
# 提取窗口内的枢轴点
|
||||||
|
ph_count = 0
|
||||||
|
pl_count = 0
|
||||||
|
for i in range(window_start, window_end + 1):
|
||||||
|
if is_pivot_high[i]:
|
||||||
|
ph_count += 1
|
||||||
|
if is_pivot_low[i]:
|
||||||
|
pl_count += 1
|
||||||
|
|
||||||
|
if ph_count < 2 or pl_count < 2:
|
||||||
|
return invalid_result
|
||||||
|
|
||||||
|
# 收集枢轴点索引和值
|
||||||
|
ph_indices = np.empty(ph_count, dtype=np.int32)
|
||||||
|
ph_values = np.empty(ph_count, dtype=np.float64)
|
||||||
|
pl_indices = np.empty(pl_count, dtype=np.int32)
|
||||||
|
pl_values = np.empty(pl_count, dtype=np.float64)
|
||||||
|
|
||||||
|
ph_idx = 0
|
||||||
|
pl_idx = 0
|
||||||
|
for i in range(window_start, window_end + 1):
|
||||||
|
rel_i = i - window_start # 相对索引
|
||||||
|
if is_pivot_high[i]:
|
||||||
|
ph_indices[ph_idx] = rel_i
|
||||||
|
ph_values[ph_idx] = high[i]
|
||||||
|
ph_idx += 1
|
||||||
|
if is_pivot_low[i]:
|
||||||
|
pl_indices[pl_idx] = rel_i
|
||||||
|
pl_values[pl_idx] = low[i]
|
||||||
|
pl_idx += 1
|
||||||
|
|
||||||
|
# 边界拟合(anchor方法)
|
||||||
|
high_win = high[window_start:window_end + 1]
|
||||||
|
low_win = low[window_start:window_end + 1]
|
||||||
|
|
||||||
|
a_u, b_u = fit_boundary_anchor_numba(
|
||||||
|
ph_indices.astype(np.float64), ph_values,
|
||||||
|
high_win, mode=0, coverage=0.95, exclude_last=1,
|
||||||
|
window_start=0, window_end=n - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
a_l, b_l = fit_boundary_anchor_numba(
|
||||||
|
pl_indices.astype(np.float64), pl_values,
|
||||||
|
low_win, mode=1, coverage=0.95, exclude_last=1,
|
||||||
|
window_start=0, window_end=n - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 斜率检查
|
||||||
|
if a_u > upper_slope_max or a_l < lower_slope_min:
|
||||||
|
return invalid_result
|
||||||
|
|
||||||
|
# 相向收敛检查
|
||||||
|
slope_tolerance = 0.01
|
||||||
|
both_descending = (a_u < -slope_tolerance) and (a_l < -slope_tolerance)
|
||||||
|
both_ascending = (a_u > slope_tolerance) and (a_l > slope_tolerance)
|
||||||
|
if both_descending or both_ascending:
|
||||||
|
return invalid_result
|
||||||
|
|
||||||
|
# 宽度收敛检查
|
||||||
|
start_idx = 0
|
||||||
|
end_idx = n - 1
|
||||||
|
upper_start = a_u * start_idx + b_u
|
||||||
|
lower_start = a_l * start_idx + b_l
|
||||||
|
upper_end = a_u * end_idx + b_u
|
||||||
|
lower_end = a_l * end_idx + b_l
|
||||||
|
|
||||||
|
width_start = upper_start - lower_start
|
||||||
|
width_end = upper_end - lower_end
|
||||||
|
|
||||||
|
if width_start <= 0 or width_end <= 0:
|
||||||
|
return invalid_result
|
||||||
|
|
||||||
|
width_ratio = width_end / width_start
|
||||||
|
if width_ratio > shrink_ratio:
|
||||||
|
return invalid_result
|
||||||
|
|
||||||
|
# 触碰检测
|
||||||
|
touches_upper = 0
|
||||||
|
touches_lower = 0
|
||||||
|
|
||||||
|
for i in range(ph_count):
|
||||||
|
line_y = a_u * ph_indices[i] + b_u
|
||||||
|
deviation = abs(ph_values[i] - line_y) / max(line_y, 1e-9)
|
||||||
|
if deviation <= touch_tol:
|
||||||
|
touches_upper += 1
|
||||||
|
|
||||||
|
for i in range(pl_count):
|
||||||
|
line_y = a_l * pl_indices[i] + b_l
|
||||||
|
deviation = abs(pl_values[i] - line_y) / max(line_y, 1e-9)
|
||||||
|
if deviation <= touch_tol:
|
||||||
|
touches_lower += 1
|
||||||
|
|
||||||
|
if touches_upper < 2 or touches_lower < 2:
|
||||||
|
return invalid_result
|
||||||
|
|
||||||
|
# Apex计算
|
||||||
|
denom = a_u - a_l
|
||||||
|
if abs(denom) > 1e-12:
|
||||||
|
apex_x = (b_l - b_u) / denom
|
||||||
|
else:
|
||||||
|
apex_x = 1e9
|
||||||
|
|
||||||
|
# 突破判定
|
||||||
|
close_val = close[window_end]
|
||||||
|
breakout_dir = 0 # 0=none, 1=up, 2=down
|
||||||
|
if close_val > upper_end * (1 + break_tol):
|
||||||
|
breakout_dir = 1
|
||||||
|
elif close_val < lower_end * (1 - break_tol):
|
||||||
|
breakout_dir = 2
|
||||||
|
|
||||||
|
# 成交量确认
|
||||||
|
volume_confirmed = -1 # -1 = None
|
||||||
|
volume_ratio = 1.0
|
||||||
|
if vol_window > 0 and window_end >= vol_window:
|
||||||
|
vol_sum = 0.0
|
||||||
|
for i in range(window_end - vol_window + 1, window_end + 1):
|
||||||
|
vol_sum += volume[i]
|
||||||
|
vol_ma = vol_sum / vol_window
|
||||||
|
if vol_ma > 0:
|
||||||
|
volume_ratio = volume[window_end] / vol_ma
|
||||||
|
if breakout_dir != 0:
|
||||||
|
volume_confirmed = 1 if volume[window_end] > vol_ma * vol_k else 0
|
||||||
|
|
||||||
|
# 计算拟合贴合度
|
||||||
|
adherence_upper = calc_fitting_adherence_numba(
|
||||||
|
ph_indices.astype(np.float64), ph_values, a_u, b_u
|
||||||
|
)
|
||||||
|
adherence_lower = calc_fitting_adherence_numba(
|
||||||
|
pl_indices.astype(np.float64), pl_values, a_l, b_l
|
||||||
|
)
|
||||||
|
fitting_adherence = (adherence_upper + adherence_lower) / 2.0
|
||||||
|
|
||||||
|
# 计算边界利用率
|
||||||
|
boundary_util = calc_boundary_utilization_numba(
|
||||||
|
high_win, low_win, a_u, b_u, a_l, b_l, 0, n - 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计算突破强度
|
||||||
|
(strength_up, strength_down, price_score_up, price_score_down,
|
||||||
|
convergence_score, vol_score, fitting_score, boundary_util_score) = \
|
||||||
|
calc_breakout_strength_numba(
|
||||||
|
close_val, upper_end, lower_end, volume_ratio,
|
||||||
|
width_ratio, fitting_adherence, boundary_util
|
||||||
|
)
|
||||||
|
|
||||||
|
return (True, strength_up, strength_down, price_score_up, price_score_down,
|
||||||
|
convergence_score, vol_score, fitting_score, boundary_util_score,
|
||||||
|
a_u, a_l, width_ratio, touches_upper, touches_lower,
|
||||||
|
apex_x, breakout_dir, volume_confirmed)
|
||||||
|
|
||||||
|
|
||||||
|
@numba.jit(nopython=True, cache=True)
|
||||||
|
def detect_batch_with_precomputed_pivots_numba(
|
||||||
|
high: np.ndarray,
|
||||||
|
low: np.ndarray,
|
||||||
|
close: np.ndarray,
|
||||||
|
volume: np.ndarray,
|
||||||
|
is_pivot_high: np.ndarray,
|
||||||
|
is_pivot_low: np.ndarray,
|
||||||
|
valid_indices: np.ndarray,
|
||||||
|
window: int,
|
||||||
|
start_day: int,
|
||||||
|
end_day: int,
|
||||||
|
# 参数
|
||||||
|
upper_slope_max: float,
|
||||||
|
lower_slope_min: float,
|
||||||
|
shrink_ratio: float,
|
||||||
|
touch_tol: float,
|
||||||
|
break_tol: float,
|
||||||
|
vol_window: int,
|
||||||
|
vol_k: float,
|
||||||
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray,
|
||||||
|
np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray,
|
||||||
|
np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray,
|
||||||
|
np.ndarray, np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
使用预计算枢轴点的批量检测(单只股票,多日期)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
多个数组,每个元素对应一个检测点
|
||||||
|
"""
|
||||||
|
n_valid = len(valid_indices)
|
||||||
|
n_dates = 0
|
||||||
|
|
||||||
|
# 计算有效日期数量
|
||||||
|
for valid_end in range(window - 1, n_valid):
|
||||||
|
orig_date_idx = valid_indices[valid_end]
|
||||||
|
if orig_date_idx >= start_day and orig_date_idx <= end_day:
|
||||||
|
n_dates += 1
|
||||||
|
|
||||||
|
# 预分配结果数组
|
||||||
|
date_indices = np.empty(n_dates, dtype=np.int32)
|
||||||
|
is_valid = np.zeros(n_dates, dtype=np.bool_)
|
||||||
|
strength_up = np.zeros(n_dates, dtype=np.float64)
|
||||||
|
strength_down = np.zeros(n_dates, dtype=np.float64)
|
||||||
|
price_score_up = np.zeros(n_dates, dtype=np.float64)
|
||||||
|
price_score_down = np.zeros(n_dates, dtype=np.float64)
|
||||||
|
convergence_score = np.zeros(n_dates, dtype=np.float64)
|
||||||
|
vol_score = np.zeros(n_dates, dtype=np.float64)
|
||||||
|
fitting_score = np.zeros(n_dates, dtype=np.float64)
|
||||||
|
boundary_util_score = np.zeros(n_dates, dtype=np.float64)
|
||||||
|
upper_slope = np.zeros(n_dates, dtype=np.float64)
|
||||||
|
lower_slope = np.zeros(n_dates, dtype=np.float64)
|
||||||
|
width_ratio = np.zeros(n_dates, dtype=np.float64)
|
||||||
|
touches_upper = np.zeros(n_dates, dtype=np.int32)
|
||||||
|
touches_lower = np.zeros(n_dates, dtype=np.int32)
|
||||||
|
apex_x = np.zeros(n_dates, dtype=np.float64)
|
||||||
|
breakout_dir = np.zeros(n_dates, dtype=np.int32)
|
||||||
|
volume_confirmed = np.zeros(n_dates, dtype=np.int32) - 1 # 默认-1
|
||||||
|
|
||||||
|
# 遍历日期
|
||||||
|
result_idx = 0
|
||||||
|
for valid_end in range(window - 1, n_valid):
|
||||||
|
orig_date_idx = valid_indices[valid_end]
|
||||||
|
|
||||||
|
if orig_date_idx < start_day or orig_date_idx > end_day:
|
||||||
|
continue
|
||||||
|
|
||||||
|
valid_start = valid_end - window + 1
|
||||||
|
window_start_orig = valid_indices[valid_start]
|
||||||
|
window_end_orig = valid_indices[valid_end]
|
||||||
|
|
||||||
|
date_indices[result_idx] = orig_date_idx
|
||||||
|
|
||||||
|
# 调用单点检测
|
||||||
|
result = detect_single_with_precomputed_pivots(
|
||||||
|
high, low, close, volume,
|
||||||
|
is_pivot_high, is_pivot_low,
|
||||||
|
window_start_orig, window_end_orig,
|
||||||
|
upper_slope_max, lower_slope_min, shrink_ratio,
|
||||||
|
touch_tol, break_tol, vol_window, vol_k
|
||||||
|
)
|
||||||
|
|
||||||
|
is_valid[result_idx] = result[0]
|
||||||
|
strength_up[result_idx] = result[1]
|
||||||
|
strength_down[result_idx] = result[2]
|
||||||
|
price_score_up[result_idx] = result[3]
|
||||||
|
price_score_down[result_idx] = result[4]
|
||||||
|
convergence_score[result_idx] = result[5]
|
||||||
|
vol_score[result_idx] = result[6]
|
||||||
|
fitting_score[result_idx] = result[7]
|
||||||
|
boundary_util_score[result_idx] = result[8]
|
||||||
|
upper_slope[result_idx] = result[9]
|
||||||
|
lower_slope[result_idx] = result[10]
|
||||||
|
width_ratio[result_idx] = result[11]
|
||||||
|
touches_upper[result_idx] = result[12]
|
||||||
|
touches_lower[result_idx] = result[13]
|
||||||
|
apex_x[result_idx] = result[14]
|
||||||
|
breakout_dir[result_idx] = result[15]
|
||||||
|
volume_confirmed[result_idx] = result[16]
|
||||||
|
|
||||||
|
result_idx += 1
|
||||||
|
|
||||||
|
return (date_indices, is_valid, strength_up, strength_down,
|
||||||
|
price_score_up, price_score_down, convergence_score,
|
||||||
|
vol_score, fitting_score, boundary_util_score,
|
||||||
|
upper_slope, lower_slope, width_ratio,
|
||||||
|
touches_upper, touches_lower, apex_x,
|
||||||
|
breakout_dir, volume_confirmed)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user