""" 收敛三角形检测算法 (Converging Triangle Detection) 支持: - 二维矩阵批量输入 (n_stocks, n_days) - 历史滚动计算 (每个交易日往过去看 window 天) - 突破强度分数 (0~1) """ from __future__ import annotations from dataclasses import dataclass, field, asdict from typing import List, Literal, Optional, Tuple import numpy as np import pandas as pd # ============================================================================ # 参数对象 # ============================================================================ @dataclass class ConvergingTriangleParams: """收敛三角形检测参数""" # 窗口设置 window: int = 400 # 枢轴点检测 pivot_k: int = 20 # 边界线拟合 boundary_n_segments: int = 2 boundary_source: str = "full" # "full" | "pivots" # 斜率约束 upper_slope_max: float = 0.10 lower_slope_min: float = -0.10 # 触碰判定 touch_tol: float = 0.10 touch_loss_max: float = 0.10 # 收敛要求 shrink_ratio: float = 0.8 # 突破判定 break_tol: float = 0.001 vol_window: int = 20 vol_k: float = 1.3 false_break_m: int = 5 # ============================================================================ # 返回对象 # ============================================================================ @dataclass class ConvergingTriangleResult: """收敛三角形检测结果 (单个股票单个日期)""" # 基础标识 stock_idx: int date_idx: int is_valid: bool # 是否识别到有效三角形 # 突破强度 (0~1 连续分数) breakout_strength_up: float = 0.0 breakout_strength_down: float = 0.0 # 几何属性 upper_slope: float = 0.0 lower_slope: float = 0.0 width_ratio: float = 0.0 touches_upper: int = 0 touches_lower: int = 0 apex_x: float = 0.0 # 突破状态 breakout_dir: str = "none" # "up" | "down" | "none" volume_confirmed: Optional[bool] = None false_breakout: Optional[bool] = None # 窗口范围 window_start: int = 0 window_end: int = 0 def to_dict(self) -> dict: """转换为字典""" return asdict(self) # ============================================================================ # 基础工具函数 (从 sym_triangle.py 复用) # ============================================================================ def pivots_fractal( high: np.ndarray, low: np.ndarray, k: int = 3 ) -> Tuple[np.ndarray, np.ndarray]: """左右窗口分形:返回 pivot_high_idx, pivot_low_idx""" n = len(high) ph: List[int] = [] pl: List[int] = [] for i in range(k, n - k): if high[i] == np.max(high[i - k : i + k + 1]): ph.append(i) if low[i] == np.min(low[i - k : i + k + 1]): pl.append(i) return np.array(ph, dtype=int), np.array(pl, dtype=int) def fit_line(x: np.ndarray, y: np.ndarray) -> Tuple[float, float]: """拟合 y = a*x + b""" if len(x) < 2: return 0.0, float(y[0]) if len(y) > 0 else 0.0 a, b = np.polyfit(x, y, deg=1) return float(a), float(b) def fit_boundary_line( x: np.ndarray, y: np.ndarray, mode: str = "upper", n_segments: int = 3 ) -> Tuple[float, float]: """ 边界线拟合(分段取极值) - mode="upper": 每段取最高点 - mode="lower": 每段取最低点 """ if len(x) < 2: return fit_line(x, y) n_segments = min(n_segments, len(x)) if n_segments < 2: n_segments = 2 sort_idx = np.argsort(x) x_sorted = x[sort_idx] y_sorted = y[sort_idx] segment_size = len(x_sorted) // n_segments boundary_x = [] boundary_y = [] for i in range(n_segments): start = i * segment_size if i == n_segments - 1: end = len(x_sorted) else: end = (i + 1) * segment_size if start >= end: continue seg_x = x_sorted[start:end] seg_y = y_sorted[start:end] if mode == "upper": idx = np.argmax(seg_y) else: idx = np.argmin(seg_y) boundary_x.append(seg_x[idx]) boundary_y.append(seg_y[idx]) if len(boundary_x) < 2: return fit_line(x, y) return fit_line(np.array(boundary_x), np.array(boundary_y)) def line_y(a: float, b: float, x: np.ndarray) -> np.ndarray: """计算线上的 y 值""" return a * x + b # ============================================================================ # 突破强度计算 # ============================================================================ def calc_breakout_strength( close: float, upper_line: float, lower_line: float, volume_ratio: float, width_ratio: float, ) -> Tuple[float, float]: """ 计算向上/向下突破强度 (0~1) 综合考虑: - 价格突破幅度 (close 相对于上/下沿的距离) - 成交量放大倍数 - 收敛程度 (width_ratio 越小越强) Returns: (strength_up, strength_down) """ # 价格突破分数 price_up = max(0, (close - upper_line) / upper_line) if upper_line > 0 else 0 price_down = max(0, (lower_line - close) / lower_line) if lower_line > 0 else 0 # 收敛加成 (越收敛, 突破越有效) convergence_bonus = max(0, 1 - width_ratio) # 成交量加成 (放量2倍=满分) vol_bonus = min(1, max(0, volume_ratio - 1)) if volume_ratio > 0 else 0 # 加权合成 # 基础分数 * 收敛加成 * 成交量加成 strength_up = min(1.0, price_up * 5 * (1 + convergence_bonus * 0.5) * (1 + vol_bonus * 0.5)) strength_down = min(1.0, price_down * 5 * (1 + convergence_bonus * 0.5) * (1 + vol_bonus * 0.5)) return strength_up, strength_down # ============================================================================ # 单点检测函数 # ============================================================================ def detect_converging_triangle( high: np.ndarray, low: np.ndarray, close: np.ndarray, volume: Optional[np.ndarray], params: ConvergingTriangleParams, stock_idx: int = 0, date_idx: int = 0, ) -> ConvergingTriangleResult: """ 检测单个窗口是否存在收敛三角形 Args: high, low, close: 窗口内的价格数据 (长度 = window) volume: 窗口内的成交量数据 (可选) params: 检测参数 stock_idx: 股票索引 (用于结果标识) date_idx: 日期索引 (用于结果标识) Returns: ConvergingTriangleResult """ n = len(close) window = params.window # 创建默认无效结果 invalid_result = ConvergingTriangleResult( stock_idx=stock_idx, date_idx=date_idx, is_valid=False, window_start=max(0, n - window), window_end=n - 1, ) # 数据长度检查 if n < max(window, 2 * params.pivot_k + 5): return invalid_result # 计算枢轴点 ph_idx, pl_idx = pivots_fractal(high, low, k=params.pivot_k) end = n - 1 start = max(0, end - window + 1) x_all = np.arange(n, dtype=float) # 筛选窗口内的枢轴点 ph_in = ph_idx[(ph_idx >= start) & (ph_idx <= end)] pl_in = pl_idx[(pl_idx >= start) & (pl_idx <= end)] if len(ph_in) < 2 or len(pl_in) < 2: return invalid_result # 拟合边界线 if params.boundary_source == "full": x_upper = x_all[start : end + 1] y_upper = high[start : end + 1] x_lower = x_all[start : end + 1] y_lower = low[start : end + 1] else: x_upper = x_all[ph_in] y_upper = high[ph_in] x_lower = x_all[pl_in] y_lower = low[pl_in] a_u, b_u = fit_boundary_line(x_upper, y_upper, mode="upper", n_segments=params.boundary_n_segments) a_l, b_l = fit_boundary_line(x_lower, y_lower, mode="lower", n_segments=params.boundary_n_segments) # 斜率检查 if not (a_u <= params.upper_slope_max and a_l >= params.lower_slope_min): return invalid_result # 宽度收敛检查 upper_start = float(line_y(a_u, b_u, np.array([start]))[0]) lower_start = float(line_y(a_l, b_l, np.array([start]))[0]) upper_end = float(line_y(a_u, b_u, np.array([end]))[0]) lower_end = float(line_y(a_l, b_l, np.array([end]))[0]) width_start = upper_start - lower_start width_end = upper_end - lower_end if width_start <= 0 or width_end <= 0: return invalid_result width_ratio = width_end / width_start if width_ratio > params.shrink_ratio: return invalid_result # 触碰程度检查 ph_dist = np.abs(high[ph_in] - line_y(a_u, b_u, x_all[ph_in])) / np.maximum( line_y(a_u, b_u, x_all[ph_in]), 1e-9 ) pl_dist = np.abs(low[pl_in] - line_y(a_l, b_l, x_all[pl_in])) / np.maximum( line_y(a_l, b_l, x_all[pl_in]), 1e-9 ) touches_upper = int((ph_dist <= params.touch_tol).sum()) touches_lower = int((pl_dist <= params.touch_tol).sum()) loss_upper = float(np.mean(ph_dist)) if len(ph_dist) else float("inf") loss_lower = float(np.mean(pl_dist)) if len(pl_dist) else float("inf") if loss_upper > params.touch_loss_max or loss_lower > params.touch_loss_max: return invalid_result # Apex 计算 denom = a_u - a_l apex_x = float((b_l - b_u) / denom) if abs(denom) > 1e-12 else float("inf") # 突破判定 breakout_dir: Literal["up", "down", "none"] = "none" breakout_idx: Optional[int] = None if close[end] > upper_end * (1 + params.break_tol): breakout_dir = "up" breakout_idx = end elif close[end] < lower_end * (1 - params.break_tol): breakout_dir = "down" breakout_idx = end # 成交量确认 volume_confirmed: Optional[bool] = None volume_ratio = 1.0 if volume is not None and len(volume) >= params.vol_window: vol_ma = np.mean(volume[-params.vol_window:]) if vol_ma > 0: volume_ratio = volume[-1] / vol_ma if breakout_dir != "none": volume_confirmed = bool(volume[-1] > vol_ma * params.vol_k) # 假突破检测 (历史回测模式,实时无法得知) false_breakout: Optional[bool] = None # 注意: 这里是基于历史数据,无法检测假突破 # 假突破需要看"未来"数据,与当前设计不符 # 计算突破强度 strength_up, strength_down = calc_breakout_strength( close=close[end], upper_line=upper_end, lower_line=lower_end, volume_ratio=volume_ratio, width_ratio=width_ratio, ) return ConvergingTriangleResult( stock_idx=stock_idx, date_idx=date_idx, is_valid=True, breakout_strength_up=strength_up, breakout_strength_down=strength_down, upper_slope=a_u, lower_slope=a_l, width_ratio=width_ratio, touches_upper=touches_upper, touches_lower=touches_lower, apex_x=apex_x, breakout_dir=breakout_dir, volume_confirmed=volume_confirmed, false_breakout=false_breakout, window_start=start, window_end=end, ) # ============================================================================ # 批量滚动检测函数 # ============================================================================ def detect_converging_triangle_batch( open_mtx: np.ndarray, high_mtx: np.ndarray, low_mtx: np.ndarray, close_mtx: np.ndarray, volume_mtx: np.ndarray, params: ConvergingTriangleParams, start_day: Optional[int] = None, end_day: Optional[int] = None, only_valid: bool = False, verbose: bool = False, ) -> pd.DataFrame: """ 批量滚动检测收敛三角形 Args: open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx: OHLCV 矩阵, shape=(n_stocks, n_days) params: 检测参数 start_day: 从哪一天开始计算 (默认: window-1, 即第一个有足够历史的点) end_day: 到哪一天结束 (默认: 最后一天) only_valid: 是否只返回识别到三角形的记录 verbose: 是否打印进度 Returns: DataFrame with columns: - stock_idx, date_idx - is_valid - breakout_strength_up, breakout_strength_down - upper_slope, lower_slope, width_ratio - touches_upper, touches_lower, apex_x - breakout_dir, volume_confirmed, false_breakout - window_start, window_end """ n_stocks, n_days = close_mtx.shape window = params.window # 默认起止日 if start_day is None: start_day = window - 1 if end_day is None: end_day = n_days - 1 # 确保范围有效 start_day = max(window - 1, start_day) end_day = min(n_days - 1, end_day) results: List[dict] = [] total = n_stocks * (end_day - start_day + 1) processed = 0 for stock_idx in range(n_stocks): # 提取该股票的全部数据,并过滤 NaN high_stock = high_mtx[stock_idx, :] low_stock = low_mtx[stock_idx, :] close_stock = close_mtx[stock_idx, :] volume_stock = volume_mtx[stock_idx, :] if volume_mtx is not None else None # 找到有效数据的 mask(非 NaN) valid_mask = ~np.isnan(close_stock) valid_indices = np.where(valid_mask)[0] if len(valid_indices) < window: # 该股票有效数据不足一个窗口 for date_idx in range(start_day, end_day + 1): if not only_valid: results.append(ConvergingTriangleResult( stock_idx=stock_idx, date_idx=date_idx, is_valid=False, ).to_dict()) processed += 1 continue # 提取有效数据 high_valid = high_stock[valid_mask] low_valid = low_stock[valid_mask] close_valid = close_stock[valid_mask] volume_valid = volume_stock[valid_mask] if volume_stock is not None else None # 在有效数据上滚动 n_valid = len(close_valid) for valid_end in range(window - 1, n_valid): # 原始数据中的 date_idx orig_date_idx = valid_indices[valid_end] # 检查是否在指定范围内 if orig_date_idx < start_day or orig_date_idx > end_day: continue # 提取窗口 valid_start = valid_end - window + 1 high_win = high_valid[valid_start:valid_end + 1] low_win = low_valid[valid_start:valid_end + 1] close_win = close_valid[valid_start:valid_end + 1] volume_win = volume_valid[valid_start:valid_end + 1] if volume_valid is not None else None # 检测 result = detect_converging_triangle( high=high_win, low=low_win, close=close_win, volume=volume_win, params=params, stock_idx=stock_idx, date_idx=orig_date_idx, ) if only_valid and not result.is_valid: processed += 1 continue results.append(result.to_dict()) processed += 1 if verbose and (stock_idx + 1) % 10 == 0: print(f" Progress: {stock_idx + 1}/{n_stocks} stocks, {processed}/{total} points") if verbose: print(f" Completed: {processed} points processed") return pd.DataFrame(results)