- Created README.md and USAGE.md for project overview and usage instructions. - Added core algorithm in src/converging_triangle.py for batch processing of stock data. - Introduced data files (open.pkl, high.pkl, low.pkl, close.pkl, volume.pkl) for OHLCV data. - Developed output documentation for results and breakout strength calculations. - Implemented scripts for running the detection and generating reports. - Added SVG visualizations and markdown documentation for algorithm details and usage examples.
510 lines
16 KiB
Python
510 lines
16 KiB
Python
"""
|
||
收敛三角形检测算法 (Converging Triangle Detection)
|
||
|
||
支持:
|
||
- 二维矩阵批量输入 (n_stocks, n_days)
|
||
- 历史滚动计算 (每个交易日往过去看 window 天)
|
||
- 突破强度分数 (0~1)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass, field, asdict
|
||
from typing import List, Literal, Optional, Tuple
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
|
||
# ============================================================================
|
||
# 参数对象
|
||
# ============================================================================
|
||
|
||
@dataclass
|
||
class ConvergingTriangleParams:
|
||
"""收敛三角形检测参数"""
|
||
|
||
# 窗口设置
|
||
window: int = 400
|
||
|
||
# 枢轴点检测
|
||
pivot_k: int = 20
|
||
|
||
# 边界线拟合
|
||
boundary_n_segments: int = 2
|
||
boundary_source: str = "full" # "full" | "pivots"
|
||
|
||
# 斜率约束
|
||
upper_slope_max: float = 0.10
|
||
lower_slope_min: float = -0.10
|
||
|
||
# 触碰判定
|
||
touch_tol: float = 0.10
|
||
touch_loss_max: float = 0.10
|
||
|
||
# 收敛要求
|
||
shrink_ratio: float = 0.8
|
||
|
||
# 突破判定
|
||
break_tol: float = 0.001
|
||
vol_window: int = 20
|
||
vol_k: float = 1.3
|
||
false_break_m: int = 5
|
||
|
||
|
||
# ============================================================================
|
||
# 返回对象
|
||
# ============================================================================
|
||
|
||
@dataclass
|
||
class ConvergingTriangleResult:
|
||
"""收敛三角形检测结果 (单个股票单个日期)"""
|
||
|
||
# 基础标识
|
||
stock_idx: int
|
||
date_idx: int
|
||
is_valid: bool # 是否识别到有效三角形
|
||
|
||
# 突破强度 (0~1 连续分数)
|
||
breakout_strength_up: float = 0.0
|
||
breakout_strength_down: float = 0.0
|
||
|
||
# 几何属性
|
||
upper_slope: float = 0.0
|
||
lower_slope: float = 0.0
|
||
width_ratio: float = 0.0
|
||
touches_upper: int = 0
|
||
touches_lower: int = 0
|
||
apex_x: float = 0.0
|
||
|
||
# 突破状态
|
||
breakout_dir: str = "none" # "up" | "down" | "none"
|
||
volume_confirmed: Optional[bool] = None
|
||
false_breakout: Optional[bool] = None
|
||
|
||
# 窗口范围
|
||
window_start: int = 0
|
||
window_end: int = 0
|
||
|
||
def to_dict(self) -> dict:
|
||
"""转换为字典"""
|
||
return asdict(self)
|
||
|
||
|
||
# ============================================================================
|
||
# 基础工具函数 (从 sym_triangle.py 复用)
|
||
# ============================================================================
|
||
|
||
def pivots_fractal(
|
||
high: np.ndarray, low: np.ndarray, k: int = 3
|
||
) -> Tuple[np.ndarray, np.ndarray]:
|
||
"""左右窗口分形:返回 pivot_high_idx, pivot_low_idx"""
|
||
n = len(high)
|
||
ph: List[int] = []
|
||
pl: List[int] = []
|
||
for i in range(k, n - k):
|
||
if high[i] == np.max(high[i - k : i + k + 1]):
|
||
ph.append(i)
|
||
if low[i] == np.min(low[i - k : i + k + 1]):
|
||
pl.append(i)
|
||
return np.array(ph, dtype=int), np.array(pl, dtype=int)
|
||
|
||
|
||
def fit_line(x: np.ndarray, y: np.ndarray) -> Tuple[float, float]:
|
||
"""拟合 y = a*x + b"""
|
||
if len(x) < 2:
|
||
return 0.0, float(y[0]) if len(y) > 0 else 0.0
|
||
a, b = np.polyfit(x, y, deg=1)
|
||
return float(a), float(b)
|
||
|
||
|
||
def fit_boundary_line(
|
||
x: np.ndarray, y: np.ndarray, mode: str = "upper", n_segments: int = 3
|
||
) -> Tuple[float, float]:
|
||
"""
|
||
边界线拟合(分段取极值)
|
||
- mode="upper": 每段取最高点
|
||
- mode="lower": 每段取最低点
|
||
"""
|
||
if len(x) < 2:
|
||
return fit_line(x, y)
|
||
|
||
n_segments = min(n_segments, len(x))
|
||
if n_segments < 2:
|
||
n_segments = 2
|
||
|
||
sort_idx = np.argsort(x)
|
||
x_sorted = x[sort_idx]
|
||
y_sorted = y[sort_idx]
|
||
|
||
segment_size = len(x_sorted) // n_segments
|
||
boundary_x = []
|
||
boundary_y = []
|
||
|
||
for i in range(n_segments):
|
||
start = i * segment_size
|
||
if i == n_segments - 1:
|
||
end = len(x_sorted)
|
||
else:
|
||
end = (i + 1) * segment_size
|
||
|
||
if start >= end:
|
||
continue
|
||
|
||
seg_x = x_sorted[start:end]
|
||
seg_y = y_sorted[start:end]
|
||
|
||
if mode == "upper":
|
||
idx = np.argmax(seg_y)
|
||
else:
|
||
idx = np.argmin(seg_y)
|
||
|
||
boundary_x.append(seg_x[idx])
|
||
boundary_y.append(seg_y[idx])
|
||
|
||
if len(boundary_x) < 2:
|
||
return fit_line(x, y)
|
||
|
||
return fit_line(np.array(boundary_x), np.array(boundary_y))
|
||
|
||
|
||
def line_y(a: float, b: float, x: np.ndarray) -> np.ndarray:
|
||
"""计算线上的 y 值"""
|
||
return a * x + b
|
||
|
||
|
||
# ============================================================================
|
||
# 突破强度计算
|
||
# ============================================================================
|
||
|
||
def calc_breakout_strength(
|
||
close: float,
|
||
upper_line: float,
|
||
lower_line: float,
|
||
volume_ratio: float,
|
||
width_ratio: float,
|
||
) -> Tuple[float, float]:
|
||
"""
|
||
计算向上/向下突破强度 (0~1)
|
||
|
||
综合考虑:
|
||
- 价格突破幅度 (close 相对于上/下沿的距离)
|
||
- 成交量放大倍数
|
||
- 收敛程度 (width_ratio 越小越强)
|
||
|
||
Returns:
|
||
(strength_up, strength_down)
|
||
"""
|
||
# 价格突破分数
|
||
price_up = max(0, (close - upper_line) / upper_line) if upper_line > 0 else 0
|
||
price_down = max(0, (lower_line - close) / lower_line) if lower_line > 0 else 0
|
||
|
||
# 收敛加成 (越收敛, 突破越有效)
|
||
convergence_bonus = max(0, 1 - width_ratio)
|
||
|
||
# 成交量加成 (放量2倍=满分)
|
||
vol_bonus = min(1, max(0, volume_ratio - 1)) if volume_ratio > 0 else 0
|
||
|
||
# 加权合成
|
||
# 基础分数 * 收敛加成 * 成交量加成
|
||
strength_up = min(1.0, price_up * 5 * (1 + convergence_bonus * 0.5) * (1 + vol_bonus * 0.5))
|
||
strength_down = min(1.0, price_down * 5 * (1 + convergence_bonus * 0.5) * (1 + vol_bonus * 0.5))
|
||
|
||
return strength_up, strength_down
|
||
|
||
|
||
# ============================================================================
|
||
# 单点检测函数
|
||
# ============================================================================
|
||
|
||
def detect_converging_triangle(
|
||
high: np.ndarray,
|
||
low: np.ndarray,
|
||
close: np.ndarray,
|
||
volume: Optional[np.ndarray],
|
||
params: ConvergingTriangleParams,
|
||
stock_idx: int = 0,
|
||
date_idx: int = 0,
|
||
) -> ConvergingTriangleResult:
|
||
"""
|
||
检测单个窗口是否存在收敛三角形
|
||
|
||
Args:
|
||
high, low, close: 窗口内的价格数据 (长度 = window)
|
||
volume: 窗口内的成交量数据 (可选)
|
||
params: 检测参数
|
||
stock_idx: 股票索引 (用于结果标识)
|
||
date_idx: 日期索引 (用于结果标识)
|
||
|
||
Returns:
|
||
ConvergingTriangleResult
|
||
"""
|
||
n = len(close)
|
||
window = params.window
|
||
|
||
# 创建默认无效结果
|
||
invalid_result = ConvergingTriangleResult(
|
||
stock_idx=stock_idx,
|
||
date_idx=date_idx,
|
||
is_valid=False,
|
||
window_start=max(0, n - window),
|
||
window_end=n - 1,
|
||
)
|
||
|
||
# 数据长度检查
|
||
if n < max(window, 2 * params.pivot_k + 5):
|
||
return invalid_result
|
||
|
||
# 计算枢轴点
|
||
ph_idx, pl_idx = pivots_fractal(high, low, k=params.pivot_k)
|
||
|
||
end = n - 1
|
||
start = max(0, end - window + 1)
|
||
x_all = np.arange(n, dtype=float)
|
||
|
||
# 筛选窗口内的枢轴点
|
||
ph_in = ph_idx[(ph_idx >= start) & (ph_idx <= end)]
|
||
pl_in = pl_idx[(pl_idx >= start) & (pl_idx <= end)]
|
||
|
||
if len(ph_in) < 2 or len(pl_in) < 2:
|
||
return invalid_result
|
||
|
||
# 拟合边界线
|
||
if params.boundary_source == "full":
|
||
x_upper = x_all[start : end + 1]
|
||
y_upper = high[start : end + 1]
|
||
x_lower = x_all[start : end + 1]
|
||
y_lower = low[start : end + 1]
|
||
else:
|
||
x_upper = x_all[ph_in]
|
||
y_upper = high[ph_in]
|
||
x_lower = x_all[pl_in]
|
||
y_lower = low[pl_in]
|
||
|
||
a_u, b_u = fit_boundary_line(x_upper, y_upper, mode="upper", n_segments=params.boundary_n_segments)
|
||
a_l, b_l = fit_boundary_line(x_lower, y_lower, mode="lower", n_segments=params.boundary_n_segments)
|
||
|
||
# 斜率检查
|
||
if not (a_u <= params.upper_slope_max and a_l >= params.lower_slope_min):
|
||
return invalid_result
|
||
|
||
# 宽度收敛检查
|
||
upper_start = float(line_y(a_u, b_u, np.array([start]))[0])
|
||
lower_start = float(line_y(a_l, b_l, np.array([start]))[0])
|
||
upper_end = float(line_y(a_u, b_u, np.array([end]))[0])
|
||
lower_end = float(line_y(a_l, b_l, np.array([end]))[0])
|
||
|
||
width_start = upper_start - lower_start
|
||
width_end = upper_end - lower_end
|
||
|
||
if width_start <= 0 or width_end <= 0:
|
||
return invalid_result
|
||
|
||
width_ratio = width_end / width_start
|
||
if width_ratio > params.shrink_ratio:
|
||
return invalid_result
|
||
|
||
# 触碰程度检查
|
||
ph_dist = np.abs(high[ph_in] - line_y(a_u, b_u, x_all[ph_in])) / np.maximum(
|
||
line_y(a_u, b_u, x_all[ph_in]), 1e-9
|
||
)
|
||
pl_dist = np.abs(low[pl_in] - line_y(a_l, b_l, x_all[pl_in])) / np.maximum(
|
||
line_y(a_l, b_l, x_all[pl_in]), 1e-9
|
||
)
|
||
|
||
touches_upper = int((ph_dist <= params.touch_tol).sum())
|
||
touches_lower = int((pl_dist <= params.touch_tol).sum())
|
||
|
||
loss_upper = float(np.mean(ph_dist)) if len(ph_dist) else float("inf")
|
||
loss_lower = float(np.mean(pl_dist)) if len(pl_dist) else float("inf")
|
||
|
||
if loss_upper > params.touch_loss_max or loss_lower > params.touch_loss_max:
|
||
return invalid_result
|
||
|
||
# Apex 计算
|
||
denom = a_u - a_l
|
||
apex_x = float((b_l - b_u) / denom) if abs(denom) > 1e-12 else float("inf")
|
||
|
||
# 突破判定
|
||
breakout_dir: Literal["up", "down", "none"] = "none"
|
||
breakout_idx: Optional[int] = None
|
||
|
||
if close[end] > upper_end * (1 + params.break_tol):
|
||
breakout_dir = "up"
|
||
breakout_idx = end
|
||
elif close[end] < lower_end * (1 - params.break_tol):
|
||
breakout_dir = "down"
|
||
breakout_idx = end
|
||
|
||
# 成交量确认
|
||
volume_confirmed: Optional[bool] = None
|
||
volume_ratio = 1.0
|
||
|
||
if volume is not None and len(volume) >= params.vol_window:
|
||
vol_ma = np.mean(volume[-params.vol_window:])
|
||
if vol_ma > 0:
|
||
volume_ratio = volume[-1] / vol_ma
|
||
if breakout_dir != "none":
|
||
volume_confirmed = bool(volume[-1] > vol_ma * params.vol_k)
|
||
|
||
# 假突破检测 (历史回测模式,实时无法得知)
|
||
false_breakout: Optional[bool] = None
|
||
# 注意: 这里是基于历史数据,无法检测假突破
|
||
# 假突破需要看"未来"数据,与当前设计不符
|
||
|
||
# 计算突破强度
|
||
strength_up, strength_down = calc_breakout_strength(
|
||
close=close[end],
|
||
upper_line=upper_end,
|
||
lower_line=lower_end,
|
||
volume_ratio=volume_ratio,
|
||
width_ratio=width_ratio,
|
||
)
|
||
|
||
return ConvergingTriangleResult(
|
||
stock_idx=stock_idx,
|
||
date_idx=date_idx,
|
||
is_valid=True,
|
||
breakout_strength_up=strength_up,
|
||
breakout_strength_down=strength_down,
|
||
upper_slope=a_u,
|
||
lower_slope=a_l,
|
||
width_ratio=width_ratio,
|
||
touches_upper=touches_upper,
|
||
touches_lower=touches_lower,
|
||
apex_x=apex_x,
|
||
breakout_dir=breakout_dir,
|
||
volume_confirmed=volume_confirmed,
|
||
false_breakout=false_breakout,
|
||
window_start=start,
|
||
window_end=end,
|
||
)
|
||
|
||
|
||
# ============================================================================
|
||
# 批量滚动检测函数
|
||
# ============================================================================
|
||
|
||
def detect_converging_triangle_batch(
|
||
open_mtx: np.ndarray,
|
||
high_mtx: np.ndarray,
|
||
low_mtx: np.ndarray,
|
||
close_mtx: np.ndarray,
|
||
volume_mtx: np.ndarray,
|
||
params: ConvergingTriangleParams,
|
||
start_day: Optional[int] = None,
|
||
end_day: Optional[int] = None,
|
||
only_valid: bool = False,
|
||
verbose: bool = False,
|
||
) -> pd.DataFrame:
|
||
"""
|
||
批量滚动检测收敛三角形
|
||
|
||
Args:
|
||
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx:
|
||
OHLCV 矩阵, shape=(n_stocks, n_days)
|
||
params: 检测参数
|
||
start_day: 从哪一天开始计算 (默认: window-1, 即第一个有足够历史的点)
|
||
end_day: 到哪一天结束 (默认: 最后一天)
|
||
only_valid: 是否只返回识别到三角形的记录
|
||
verbose: 是否打印进度
|
||
|
||
Returns:
|
||
DataFrame with columns:
|
||
- stock_idx, date_idx
|
||
- is_valid
|
||
- breakout_strength_up, breakout_strength_down
|
||
- upper_slope, lower_slope, width_ratio
|
||
- touches_upper, touches_lower, apex_x
|
||
- breakout_dir, volume_confirmed, false_breakout
|
||
- window_start, window_end
|
||
"""
|
||
n_stocks, n_days = close_mtx.shape
|
||
window = params.window
|
||
|
||
# 默认起止日
|
||
if start_day is None:
|
||
start_day = window - 1
|
||
if end_day is None:
|
||
end_day = n_days - 1
|
||
|
||
# 确保范围有效
|
||
start_day = max(window - 1, start_day)
|
||
end_day = min(n_days - 1, end_day)
|
||
|
||
results: List[dict] = []
|
||
total = n_stocks * (end_day - start_day + 1)
|
||
processed = 0
|
||
|
||
for stock_idx in range(n_stocks):
|
||
# 提取该股票的全部数据,并过滤 NaN
|
||
high_stock = high_mtx[stock_idx, :]
|
||
low_stock = low_mtx[stock_idx, :]
|
||
close_stock = close_mtx[stock_idx, :]
|
||
volume_stock = volume_mtx[stock_idx, :] if volume_mtx is not None else None
|
||
|
||
# 找到有效数据的 mask(非 NaN)
|
||
valid_mask = ~np.isnan(close_stock)
|
||
valid_indices = np.where(valid_mask)[0]
|
||
|
||
if len(valid_indices) < window:
|
||
# 该股票有效数据不足一个窗口
|
||
for date_idx in range(start_day, end_day + 1):
|
||
if not only_valid:
|
||
results.append(ConvergingTriangleResult(
|
||
stock_idx=stock_idx,
|
||
date_idx=date_idx,
|
||
is_valid=False,
|
||
).to_dict())
|
||
processed += 1
|
||
continue
|
||
|
||
# 提取有效数据
|
||
high_valid = high_stock[valid_mask]
|
||
low_valid = low_stock[valid_mask]
|
||
close_valid = close_stock[valid_mask]
|
||
volume_valid = volume_stock[valid_mask] if volume_stock is not None else None
|
||
|
||
# 在有效数据上滚动
|
||
n_valid = len(close_valid)
|
||
for valid_end in range(window - 1, n_valid):
|
||
# 原始数据中的 date_idx
|
||
orig_date_idx = valid_indices[valid_end]
|
||
|
||
# 检查是否在指定范围内
|
||
if orig_date_idx < start_day or orig_date_idx > end_day:
|
||
continue
|
||
|
||
# 提取窗口
|
||
valid_start = valid_end - window + 1
|
||
high_win = high_valid[valid_start:valid_end + 1]
|
||
low_win = low_valid[valid_start:valid_end + 1]
|
||
close_win = close_valid[valid_start:valid_end + 1]
|
||
volume_win = volume_valid[valid_start:valid_end + 1] if volume_valid is not None else None
|
||
|
||
# 检测
|
||
result = detect_converging_triangle(
|
||
high=high_win,
|
||
low=low_win,
|
||
close=close_win,
|
||
volume=volume_win,
|
||
params=params,
|
||
stock_idx=stock_idx,
|
||
date_idx=orig_date_idx,
|
||
)
|
||
|
||
if only_valid and not result.is_valid:
|
||
processed += 1
|
||
continue
|
||
|
||
results.append(result.to_dict())
|
||
processed += 1
|
||
|
||
if verbose and (stock_idx + 1) % 10 == 0:
|
||
print(f" Progress: {stock_idx + 1}/{n_stocks} stocks, {processed}/{total} points")
|
||
|
||
if verbose:
|
||
print(f" Completed: {processed} points processed")
|
||
|
||
return pd.DataFrame(results)
|