technical-patterns-lab/scripts/run_converging_triangle.py
褚宏光 8dea3fbccb Enhance converging triangle analysis with new scripts and data outputs
- Updated all_results.csv with additional stock data and breakout strength metrics.
- Revised report.md to improve clarity and detail on stock selection criteria and results.
- Expanded strong_breakout_down.csv and strong_breakout_up.csv with new entries reflecting recent analysis.
- Introduced new chart images for selected stocks to visualize breakout patterns.
- Added plot_converging_triangles.py script for generating visualizations of stocks meeting convergence criteria.
- Enhanced report_converging_triangles.py to allow for date-specific reporting and improved output formatting.
- Optimized run_converging_triangle.py for performance and added execution time logging.
2026-01-22 10:00:47 +08:00

245 lines
8.3 KiB
Python

"""
批量滚动检测收敛三角形 - 从 pkl 文件读取 OHLCV 数据
每个股票的每个交易日都会计算,输出 DataFrame
"""
import os
import sys
import pickle
import time
import numpy as np
import pandas as pd
# 让脚本能找到 src/ 下的模块
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
from converging_triangle import (
ConvergingTriangleParams,
ConvergingTriangleResult,
detect_converging_triangle_batch,
)
# ============================================================================
# 【可调参数区】
# ============================================================================
# --- 数据源 ---
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
# --- 检测参数 ---
PARAMS = ConvergingTriangleParams(
window=120,
pivot_k=15,
boundary_n_segments=2,
boundary_source="full",
upper_slope_max=0.10,
lower_slope_min=-0.10,
touch_tol=0.10,
touch_loss_max=0.10,
shrink_ratio=0.8,
break_tol=0.001,
vol_window=20,
vol_k=1.3,
false_break_m=5,
)
# --- 计算范围 ---
# 设为 None 表示计算全部历史;设为具体数字可以只算最近 N 天
RECENT_DAYS = 500 # 默认只算最近 500 天,避免计算时间过长
# --- 输出控制 ---
ONLY_VALID = True # True: 只输出识别到三角形的记录; False: 输出所有记录
VERBOSE = True
# ============================================================================
# pkl 数据加载
# ============================================================================
class FakeModule:
"""空壳模块,绕过 model 依赖"""
ndarray = np.ndarray
def load_pkl(pkl_path: str) -> dict:
"""加载 pkl 文件,返回字典 {mtx, dtes, tkrs, tkrs_name, ...}"""
sys.modules['model'] = FakeModule()
sys.modules['model.index_info'] = FakeModule()
with open(pkl_path, 'rb') as f:
data = pickle.load(f)
return data
def load_ohlcv_from_pkl(data_dir: str) -> tuple:
"""
从 pkl 文件加载 OHLCV 数据
Returns:
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx: shape=(n_stocks, n_days)
dates: shape=(n_days,) 真实日期 (如 20050104)
tkrs: shape=(n_stocks,) 股票代码
tkrs_name: shape=(n_stocks,) 股票名称
"""
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
# 使用 close 的元数据
dates = close_data["dtes"]
tkrs = close_data["tkrs"]
tkrs_name = close_data["tkrs_name"]
return (
open_data["mtx"],
high_data["mtx"],
low_data["mtx"],
close_data["mtx"],
volume_data["mtx"],
dates,
tkrs,
tkrs_name,
)
# ============================================================================
# 主流程
# ============================================================================
def main() -> None:
start_time = time.time()
print("=" * 70)
print("Converging Triangle Batch Detection")
print("=" * 70)
# 1. 加载数据
print("\n[1] Loading OHLCV pkl files...")
load_start = time.time()
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(DATA_DIR)
load_time = time.time() - load_start
n_stocks, n_days = close_mtx.shape
print(f" Stocks: {n_stocks}")
print(f" Days: {n_days}")
print(f" Date range: {dates[0]} ~ {dates[-1]}")
print(f" 加载耗时: {load_time:.2f}")
# 2. 找有效数据范围(排除全 NaN 的列)
any_valid = np.any(~np.isnan(close_mtx), axis=0)
valid_day_idx = np.where(any_valid)[0]
if len(valid_day_idx) == 0:
print("No valid data found!")
return
last_valid_day = valid_day_idx[-1]
print(f" Last valid day: {last_valid_day} (date: {dates[last_valid_day]})")
# 3. 计算范围
start_day = PARAMS.window - 1
end_day = last_valid_day # 使用最后有效日,而不是矩阵末尾
if RECENT_DAYS is not None:
start_day = max(start_day, end_day - RECENT_DAYS + 1)
print(f"\n[2] Detection range: day {start_day} ~ {end_day}")
print(f" Window size: {PARAMS.window}")
print(f" Total points: {n_stocks} x {end_day - start_day + 1} = {n_stocks * (end_day - start_day + 1)}")
# 3. 批量检测
print("\n[3] Running batch detection...")
detect_start = time.time()
df = detect_converging_triangle_batch(
open_mtx=open_mtx,
high_mtx=high_mtx,
low_mtx=low_mtx,
close_mtx=close_mtx,
volume_mtx=volume_mtx,
params=PARAMS,
start_day=start_day,
end_day=end_day,
only_valid=ONLY_VALID,
verbose=VERBOSE,
)
detect_time = time.time() - detect_start
print(f" 检测耗时: {detect_time:.2f}")
# 4. 添加股票代码、名称和真实日期
if len(df) > 0:
df["stock_code"] = df["stock_idx"].map(lambda x: tkrs[x] if x < len(tkrs) else "")
df["stock_name"] = df["stock_idx"].map(lambda x: tkrs_name[x] if x < len(tkrs_name) else "")
df["date"] = df["date_idx"].map(lambda x: dates[x] if x < len(dates) else 0)
# 5. 输出结果
print("\n" + "=" * 70)
print("Detection Results")
print("=" * 70)
if ONLY_VALID:
print(f"\nTotal valid triangles detected: {len(df)}")
else:
valid_count = df["is_valid"].sum()
print(f"\nTotal records: {len(df)}")
print(f"Valid triangles: {valid_count} ({valid_count/len(df)*100:.1f}%)")
# 按突破方向统计
if len(df) > 0 and "breakout_dir" in df.columns:
breakout_stats = df[df["is_valid"] == True]["breakout_dir"].value_counts()
print(f"\nBreakout statistics:")
for dir_name, count in breakout_stats.items():
print(f" - {dir_name}: {count}")
# 突破强度统计
if len(df) > 0 and "breakout_strength_up" in df.columns:
valid_df = df[df["is_valid"] == True]
if len(valid_df) > 0:
print(f"\nBreakout strength (valid triangles):")
print(f" - Up mean: {valid_df['breakout_strength_up'].mean():.4f}, max: {valid_df['breakout_strength_up'].max():.4f}")
print(f" - Down mean: {valid_df['breakout_strength_down'].mean():.4f}, max: {valid_df['breakout_strength_down'].max():.4f}")
# 6. 保存结果
outputs_dir = os.path.join(os.path.dirname(__file__), "..", "outputs", "converging_triangles")
os.makedirs(outputs_dir, exist_ok=True)
# 保存完整 CSV (使用 utf-8-sig 支持中文)
csv_path = os.path.join(outputs_dir, "all_results.csv")
df.to_csv(csv_path, index=False, encoding="utf-8-sig")
print(f"\nResults saved to: {csv_path}")
# 保存高强度突破记录 (strength > 0.3)
if len(df) > 0:
strong_up = df[(df["is_valid"] == True) & (df["breakout_strength_up"] > 0.3)]
strong_down = df[(df["is_valid"] == True) & (df["breakout_strength_down"] > 0.3)]
if len(strong_up) > 0:
strong_up_path = os.path.join(outputs_dir, "strong_breakout_up.csv")
strong_up.to_csv(strong_up_path, index=False, encoding="utf-8-sig")
print(f"Strong up breakouts ({len(strong_up)}): {strong_up_path}")
if len(strong_down) > 0:
strong_down_path = os.path.join(outputs_dir, "strong_breakout_down.csv")
strong_down.to_csv(strong_down_path, index=False, encoding="utf-8-sig")
print(f"Strong down breakouts ({len(strong_down)}): {strong_down_path}")
# 7. 显示样本
if len(df) > 0:
print("\n" + "-" * 70)
print("Sample results (first 10):")
print("-" * 70)
display_cols = [
"stock_code", "date", "is_valid",
"breakout_strength_up", "breakout_strength_down",
"breakout_dir", "width_ratio"
]
display_cols = [c for c in display_cols if c in df.columns]
print(df[display_cols].head(10).to_string(index=False))
total_time = time.time() - start_time
print("\n" + "=" * 70)
print(f"总耗时: {total_time:.2f} 秒 ({total_time/60:.2f} 分钟)")
print("=" * 70)
if __name__ == "__main__":
main()