- Created README.md and USAGE.md for project overview and usage instructions. - Added core algorithm in src/converging_triangle.py for batch processing of stock data. - Introduced data files (open.pkl, high.pkl, low.pkl, close.pkl, volume.pkl) for OHLCV data. - Developed output documentation for results and breakout strength calculations. - Implemented scripts for running the detection and generating reports. - Added SVG visualizations and markdown documentation for algorithm details and usage examples.
280 lines
8.6 KiB
Python
280 lines
8.6 KiB
Python
"""
|
||
批量识别对称三角形 - 从 pkl 文件读取 OHLCV 数据
|
||
每个 pkl 文件包含 108 个股票 × N 个交易日的矩阵
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import pickle
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
# 让脚本能找到 src/ 下的模块
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "src"))
|
||
|
||
from sym_triangle import detect_sym_triangle, line_y, pivots_fractal, fit_boundary_line
|
||
|
||
# ============================================================================
|
||
# 【可调参数区】
|
||
# ============================================================================
|
||
|
||
# --- 数据源 ---
|
||
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "data")
|
||
|
||
# --- 窗口大小 ---
|
||
WINDOW = 400
|
||
|
||
# --- 枢轴点检测 ---
|
||
PIVOT_K = 20
|
||
|
||
# --- 边界线拟合 ---
|
||
BOUNDARY_N_SEGMENTS = 2
|
||
BOUNDARY_SOURCE = "full"
|
||
|
||
# --- 斜率约束 ---
|
||
UPPER_SLOPE_MAX = 0.10
|
||
LOWER_SLOPE_MIN = -0.10
|
||
|
||
# --- 触碰判定 ---
|
||
TOUCH_TOL = 0.10
|
||
TOUCH_LOSS_MAX = 0.10
|
||
|
||
# --- 收敛要求 ---
|
||
SHRINK_RATIO = 0.8
|
||
|
||
# --- 突破判定 ---
|
||
BREAK_TOL = 0.001
|
||
VOL_WINDOW = 20
|
||
VOL_K = 1.3
|
||
FALSE_BREAK_M = 5
|
||
|
||
# --- 输出控制 ---
|
||
PRINT_DEBUG = False # 批量时关闭调试输出
|
||
SAVE_ALL_CHARTS = False # True=保存所有股票图,False=只保存识别到的
|
||
|
||
|
||
# ============================================================================
|
||
# pkl 数据加载
|
||
# ============================================================================
|
||
|
||
class FakeModule:
|
||
"""空壳模块,绕过 model 依赖"""
|
||
ndarray = np.ndarray
|
||
|
||
|
||
def load_pkl(pkl_path: str) -> dict:
|
||
"""加载 pkl 文件,返回字典 {mtx, dtes, tkrs, tkrs_name, ...}"""
|
||
# 注入空壳模块
|
||
sys.modules['model'] = FakeModule()
|
||
sys.modules['model.index_info'] = FakeModule()
|
||
|
||
with open(pkl_path, 'rb') as f:
|
||
data = pickle.load(f)
|
||
return data
|
||
|
||
|
||
def load_ohlcv_from_pkl(data_dir: str) -> tuple:
|
||
"""
|
||
从 pkl 文件加载 OHLCV 数据
|
||
|
||
Returns:
|
||
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx: shape=(n_stocks, n_days)
|
||
dates: shape=(n_days,) 真实日期 (如 20050104)
|
||
tkrs: shape=(n_stocks,) 股票代码
|
||
tkrs_name: shape=(n_stocks,) 股票名称
|
||
"""
|
||
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
|
||
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
|
||
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
|
||
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
|
||
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
|
||
|
||
# 使用 close 的元数据
|
||
dates = close_data["dtes"]
|
||
tkrs = close_data["tkrs"]
|
||
tkrs_name = close_data["tkrs_name"]
|
||
|
||
return (
|
||
open_data["mtx"],
|
||
high_data["mtx"],
|
||
low_data["mtx"],
|
||
close_data["mtx"],
|
||
volume_data["mtx"],
|
||
dates,
|
||
tkrs,
|
||
tkrs_name,
|
||
)
|
||
|
||
|
||
def get_stock_df(
|
||
stock_idx: int,
|
||
open_mtx: np.ndarray,
|
||
high_mtx: np.ndarray,
|
||
low_mtx: np.ndarray,
|
||
close_mtx: np.ndarray,
|
||
volume_mtx: np.ndarray,
|
||
dates: np.ndarray,
|
||
) -> pd.DataFrame:
|
||
"""提取单个股票的 DataFrame"""
|
||
df = pd.DataFrame({
|
||
"date": dates,
|
||
"open": open_mtx[stock_idx, :],
|
||
"high": high_mtx[stock_idx, :],
|
||
"low": low_mtx[stock_idx, :],
|
||
"close": close_mtx[stock_idx, :],
|
||
"volume": volume_mtx[stock_idx, :],
|
||
})
|
||
# 过滤掉 NaN/0 值
|
||
df = df.replace(0, np.nan).dropna().reset_index(drop=True)
|
||
return df
|
||
|
||
|
||
# ============================================================================
|
||
# 绘图
|
||
# ============================================================================
|
||
|
||
def plot_sym_triangle(df: pd.DataFrame, res, stock_id: str, out_path: str) -> None:
|
||
import matplotlib.pyplot as plt
|
||
|
||
close = df["close"].to_numpy(dtype=float)
|
||
x = np.arange(len(df), dtype=float)
|
||
dates = df["date"].to_numpy()
|
||
a_u, b_u = res.upper_coef
|
||
a_l, b_l = res.lower_coef
|
||
|
||
start, end = res.start, res.end
|
||
xw = np.arange(start, end + 1, dtype=float)
|
||
upper = line_y(a_u, b_u, xw)
|
||
lower = line_y(a_l, b_l, xw)
|
||
|
||
plt.figure(figsize=(12, 5))
|
||
plt.plot(x, close, linewidth=1.2, label="close")
|
||
plt.plot(xw, upper, linewidth=2, label="upper")
|
||
plt.plot(xw, lower, linewidth=2, label="lower")
|
||
plt.axvline(end, color="gray", linestyle="--", linewidth=1)
|
||
|
||
start_date = dates[start] if len(dates) > start else start
|
||
end_date = dates[end] if len(dates) > end else end
|
||
|
||
plt.title(
|
||
f"[{stock_id}] sym_triangle: "
|
||
f"range={start_date}-{end_date}, "
|
||
f"slope=({a_u:.4f},{a_l:.4f}), "
|
||
f"width_ratio={res.width_ratio:.2f}, "
|
||
f"touches=({res.touches_upper},{res.touches_lower})"
|
||
)
|
||
|
||
if len(dates) > 0:
|
||
step = max(1, len(dates) // 8)
|
||
idx = np.arange(0, len(dates), step)
|
||
plt.xticks(idx, dates[idx], rotation=45, ha="right")
|
||
|
||
plt.legend()
|
||
plt.tight_layout()
|
||
plt.savefig(out_path, dpi=150)
|
||
plt.close()
|
||
|
||
|
||
# ============================================================================
|
||
# 主流程
|
||
# ============================================================================
|
||
|
||
def main() -> None:
|
||
print("=" * 60)
|
||
print("Symmetric Triangle Batch Detection - from pkl files")
|
||
print("=" * 60)
|
||
|
||
# 1. 加载数据
|
||
print("\n[1] Loading OHLCV pkl files...")
|
||
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(DATA_DIR)
|
||
n_stocks, n_days = close_mtx.shape
|
||
print(f" Stocks: {n_stocks}")
|
||
print(f" Days: {n_days}")
|
||
print(f" Date range: {dates[0]} ~ {dates[-1]}")
|
||
|
||
# 2. 准备输出目录
|
||
outputs_dir = os.path.join(os.path.dirname(__file__), "..", "..", "outputs", "sym_triangles")
|
||
os.makedirs(outputs_dir, exist_ok=True)
|
||
|
||
# 3. 遍历所有股票
|
||
print(f"\n[2] Scanning {n_stocks} stocks...")
|
||
detected = []
|
||
|
||
for i in range(n_stocks):
|
||
stock_code = tkrs[i]
|
||
stock_name = tkrs_name[i]
|
||
stock_id = f"{stock_code}" # 使用真实股票代码
|
||
|
||
# 提取单只股票数据
|
||
df = get_stock_df(i, open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates)
|
||
|
||
if len(df) < WINDOW:
|
||
if PRINT_DEBUG:
|
||
print(f" [{stock_id}] 跳过: 数据不足 ({len(df)} < {WINDOW})")
|
||
continue
|
||
|
||
# 运行检测
|
||
res = detect_sym_triangle(
|
||
df,
|
||
window=WINDOW,
|
||
pivot_k=PIVOT_K,
|
||
touch_tol=TOUCH_TOL,
|
||
touch_loss_max=TOUCH_LOSS_MAX,
|
||
shrink_ratio=SHRINK_RATIO,
|
||
break_tol=BREAK_TOL,
|
||
vol_window=VOL_WINDOW,
|
||
vol_k=VOL_K,
|
||
false_break_m=FALSE_BREAK_M,
|
||
upper_slope_max=UPPER_SLOPE_MAX,
|
||
lower_slope_min=LOWER_SLOPE_MIN,
|
||
boundary_fit=True,
|
||
boundary_n_segments=BOUNDARY_N_SEGMENTS,
|
||
boundary_source=BOUNDARY_SOURCE,
|
||
)
|
||
|
||
if res is not None:
|
||
detected.append((i, stock_id, res, df))
|
||
out_path = os.path.join(outputs_dir, f"{stock_id}.png")
|
||
plot_sym_triangle(df, res, stock_id, out_path)
|
||
print(f" [OK] {stock_id} -> {out_path}")
|
||
elif PRINT_DEBUG:
|
||
print(f" [--] {stock_id} not detected")
|
||
|
||
# 4. 汇总结果
|
||
print("\n" + "=" * 60)
|
||
print(f"Scan completed! {len(detected)}/{n_stocks} stocks have symmetric triangles")
|
||
print("=" * 60)
|
||
|
||
if detected:
|
||
print("\nDetected stocks:")
|
||
for i, stock_id, res, _ in detected:
|
||
name = tkrs_name[i]
|
||
print(f" - {stock_id} ({name}): slope=({res.upper_coef[0]:.4f},{res.lower_coef[0]:.4f}), "
|
||
f"width_ratio={res.width_ratio:.2f}, breakout={res.breakout}")
|
||
print(f"\nCharts saved to: {outputs_dir}")
|
||
|
||
# 5. Save summary CSV
|
||
if detected:
|
||
summary_path = os.path.join(outputs_dir, "summary.csv")
|
||
summary_data = []
|
||
for i, stock_id, res, _ in detected:
|
||
summary_data.append({
|
||
"stock_idx": i,
|
||
"stock_code": stock_id,
|
||
"stock_name": tkrs_name[i],
|
||
"start_date": dates[res.start],
|
||
"end_date": dates[res.end],
|
||
"upper_slope": res.upper_coef[0],
|
||
"lower_slope": res.lower_coef[0],
|
||
"width_ratio": res.width_ratio,
|
||
"touches_upper": res.touches_upper,
|
||
"touches_lower": res.touches_lower,
|
||
"breakout": res.breakout,
|
||
})
|
||
pd.DataFrame(summary_data).to_csv(summary_path, index=False, encoding="utf-8-sig")
|
||
print(f"Summary saved: {summary_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|