import json import os import sys import numpy as np import pandas as pd # 让脚本能找到 src/ 下的模块 sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "src")) from sym_triangle import detect_sym_triangle, line_y, pivots_fractal, fit_line, fit_boundary_line # ============================================================================ # 【可调参数区】- 在这里修改参数,然后重新运行脚本 # ============================================================================ # --- 窗口大小 --- # 从最新数据点往前取多少个点作为分析窗口 # 例如:400 表示分析最近 400 个交易日 WINDOW = 400 # --- 数据源 --- # OHLCV 文件目录(包含 open/high/low/close/volume.json) DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "data") # --- 枢轴点检测 --- # pivot_k: 左右窗口大小,越大找到的枢轴点越少(更明显的峰/谷) # 建议范围:5~30 PIVOT_K = 20 # --- 边界线拟合 --- # boundary_n_segments: 把窗口分成几段,每段取一个极值点来拟合 # 越大 → 拟合用的点越多 → 线越"平均" # 越小 → 拟合用的点越少 → 线越贴近极端值 # 建议范围:2~5 BOUNDARY_N_SEGMENTS = 2 # boundary_source: # - "pivots": 只用枢轴点来拟合(更稳定) # - "full": 用全量 high/low 来分段取极值(更贴边) BOUNDARY_SOURCE = "full" # --- 斜率约束 --- # upper_slope_max: 上沿斜率的最大值 # - 设为 0:严格要求上沿下降(标准对称三角形) # - 设为正数(如 0.05):允许上沿略微上升(放宽条件) # - 设为负数(如 -0.01):要求上沿必须明显下降 UPPER_SLOPE_MAX = 0.10 # lower_slope_min: 下沿斜率的最小值 # - 设为 0:严格要求下沿上升 # - 设为负数(如 -0.1):允许下沿略微下降(放宽条件) LOWER_SLOPE_MIN = -0.10 # --- 触碰判定 --- # touch_tol: 枢轴点"触碰"线的容差(百分比) # 越大 → 判定越宽松 TOUCH_TOL = 0.10 # touch_loss_max: 平均相对偏差上限(损失函数) TOUCH_LOSS_MAX = 0.10 # 是否打印调试信息(包含 loss_upper/loss_lower) PRINT_DEBUG = True # --- 收敛要求 --- # shrink_ratio: 三角形末端宽度 / 起始宽度 的最大值 # 越小 → 要求收敛越明显 SHRINK_RATIO = 0.8 # --- 突破判定 --- BREAK_TOL = 0.001 VOL_WINDOW = 20 VOL_K = 1.3 FALSE_BREAK_M = 5 # ============================================================================ def load_series_from_json(json_path: str, name: str) -> pd.DataFrame: """从单个 JSON 中读取 labels/values,返回 date + 指标列。""" with open(json_path, "r", encoding="utf-8") as f: raw = json.load(f) data = raw.get("data", {}) labels = data.get("labels", []) values = data.get("values", []) if not labels or not values or len(labels) != len(values): raise ValueError(f"{name}.json 中未找到等长的 labels / values") df = pd.DataFrame({"date": labels, name: values}) df["date"] = pd.to_numeric(df["date"], errors="coerce") df[name] = pd.to_numeric(df[name], errors="coerce") return df.dropna(subset=["date", name]).reset_index(drop=True) def load_ohlcv_from_dir(data_dir: str) -> pd.DataFrame: """从目录读取 open/high/low/close/volume.json,并按 date 对齐。""" open_df = load_series_from_json(os.path.join(data_dir, "open.json"), "open") high_df = load_series_from_json(os.path.join(data_dir, "high.json"), "high") low_df = load_series_from_json(os.path.join(data_dir, "low.json"), "low") close_df = load_series_from_json(os.path.join(data_dir, "close.json"), "close") volume_df = load_series_from_json(os.path.join(data_dir, "volume.json"), "volume") df = open_df.merge(high_df, on="date", how="inner") df = df.merge(low_df, on="date", how="inner") df = df.merge(close_df, on="date", how="inner") df = df.merge(volume_df, on="date", how="inner") return df.sort_values("date").reset_index(drop=True) def plot_sym_triangle(df: pd.DataFrame, res, out_path: str) -> None: import matplotlib.pyplot as plt close = df["close"].to_numpy(dtype=float) x = np.arange(len(df), dtype=float) dates = df["date"].to_numpy() a_u, b_u = res.upper_coef a_l, b_l = res.lower_coef start, end = res.start, res.end xw = np.arange(start, end + 1, dtype=float) upper = line_y(a_u, b_u, xw) lower = line_y(a_l, b_l, xw) plt.figure(figsize=(12, 5)) plt.plot(x, close, linewidth=1.2, label="close") plt.plot(xw, upper, linewidth=2, label="upper") plt.plot(xw, lower, linewidth=2, label="lower") plt.axvline(end, color="gray", linestyle="--", linewidth=1) start_date = dates[start] if len(dates) > 0 else start end_date = dates[end] if len(dates) > 0 else end plt.title( "sym_triangle: " f"range={start_date}-{end_date}, " f"slope=({a_u:.4f},{a_l:.4f}), " f"width_ratio={res.width_ratio:.2f}, " f"touches=({res.touches_upper},{res.touches_lower})" ) # 稀疏显示日期标签(防止拥挤) if len(dates) > 0: step = max(1, len(dates) // 8) idx = np.arange(0, len(dates), step) plt.xticks(idx, dates[idx], rotation=45, ha="right") plt.legend() plt.tight_layout() plt.savefig(out_path, dpi=150) plt.show() def debug_latest_window( df: pd.DataFrame, window: int, pivot_k: int, touch_tol: float, shrink_ratio: float, boundary_fit: bool = True, boundary_n_segments: int = 3, boundary_source: str = "pivots", lower_slope_min: float = 0.0, ) -> None: """打印最近窗口的关键诊断指标,定位未识别的原因。""" n = len(df) end = n - 1 start = max(0, end - window + 1) high = df["high"].to_numpy(dtype=float) low = df["low"].to_numpy(dtype=float) x_all = np.arange(n, dtype=float) ph_idx, pl_idx = pivots_fractal(high, low, k=pivot_k) ph_in = ph_idx[(ph_idx >= start) & (ph_idx <= end)] pl_in = pl_idx[(pl_idx >= start) & (pl_idx <= end)] print(f"window=[{start},{end}], len={window}, pivots_high={len(ph_in)}, pivots_low={len(pl_in)}") if len(ph_in) < 2 or len(pl_in) < 2: print("诊断:枢轴点不足(high/low pivots < 2)。") return if boundary_fit: if boundary_source == "full": x_upper = x_all[start : end + 1] y_upper = high[start : end + 1] x_lower = x_all[start : end + 1] y_lower = low[start : end + 1] else: x_upper = x_all[ph_in] y_upper = high[ph_in] x_lower = x_all[pl_in] y_lower = low[pl_in] a_u, b_u = fit_boundary_line(x_upper, y_upper, mode="upper", n_segments=boundary_n_segments) a_l, b_l = fit_boundary_line(x_lower, y_lower, mode="lower", n_segments=boundary_n_segments) else: a_u, b_u = fit_line(x_all[ph_in], high[ph_in]) a_l, b_l = fit_line(x_all[pl_in], low[pl_in]) upper_start = float(line_y(a_u, b_u, np.array([start]))[0]) lower_start = float(line_y(a_l, b_l, np.array([start]))[0]) upper_end = float(line_y(a_u, b_u, np.array([end]))[0]) lower_end = float(line_y(a_l, b_l, np.array([end]))[0]) width_start = upper_start - lower_start width_end = upper_end - lower_end width_ratio = width_end / width_start if width_start > 0 else float("inf") ph_dist = np.abs(high[ph_in] - line_y(a_u, b_u, x_all[ph_in])) / np.maximum( line_y(a_u, b_u, x_all[ph_in]), 1e-9 ) pl_dist = np.abs(low[pl_in] - line_y(a_l, b_l, x_all[pl_in])) / np.maximum( line_y(a_l, b_l, x_all[pl_in]), 1e-9 ) touches_upper = int((ph_dist <= touch_tol).sum()) touches_lower = int((pl_dist <= touch_tol).sum()) loss_upper = float(np.mean(ph_dist)) if len(ph_dist) else float("inf") loss_lower = float(np.mean(pl_dist)) if len(pl_dist) else float("inf") print( f"a_u={a_u:.6f}, a_l={a_l:.6f} " f"(need a_u<=upper_slope_max, a_l>={lower_slope_min})" ) print(f"width_ratio={width_ratio:.3f} (need <= {shrink_ratio})") print( f"touches_upper={touches_upper}, touches_lower={touches_lower} " f"(loss_upper={loss_upper:.4f}, loss_lower={loss_lower:.4f}, " f"need <= {TOUCH_LOSS_MAX})" ) def main() -> None: df = load_ohlcv_from_dir(DATA_DIR) # 只分析“最近一个窗口”(从最新点往过去) res = detect_sym_triangle( df, window=WINDOW, pivot_k=PIVOT_K, touch_tol=TOUCH_TOL, touch_loss_max=TOUCH_LOSS_MAX, shrink_ratio=SHRINK_RATIO, break_tol=BREAK_TOL, vol_window=VOL_WINDOW, vol_k=VOL_K, false_break_m=FALSE_BREAK_M, upper_slope_max=UPPER_SLOPE_MAX, lower_slope_min=LOWER_SLOPE_MIN, boundary_fit=True, boundary_n_segments=BOUNDARY_N_SEGMENTS, boundary_source=BOUNDARY_SOURCE, ) if PRINT_DEBUG: debug_latest_window( df, window=WINDOW, pivot_k=PIVOT_K, touch_tol=TOUCH_TOL, shrink_ratio=SHRINK_RATIO, boundary_fit=True, boundary_n_segments=BOUNDARY_N_SEGMENTS, boundary_source=BOUNDARY_SOURCE, lower_slope_min=LOWER_SLOPE_MIN, ) if res is None: print("未识别到对称三角形") return print(res) outputs_dir = os.path.join(os.path.dirname(__file__), "..", "..", "outputs") os.makedirs(outputs_dir, exist_ok=True) out_path = os.path.join(outputs_dir, "sym_triangle_result.png") plot_sym_triangle(df, res, out_path) print(f"图已保存:{out_path}") if __name__ == "__main__": main()