technical-patterns-lab/scripts/archive/run_sym_triangle_pkl.py
褚宏光 543572667b Add initial implementation of converging triangle detection algorithm and related documentation
- Created README.md and USAGE.md for project overview and usage instructions.
- Added core algorithm in src/converging_triangle.py for batch processing of stock data.
- Introduced data files (open.pkl, high.pkl, low.pkl, close.pkl, volume.pkl) for OHLCV data.
- Developed output documentation for results and breakout strength calculations.
- Implemented scripts for running the detection and generating reports.
- Added SVG visualizations and markdown documentation for algorithm details and usage examples.
2026-01-21 18:02:58 +08:00

280 lines
8.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
批量识别对称三角形 - 从 pkl 文件读取 OHLCV 数据
每个 pkl 文件包含 108 个股票 × N 个交易日的矩阵
"""
import os
import sys
import pickle
import numpy as np
import pandas as pd
# 让脚本能找到 src/ 下的模块
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "src"))
from sym_triangle import detect_sym_triangle, line_y, pivots_fractal, fit_boundary_line
# ============================================================================
# 【可调参数区】
# ============================================================================
# --- 数据源 ---
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "data")
# --- 窗口大小 ---
WINDOW = 400
# --- 枢轴点检测 ---
PIVOT_K = 20
# --- 边界线拟合 ---
BOUNDARY_N_SEGMENTS = 2
BOUNDARY_SOURCE = "full"
# --- 斜率约束 ---
UPPER_SLOPE_MAX = 0.10
LOWER_SLOPE_MIN = -0.10
# --- 触碰判定 ---
TOUCH_TOL = 0.10
TOUCH_LOSS_MAX = 0.10
# --- 收敛要求 ---
SHRINK_RATIO = 0.8
# --- 突破判定 ---
BREAK_TOL = 0.001
VOL_WINDOW = 20
VOL_K = 1.3
FALSE_BREAK_M = 5
# --- 输出控制 ---
PRINT_DEBUG = False # 批量时关闭调试输出
SAVE_ALL_CHARTS = False # True=保存所有股票图False=只保存识别到的
# ============================================================================
# pkl 数据加载
# ============================================================================
class FakeModule:
"""空壳模块,绕过 model 依赖"""
ndarray = np.ndarray
def load_pkl(pkl_path: str) -> dict:
"""加载 pkl 文件,返回字典 {mtx, dtes, tkrs, tkrs_name, ...}"""
# 注入空壳模块
sys.modules['model'] = FakeModule()
sys.modules['model.index_info'] = FakeModule()
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data
def load_ohlcv_from_pkl(data_dir: str) -> tuple:
"""
从 pkl 文件加载 OHLCV 数据
Returns:
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx: shape=(n_stocks, n_days)
dates: shape=(n_days,) 真实日期 (如 20050104)
tkrs: shape=(n_stocks,) 股票代码
tkrs_name: shape=(n_stocks,) 股票名称
"""
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
# 使用 close 的元数据
dates = close_data["dtes"]
tkrs = close_data["tkrs"]
tkrs_name = close_data["tkrs_name"]
return (
open_data["mtx"],
high_data["mtx"],
low_data["mtx"],
close_data["mtx"],
volume_data["mtx"],
dates,
tkrs,
tkrs_name,
)
def get_stock_df(
stock_idx: int,
open_mtx: np.ndarray,
high_mtx: np.ndarray,
low_mtx: np.ndarray,
close_mtx: np.ndarray,
volume_mtx: np.ndarray,
dates: np.ndarray,
) -> pd.DataFrame:
"""提取单个股票的 DataFrame"""
df = pd.DataFrame({
"date": dates,
"open": open_mtx[stock_idx, :],
"high": high_mtx[stock_idx, :],
"low": low_mtx[stock_idx, :],
"close": close_mtx[stock_idx, :],
"volume": volume_mtx[stock_idx, :],
})
# 过滤掉 NaN/0 值
df = df.replace(0, np.nan).dropna().reset_index(drop=True)
return df
# ============================================================================
# 绘图
# ============================================================================
def plot_sym_triangle(df: pd.DataFrame, res, stock_id: str, out_path: str) -> None:
import matplotlib.pyplot as plt
close = df["close"].to_numpy(dtype=float)
x = np.arange(len(df), dtype=float)
dates = df["date"].to_numpy()
a_u, b_u = res.upper_coef
a_l, b_l = res.lower_coef
start, end = res.start, res.end
xw = np.arange(start, end + 1, dtype=float)
upper = line_y(a_u, b_u, xw)
lower = line_y(a_l, b_l, xw)
plt.figure(figsize=(12, 5))
plt.plot(x, close, linewidth=1.2, label="close")
plt.plot(xw, upper, linewidth=2, label="upper")
plt.plot(xw, lower, linewidth=2, label="lower")
plt.axvline(end, color="gray", linestyle="--", linewidth=1)
start_date = dates[start] if len(dates) > start else start
end_date = dates[end] if len(dates) > end else end
plt.title(
f"[{stock_id}] sym_triangle: "
f"range={start_date}-{end_date}, "
f"slope=({a_u:.4f},{a_l:.4f}), "
f"width_ratio={res.width_ratio:.2f}, "
f"touches=({res.touches_upper},{res.touches_lower})"
)
if len(dates) > 0:
step = max(1, len(dates) // 8)
idx = np.arange(0, len(dates), step)
plt.xticks(idx, dates[idx], rotation=45, ha="right")
plt.legend()
plt.tight_layout()
plt.savefig(out_path, dpi=150)
plt.close()
# ============================================================================
# 主流程
# ============================================================================
def main() -> None:
print("=" * 60)
print("Symmetric Triangle Batch Detection - from pkl files")
print("=" * 60)
# 1. 加载数据
print("\n[1] Loading OHLCV pkl files...")
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(DATA_DIR)
n_stocks, n_days = close_mtx.shape
print(f" Stocks: {n_stocks}")
print(f" Days: {n_days}")
print(f" Date range: {dates[0]} ~ {dates[-1]}")
# 2. 准备输出目录
outputs_dir = os.path.join(os.path.dirname(__file__), "..", "..", "outputs", "sym_triangles")
os.makedirs(outputs_dir, exist_ok=True)
# 3. 遍历所有股票
print(f"\n[2] Scanning {n_stocks} stocks...")
detected = []
for i in range(n_stocks):
stock_code = tkrs[i]
stock_name = tkrs_name[i]
stock_id = f"{stock_code}" # 使用真实股票代码
# 提取单只股票数据
df = get_stock_df(i, open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates)
if len(df) < WINDOW:
if PRINT_DEBUG:
print(f" [{stock_id}] 跳过: 数据不足 ({len(df)} < {WINDOW})")
continue
# 运行检测
res = detect_sym_triangle(
df,
window=WINDOW,
pivot_k=PIVOT_K,
touch_tol=TOUCH_TOL,
touch_loss_max=TOUCH_LOSS_MAX,
shrink_ratio=SHRINK_RATIO,
break_tol=BREAK_TOL,
vol_window=VOL_WINDOW,
vol_k=VOL_K,
false_break_m=FALSE_BREAK_M,
upper_slope_max=UPPER_SLOPE_MAX,
lower_slope_min=LOWER_SLOPE_MIN,
boundary_fit=True,
boundary_n_segments=BOUNDARY_N_SEGMENTS,
boundary_source=BOUNDARY_SOURCE,
)
if res is not None:
detected.append((i, stock_id, res, df))
out_path = os.path.join(outputs_dir, f"{stock_id}.png")
plot_sym_triangle(df, res, stock_id, out_path)
print(f" [OK] {stock_id} -> {out_path}")
elif PRINT_DEBUG:
print(f" [--] {stock_id} not detected")
# 4. 汇总结果
print("\n" + "=" * 60)
print(f"Scan completed! {len(detected)}/{n_stocks} stocks have symmetric triangles")
print("=" * 60)
if detected:
print("\nDetected stocks:")
for i, stock_id, res, _ in detected:
name = tkrs_name[i]
print(f" - {stock_id} ({name}): slope=({res.upper_coef[0]:.4f},{res.lower_coef[0]:.4f}), "
f"width_ratio={res.width_ratio:.2f}, breakout={res.breakout}")
print(f"\nCharts saved to: {outputs_dir}")
# 5. Save summary CSV
if detected:
summary_path = os.path.join(outputs_dir, "summary.csv")
summary_data = []
for i, stock_id, res, _ in detected:
summary_data.append({
"stock_idx": i,
"stock_code": stock_id,
"stock_name": tkrs_name[i],
"start_date": dates[res.start],
"end_date": dates[res.end],
"upper_slope": res.upper_coef[0],
"lower_slope": res.lower_coef[0],
"width_ratio": res.width_ratio,
"touches_upper": res.touches_upper,
"touches_lower": res.touches_lower,
"breakout": res.breakout,
})
pd.DataFrame(summary_data).to_csv(summary_path, index=False, encoding="utf-8-sig")
print(f"Summary saved: {summary_path}")
if __name__ == "__main__":
main()