- Updated all_results.csv with additional stock data and breakout strength metrics. - Revised report.md to improve clarity and detail on stock selection criteria and results. - Expanded strong_breakout_down.csv and strong_breakout_up.csv with new entries reflecting recent analysis. - Introduced new chart images for selected stocks to visualize breakout patterns. - Added plot_converging_triangles.py script for generating visualizations of stocks meeting convergence criteria. - Enhanced report_converging_triangles.py to allow for date-specific reporting and improved output formatting. - Optimized run_converging_triangle.py for performance and added execution time logging.
245 lines
8.3 KiB
Python
245 lines
8.3 KiB
Python
"""
|
|
批量滚动检测收敛三角形 - 从 pkl 文件读取 OHLCV 数据
|
|
每个股票的每个交易日都会计算,输出 DataFrame
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import pickle
|
|
import time
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
# 让脚本能找到 src/ 下的模块
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
|
|
|
from converging_triangle import (
|
|
ConvergingTriangleParams,
|
|
ConvergingTriangleResult,
|
|
detect_converging_triangle_batch,
|
|
)
|
|
|
|
# ============================================================================
|
|
# 【可调参数区】
|
|
# ============================================================================
|
|
|
|
# --- 数据源 ---
|
|
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
|
|
|
|
# --- 检测参数 ---
|
|
PARAMS = ConvergingTriangleParams(
|
|
window=120,
|
|
pivot_k=15,
|
|
boundary_n_segments=2,
|
|
boundary_source="full",
|
|
upper_slope_max=0.10,
|
|
lower_slope_min=-0.10,
|
|
touch_tol=0.10,
|
|
touch_loss_max=0.10,
|
|
shrink_ratio=0.8,
|
|
break_tol=0.001,
|
|
vol_window=20,
|
|
vol_k=1.3,
|
|
false_break_m=5,
|
|
)
|
|
|
|
# --- 计算范围 ---
|
|
# 设为 None 表示计算全部历史;设为具体数字可以只算最近 N 天
|
|
RECENT_DAYS = 500 # 默认只算最近 500 天,避免计算时间过长
|
|
|
|
# --- 输出控制 ---
|
|
ONLY_VALID = True # True: 只输出识别到三角形的记录; False: 输出所有记录
|
|
VERBOSE = True
|
|
|
|
|
|
# ============================================================================
|
|
# pkl 数据加载
|
|
# ============================================================================
|
|
|
|
class FakeModule:
|
|
"""空壳模块,绕过 model 依赖"""
|
|
ndarray = np.ndarray
|
|
|
|
|
|
def load_pkl(pkl_path: str) -> dict:
|
|
"""加载 pkl 文件,返回字典 {mtx, dtes, tkrs, tkrs_name, ...}"""
|
|
sys.modules['model'] = FakeModule()
|
|
sys.modules['model.index_info'] = FakeModule()
|
|
|
|
with open(pkl_path, 'rb') as f:
|
|
data = pickle.load(f)
|
|
return data
|
|
|
|
|
|
def load_ohlcv_from_pkl(data_dir: str) -> tuple:
|
|
"""
|
|
从 pkl 文件加载 OHLCV 数据
|
|
|
|
Returns:
|
|
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx: shape=(n_stocks, n_days)
|
|
dates: shape=(n_days,) 真实日期 (如 20050104)
|
|
tkrs: shape=(n_stocks,) 股票代码
|
|
tkrs_name: shape=(n_stocks,) 股票名称
|
|
"""
|
|
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
|
|
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
|
|
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
|
|
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
|
|
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
|
|
|
|
# 使用 close 的元数据
|
|
dates = close_data["dtes"]
|
|
tkrs = close_data["tkrs"]
|
|
tkrs_name = close_data["tkrs_name"]
|
|
|
|
return (
|
|
open_data["mtx"],
|
|
high_data["mtx"],
|
|
low_data["mtx"],
|
|
close_data["mtx"],
|
|
volume_data["mtx"],
|
|
dates,
|
|
tkrs,
|
|
tkrs_name,
|
|
)
|
|
|
|
|
|
# ============================================================================
|
|
# 主流程
|
|
# ============================================================================
|
|
|
|
def main() -> None:
|
|
start_time = time.time()
|
|
|
|
print("=" * 70)
|
|
print("Converging Triangle Batch Detection")
|
|
print("=" * 70)
|
|
|
|
# 1. 加载数据
|
|
print("\n[1] Loading OHLCV pkl files...")
|
|
load_start = time.time()
|
|
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(DATA_DIR)
|
|
load_time = time.time() - load_start
|
|
n_stocks, n_days = close_mtx.shape
|
|
print(f" Stocks: {n_stocks}")
|
|
print(f" Days: {n_days}")
|
|
print(f" Date range: {dates[0]} ~ {dates[-1]}")
|
|
print(f" 加载耗时: {load_time:.2f} 秒")
|
|
|
|
# 2. 找有效数据范围(排除全 NaN 的列)
|
|
any_valid = np.any(~np.isnan(close_mtx), axis=0)
|
|
valid_day_idx = np.where(any_valid)[0]
|
|
if len(valid_day_idx) == 0:
|
|
print("No valid data found!")
|
|
return
|
|
last_valid_day = valid_day_idx[-1]
|
|
print(f" Last valid day: {last_valid_day} (date: {dates[last_valid_day]})")
|
|
|
|
# 3. 计算范围
|
|
start_day = PARAMS.window - 1
|
|
end_day = last_valid_day # 使用最后有效日,而不是矩阵末尾
|
|
|
|
if RECENT_DAYS is not None:
|
|
start_day = max(start_day, end_day - RECENT_DAYS + 1)
|
|
|
|
print(f"\n[2] Detection range: day {start_day} ~ {end_day}")
|
|
print(f" Window size: {PARAMS.window}")
|
|
print(f" Total points: {n_stocks} x {end_day - start_day + 1} = {n_stocks * (end_day - start_day + 1)}")
|
|
|
|
# 3. 批量检测
|
|
print("\n[3] Running batch detection...")
|
|
detect_start = time.time()
|
|
df = detect_converging_triangle_batch(
|
|
open_mtx=open_mtx,
|
|
high_mtx=high_mtx,
|
|
low_mtx=low_mtx,
|
|
close_mtx=close_mtx,
|
|
volume_mtx=volume_mtx,
|
|
params=PARAMS,
|
|
start_day=start_day,
|
|
end_day=end_day,
|
|
only_valid=ONLY_VALID,
|
|
verbose=VERBOSE,
|
|
)
|
|
detect_time = time.time() - detect_start
|
|
print(f" 检测耗时: {detect_time:.2f} 秒")
|
|
|
|
# 4. 添加股票代码、名称和真实日期
|
|
if len(df) > 0:
|
|
df["stock_code"] = df["stock_idx"].map(lambda x: tkrs[x] if x < len(tkrs) else "")
|
|
df["stock_name"] = df["stock_idx"].map(lambda x: tkrs_name[x] if x < len(tkrs_name) else "")
|
|
df["date"] = df["date_idx"].map(lambda x: dates[x] if x < len(dates) else 0)
|
|
|
|
# 5. 输出结果
|
|
print("\n" + "=" * 70)
|
|
print("Detection Results")
|
|
print("=" * 70)
|
|
|
|
if ONLY_VALID:
|
|
print(f"\nTotal valid triangles detected: {len(df)}")
|
|
else:
|
|
valid_count = df["is_valid"].sum()
|
|
print(f"\nTotal records: {len(df)}")
|
|
print(f"Valid triangles: {valid_count} ({valid_count/len(df)*100:.1f}%)")
|
|
|
|
# 按突破方向统计
|
|
if len(df) > 0 and "breakout_dir" in df.columns:
|
|
breakout_stats = df[df["is_valid"] == True]["breakout_dir"].value_counts()
|
|
print(f"\nBreakout statistics:")
|
|
for dir_name, count in breakout_stats.items():
|
|
print(f" - {dir_name}: {count}")
|
|
|
|
# 突破强度统计
|
|
if len(df) > 0 and "breakout_strength_up" in df.columns:
|
|
valid_df = df[df["is_valid"] == True]
|
|
if len(valid_df) > 0:
|
|
print(f"\nBreakout strength (valid triangles):")
|
|
print(f" - Up mean: {valid_df['breakout_strength_up'].mean():.4f}, max: {valid_df['breakout_strength_up'].max():.4f}")
|
|
print(f" - Down mean: {valid_df['breakout_strength_down'].mean():.4f}, max: {valid_df['breakout_strength_down'].max():.4f}")
|
|
|
|
# 6. 保存结果
|
|
outputs_dir = os.path.join(os.path.dirname(__file__), "..", "outputs", "converging_triangles")
|
|
os.makedirs(outputs_dir, exist_ok=True)
|
|
|
|
# 保存完整 CSV (使用 utf-8-sig 支持中文)
|
|
csv_path = os.path.join(outputs_dir, "all_results.csv")
|
|
df.to_csv(csv_path, index=False, encoding="utf-8-sig")
|
|
print(f"\nResults saved to: {csv_path}")
|
|
|
|
# 保存高强度突破记录 (strength > 0.3)
|
|
if len(df) > 0:
|
|
strong_up = df[(df["is_valid"] == True) & (df["breakout_strength_up"] > 0.3)]
|
|
strong_down = df[(df["is_valid"] == True) & (df["breakout_strength_down"] > 0.3)]
|
|
|
|
if len(strong_up) > 0:
|
|
strong_up_path = os.path.join(outputs_dir, "strong_breakout_up.csv")
|
|
strong_up.to_csv(strong_up_path, index=False, encoding="utf-8-sig")
|
|
print(f"Strong up breakouts ({len(strong_up)}): {strong_up_path}")
|
|
|
|
if len(strong_down) > 0:
|
|
strong_down_path = os.path.join(outputs_dir, "strong_breakout_down.csv")
|
|
strong_down.to_csv(strong_down_path, index=False, encoding="utf-8-sig")
|
|
print(f"Strong down breakouts ({len(strong_down)}): {strong_down_path}")
|
|
|
|
# 7. 显示样本
|
|
if len(df) > 0:
|
|
print("\n" + "-" * 70)
|
|
print("Sample results (first 10):")
|
|
print("-" * 70)
|
|
display_cols = [
|
|
"stock_code", "date", "is_valid",
|
|
"breakout_strength_up", "breakout_strength_down",
|
|
"breakout_dir", "width_ratio"
|
|
]
|
|
display_cols = [c for c in display_cols if c in df.columns]
|
|
print(df[display_cols].head(10).to_string(index=False))
|
|
|
|
total_time = time.time() - start_time
|
|
print("\n" + "=" * 70)
|
|
print(f"总耗时: {total_time:.2f} 秒 ({total_time/60:.2f} 分钟)")
|
|
print("=" * 70)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|