""" 收敛三角形检测算法 (Converging Triangle Detection) 支持: - 二维矩阵批量输入 (n_stocks, n_days) - 历史滚动计算 (每个交易日往过去看 window 天) - 突破强度分数 (0~1) """ from __future__ import annotations from dataclasses import dataclass, field, asdict from typing import List, Literal, Optional, Tuple import numpy as np import pandas as pd # ============================================================================ # 参数对象 # ============================================================================ @dataclass class ConvergingTriangleParams: """收敛三角形检测参数""" # 窗口设置 window: int = 400 # 枢轴点检测 pivot_k: int = 20 # 边界线拟合 boundary_n_segments: int = 2 boundary_source: str = "full" # "full" | "pivots" # 斜率约束 upper_slope_max: float = 0.10 lower_slope_min: float = -0.10 # 触碰判定 touch_tol: float = 0.10 touch_loss_max: float = 0.10 # 收敛要求 shrink_ratio: float = 0.8 # 突破判定 break_tol: float = 0.001 vol_window: int = 20 vol_k: float = 1.3 false_break_m: int = 5 # ============================================================================ # 返回对象 # ============================================================================ @dataclass class ConvergingTriangleResult: """收敛三角形检测结果 (单个股票单个日期)""" # 基础标识 stock_idx: int date_idx: int is_valid: bool # 是否识别到有效三角形 # 突破强度 (0~1 连续分数) breakout_strength_up: float = 0.0 breakout_strength_down: float = 0.0 # 几何属性 upper_slope: float = 0.0 lower_slope: float = 0.0 width_ratio: float = 0.0 touches_upper: int = 0 touches_lower: int = 0 apex_x: float = 0.0 # 突破状态 breakout_dir: str = "none" # "up" | "down" | "none" volume_confirmed: Optional[bool] = None false_breakout: Optional[bool] = None # 窗口范围 window_start: int = 0 window_end: int = 0 def to_dict(self) -> dict: """转换为字典""" return asdict(self) # ============================================================================ # 基础工具函数 (从 sym_triangle.py 复用) # ============================================================================ def pivots_fractal( high: np.ndarray, low: np.ndarray, k: int = 3 ) -> Tuple[np.ndarray, np.ndarray]: """左右窗口分形:返回 pivot_high_idx, pivot_low_idx""" n = len(high) ph: List[int] = [] pl: List[int] = [] for i in range(k, n - k): if high[i] == np.max(high[i - k : i + k + 1]): ph.append(i) if low[i] == np.min(low[i - k : i + k + 1]): pl.append(i) return np.array(ph, dtype=int), np.array(pl, dtype=int) def fit_line(x: np.ndarray, y: np.ndarray) -> Tuple[float, float]: """拟合 y = a*x + b""" if len(x) < 2: return 0.0, float(y[0]) if len(y) > 0 else 0.0 a, b = np.polyfit(x, y, deg=1) return float(a), float(b) def fit_boundary_line( x: np.ndarray, y: np.ndarray, mode: str = "upper", n_segments: int = 3 ) -> Tuple[float, float]: """ 边界线拟合(分段取极值) - mode="upper": 每段取最高点 - mode="lower": 每段取最低点 """ if len(x) < 2: return fit_line(x, y) n_segments = min(n_segments, len(x)) if n_segments < 2: n_segments = 2 sort_idx = np.argsort(x) x_sorted = x[sort_idx] y_sorted = y[sort_idx] segment_size = len(x_sorted) // n_segments boundary_x = [] boundary_y = [] for i in range(n_segments): start = i * segment_size if i == n_segments - 1: end = len(x_sorted) else: end = (i + 1) * segment_size if start >= end: continue seg_x = x_sorted[start:end] seg_y = y_sorted[start:end] if mode == "upper": idx = np.argmax(seg_y) else: idx = np.argmin(seg_y) boundary_x.append(seg_x[idx]) boundary_y.append(seg_y[idx]) if len(boundary_x) < 2: return fit_line(x, y) return fit_line(np.array(boundary_x), np.array(boundary_y)) def line_y(a: float, b: float, x: np.ndarray) -> np.ndarray: """计算线上的 y 值""" return a * x + b def fit_pivot_line( pivot_indices: np.ndarray, pivot_values: np.ndarray, mode: str = "upper", min_points: int = 2, ) -> Tuple[float, float, np.ndarray]: """ 枢轴点连线法:选择合适的枢轴点连成边界线 上沿(upper):选择形成下降趋势的高点对 下沿(lower):选择形成上升趋势的低点对 Args: pivot_indices: 枢轴点的X坐标(索引) pivot_values: 枢轴点的Y值(价格) mode: "upper"(上沿) 或 "lower"(下沿) min_points: 最少需要的枢轴点数 Returns: (slope, intercept, selected_indices): 斜率、截距、选中的枢轴点索引 """ if len(pivot_indices) < min_points: return 0.0, 0.0, np.array([]) # 按时间排序 sort_idx = np.argsort(pivot_indices) x_sorted = pivot_indices[sort_idx].astype(float) y_sorted = pivot_values[sort_idx] n = len(x_sorted) if mode == "upper": # 上沿:寻找形成下降趋势的高点对 # 策略:选择前半部分最高点和后半部分最高点 mid = n // 2 if mid < 1: mid = 1 # 前半部分(包括中点)的最高点 front_idx = np.argmax(y_sorted[:mid + 1]) # 后半部分的最高点 back_idx = mid + np.argmax(y_sorted[mid:]) # 如果后点比前点高,尝试找其他组合 if y_sorted[back_idx] > y_sorted[front_idx]: # 尝试用全局最高点作为前点 global_max_idx = np.argmax(y_sorted) if global_max_idx < n - 1: # 在最高点之后找第二高的点 remaining_idx = np.argmax(y_sorted[global_max_idx + 1:]) + global_max_idx + 1 front_idx = global_max_idx back_idx = remaining_idx else: # 最高点在最后,找前面第二高的点 front_idx = np.argmax(y_sorted[:-1]) back_idx = global_max_idx selected = np.array([front_idx, back_idx]) else: # mode == "lower" # 下沿:寻找形成上升趋势的低点对 # 策略:选择前半部分最低点和后半部分最低点 mid = n // 2 if mid < 1: mid = 1 # 前半部分(包括中点)的最低点 front_idx = np.argmin(y_sorted[:mid + 1]) # 后半部分的最低点 back_idx = mid + np.argmin(y_sorted[mid:]) # 如果后点比前点低,尝试找其他组合 if y_sorted[back_idx] < y_sorted[front_idx]: # 尝试用全局最低点作为前点 global_min_idx = np.argmin(y_sorted) if global_min_idx < n - 1: # 在最低点之后找第二低的点 remaining_idx = np.argmin(y_sorted[global_min_idx + 1:]) + global_min_idx + 1 front_idx = global_min_idx back_idx = remaining_idx else: # 最低点在最后,找前面第二低的点 front_idx = np.argmin(y_sorted[:-1]) back_idx = global_min_idx selected = np.array([front_idx, back_idx]) # 确保选择的两个点不同 if front_idx == back_idx: # 如果只有一个点,尝试用第一个和最后一个 if n >= 2: selected = np.array([0, n - 1]) else: return 0.0, float(y_sorted[0]), np.array([sort_idx[0]]) # 计算斜率和截距 x1, x2 = x_sorted[selected[0]], x_sorted[selected[1]] y1, y2 = y_sorted[selected[0]], y_sorted[selected[1]] if abs(x2 - x1) < 1e-9: slope = 0.0 intercept = (y1 + y2) / 2 else: slope = (y2 - y1) / (x2 - x1) intercept = y1 - slope * x1 # 返回原始索引顺序中的选中点 selected_original = sort_idx[selected] return float(slope), float(intercept), selected_original # ============================================================================ # 突破强度计算 # ============================================================================ def calc_breakout_strength( close: float, upper_line: float, lower_line: float, volume_ratio: float, width_ratio: float, ) -> Tuple[float, float]: """ 计算向上/向下突破强度 (0~1) 使用加权求和,各分量权重: - 突破幅度分 (60%): tanh 非线性归一化,3%突破≈0.42,5%突破≈0.64,10%突破≈0.91 - 收敛分 (25%): 1 - width_ratio,收敛越强分数越高 - 成交量分 (15%): 放量程度,2倍放量=满分 突破幅度分布参考(使用 tanh(pct * 15)): - 1% 突破 → 0.15 - 2% 突破 → 0.29 - 3% 突破 → 0.42 - 5% 突破 → 0.64 - 8% 突破 → 0.83 - 10% 突破 → 0.91 Args: close: 收盘价 upper_line: 上沿价格 lower_line: 下沿价格 volume_ratio: 成交量相对均值的倍数 width_ratio: 末端宽度/起始宽度 Returns: (strength_up, strength_down) """ import math # 权重配置 W_PRICE = 0.60 # 突破幅度权重 W_CONVERGENCE = 0.25 # 收敛度权重 W_VOLUME = 0.15 # 成交量权重 TANH_SCALE = 15.0 # tanh 缩放因子 # 1. 价格突破分数(tanh 非线性归一化) if upper_line > 0: pct_up = max(0.0, (close - upper_line) / upper_line) price_score_up = math.tanh(pct_up * TANH_SCALE) else: price_score_up = 0.0 if lower_line > 0: pct_down = max(0.0, (lower_line - close) / lower_line) price_score_down = math.tanh(pct_down * TANH_SCALE) else: price_score_down = 0.0 # 2. 收敛分数(width_ratio 越小越好) convergence_score = max(0.0, min(1.0, 1.0 - width_ratio)) # 3. 成交量分数(vol_ratio > 1 时才有分) vol_score = min(1.0, max(0.0, volume_ratio - 1.0)) if volume_ratio > 0 else 0.0 # 4. 加权求和 # 只有发生突破(price_score > 0)时才计算完整强度 if price_score_up > 0: strength_up = ( W_PRICE * price_score_up + W_CONVERGENCE * convergence_score + W_VOLUME * vol_score ) else: strength_up = 0.0 if price_score_down > 0: strength_down = ( W_PRICE * price_score_down + W_CONVERGENCE * convergence_score + W_VOLUME * vol_score ) else: strength_down = 0.0 return min(1.0, strength_up), min(1.0, strength_down) # ============================================================================ # 单点检测函数 # ============================================================================ def detect_converging_triangle( high: np.ndarray, low: np.ndarray, close: np.ndarray, volume: Optional[np.ndarray], params: ConvergingTriangleParams, stock_idx: int = 0, date_idx: int = 0, ) -> ConvergingTriangleResult: """ 检测单个窗口是否存在收敛三角形 Args: high, low, close: 窗口内的价格数据 (长度 = window) volume: 窗口内的成交量数据 (可选) params: 检测参数 stock_idx: 股票索引 (用于结果标识) date_idx: 日期索引 (用于结果标识) Returns: ConvergingTriangleResult """ n = len(close) window = params.window # 创建默认无效结果 invalid_result = ConvergingTriangleResult( stock_idx=stock_idx, date_idx=date_idx, is_valid=False, window_start=max(0, n - window), window_end=n - 1, ) # 数据长度检查 if n < max(window, 2 * params.pivot_k + 5): return invalid_result # 计算枢轴点 ph_idx, pl_idx = pivots_fractal(high, low, k=params.pivot_k) end = n - 1 start = max(0, end - window + 1) x_all = np.arange(n, dtype=float) # 筛选窗口内的枢轴点 ph_in = ph_idx[(ph_idx >= start) & (ph_idx <= end)] pl_in = pl_idx[(pl_idx >= start) & (pl_idx <= end)] if len(ph_in) < 2 or len(pl_in) < 2: return invalid_result # 使用枢轴点连线法拟合边界线 # 上沿:连接高点枢轴点,形成下降趋势 a_u, b_u, selected_ph = fit_pivot_line( pivot_indices=ph_in, pivot_values=high[ph_in], mode="upper", ) # 下沿:连接低点枢轴点,形成上升趋势 a_l, b_l, selected_pl = fit_pivot_line( pivot_indices=pl_in, pivot_values=low[pl_in], mode="lower", ) if len(selected_ph) < 2 or len(selected_pl) < 2: return invalid_result # 斜率检查 if not (a_u <= params.upper_slope_max and a_l >= params.lower_slope_min): return invalid_result # 宽度收敛检查 upper_start = float(line_y(a_u, b_u, np.array([start]))[0]) lower_start = float(line_y(a_l, b_l, np.array([start]))[0]) upper_end = float(line_y(a_u, b_u, np.array([end]))[0]) lower_end = float(line_y(a_l, b_l, np.array([end]))[0]) width_start = upper_start - lower_start width_end = upper_end - lower_end if width_start <= 0 or width_end <= 0: return invalid_result width_ratio = width_end / width_start if width_ratio > params.shrink_ratio: return invalid_result # 触碰检测(枢轴点连线法) # 上沿:检查高点是否在上沿线下方或接近(不能大幅超过) upper_line_at_ph = line_y(a_u, b_u, x_all[ph_in].astype(float)) ph_deviation = (high[ph_in] - upper_line_at_ph) / np.maximum(upper_line_at_ph, 1e-9) # 触碰 = 高点接近上沿线(在容差范围内) touches_upper = int((np.abs(ph_deviation) <= params.touch_tol).sum()) # 违规 = 高点大幅超过上沿线 violations_upper = int((ph_deviation > params.touch_tol).sum()) # 下沿:检查低点是否在下沿线上方或接近(不能大幅低于) lower_line_at_pl = line_y(a_l, b_l, x_all[pl_in].astype(float)) pl_deviation = (lower_line_at_pl - low[pl_in]) / np.maximum(lower_line_at_pl, 1e-9) # 触碰 = 低点接近下沿线(在容差范围内) touches_lower = int((np.abs(pl_deviation) <= params.touch_tol).sum()) # 违规 = 低点大幅低于下沿线 violations_lower = int((pl_deviation > params.touch_tol).sum()) # 验证:违规点不能太多(允许少量异常) max_violations = max(1, len(ph_in) // 3) # 最多1/3的点可以违规 if violations_upper > max_violations or violations_lower > max_violations: return invalid_result # 确保至少有2个触碰点(包括选中的枢轴点) if touches_upper < 2 or touches_lower < 2: return invalid_result # Apex 计算 denom = a_u - a_l apex_x = float((b_l - b_u) / denom) if abs(denom) > 1e-12 else float("inf") # 突破判定 breakout_dir: Literal["up", "down", "none"] = "none" breakout_idx: Optional[int] = None if close[end] > upper_end * (1 + params.break_tol): breakout_dir = "up" breakout_idx = end elif close[end] < lower_end * (1 - params.break_tol): breakout_dir = "down" breakout_idx = end # 成交量确认 volume_confirmed: Optional[bool] = None volume_ratio = 1.0 if volume is not None and len(volume) >= params.vol_window: vol_ma = np.mean(volume[-params.vol_window:]) if vol_ma > 0: volume_ratio = volume[-1] / vol_ma if breakout_dir != "none": volume_confirmed = bool(volume[-1] > vol_ma * params.vol_k) # 假突破检测 (历史回测模式,实时无法得知) false_breakout: Optional[bool] = None # 注意: 这里是基于历史数据,无法检测假突破 # 假突破需要看"未来"数据,与当前设计不符 # 计算突破强度 strength_up, strength_down = calc_breakout_strength( close=close[end], upper_line=upper_end, lower_line=lower_end, volume_ratio=volume_ratio, width_ratio=width_ratio, ) return ConvergingTriangleResult( stock_idx=stock_idx, date_idx=date_idx, is_valid=True, breakout_strength_up=strength_up, breakout_strength_down=strength_down, upper_slope=a_u, lower_slope=a_l, width_ratio=width_ratio, touches_upper=touches_upper, touches_lower=touches_lower, apex_x=apex_x, breakout_dir=breakout_dir, volume_confirmed=volume_confirmed, false_breakout=false_breakout, window_start=start, window_end=end, ) # ============================================================================ # 批量滚动检测函数 # ============================================================================ def detect_converging_triangle_batch( open_mtx: np.ndarray, high_mtx: np.ndarray, low_mtx: np.ndarray, close_mtx: np.ndarray, volume_mtx: np.ndarray, params: ConvergingTriangleParams, start_day: Optional[int] = None, end_day: Optional[int] = None, only_valid: bool = False, verbose: bool = False, ) -> pd.DataFrame: """ 批量滚动检测收敛三角形 Args: open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx: OHLCV 矩阵, shape=(n_stocks, n_days) params: 检测参数 start_day: 从哪一天开始计算 (默认: window-1, 即第一个有足够历史的点) end_day: 到哪一天结束 (默认: 最后一天) only_valid: 是否只返回识别到三角形的记录 verbose: 是否打印进度 Returns: DataFrame with columns: - stock_idx, date_idx - is_valid - breakout_strength_up, breakout_strength_down - upper_slope, lower_slope, width_ratio - touches_upper, touches_lower, apex_x - breakout_dir, volume_confirmed, false_breakout - window_start, window_end """ n_stocks, n_days = close_mtx.shape window = params.window # 默认起止日 if start_day is None: start_day = window - 1 if end_day is None: end_day = n_days - 1 # 确保范围有效 start_day = max(window - 1, start_day) end_day = min(n_days - 1, end_day) results: List[dict] = [] total = n_stocks * (end_day - start_day + 1) processed = 0 for stock_idx in range(n_stocks): # 提取该股票的全部数据,并过滤 NaN high_stock = high_mtx[stock_idx, :] low_stock = low_mtx[stock_idx, :] close_stock = close_mtx[stock_idx, :] volume_stock = volume_mtx[stock_idx, :] if volume_mtx is not None else None # 找到有效数据的 mask(非 NaN) valid_mask = ~np.isnan(close_stock) valid_indices = np.where(valid_mask)[0] if len(valid_indices) < window: # 该股票有效数据不足一个窗口 for date_idx in range(start_day, end_day + 1): if not only_valid: results.append(ConvergingTriangleResult( stock_idx=stock_idx, date_idx=date_idx, is_valid=False, ).to_dict()) processed += 1 continue # 提取有效数据 high_valid = high_stock[valid_mask] low_valid = low_stock[valid_mask] close_valid = close_stock[valid_mask] volume_valid = volume_stock[valid_mask] if volume_stock is not None else None # 在有效数据上滚动 n_valid = len(close_valid) for valid_end in range(window - 1, n_valid): # 原始数据中的 date_idx orig_date_idx = valid_indices[valid_end] # 检查是否在指定范围内 if orig_date_idx < start_day or orig_date_idx > end_day: continue # 提取窗口 valid_start = valid_end - window + 1 high_win = high_valid[valid_start:valid_end + 1] low_win = low_valid[valid_start:valid_end + 1] close_win = close_valid[valid_start:valid_end + 1] volume_win = volume_valid[valid_start:valid_end + 1] if volume_valid is not None else None # 检测 result = detect_converging_triangle( high=high_win, low=low_win, close=close_win, volume=volume_win, params=params, stock_idx=stock_idx, date_idx=orig_date_idx, ) if only_valid and not result.is_valid: processed += 1 continue results.append(result.to_dict()) processed += 1 if verbose and (stock_idx + 1) % 10 == 0: print(f" Progress: {stock_idx + 1}/{n_stocks} stocks, {processed}/{total} points") if verbose: print(f" Completed: {processed} points processed") return pd.DataFrame(results)