- Added support for daily best stocks reporting, including a new CSV output for daily best triangles based on strength. - Introduced a logging mechanism to capture detailed execution logs, improving traceability and debugging. - Implemented a v2 optimization for batch detection, significantly reducing detection time from 92 seconds to under 2 seconds. - Updated the .gitignore file to include new log files and outputs for better management. - Enhanced the pipeline script to allow for flexible configuration of detection parameters and improved user experience. Files modified: - scripts/run_converging_triangle.py: Added logging and v2 optimization. - scripts/pipeline_converging_triangle.py: Updated for new features and logging. - scripts/plot_converging_triangles.py: Adjusted for new plotting options. - New files: discuss/20260127-拟合线.md, discuss/20260128-拟合线.md, and several images for visual documentation.
343 lines
13 KiB
Python
343 lines
13 KiB
Python
"""
|
||
批量滚动检测收敛三角形 - 从 pkl 文件读取 OHLCV 数据
|
||
每个股票的每个交易日都会计算,输出 DataFrame
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import pickle
|
||
import time
|
||
import numpy as np
|
||
import pandas as pd
|
||
from datetime import datetime
|
||
from io import StringIO
|
||
|
||
# 让脚本能找到 src/ 下的模块
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "src"))
|
||
|
||
from converging_triangle import (
|
||
ConvergingTriangleParams,
|
||
ConvergingTriangleResult,
|
||
detect_converging_triangle_batch,
|
||
detect_converging_triangle_batch_v2, # v2优化版本
|
||
)
|
||
|
||
# ============================================================================
|
||
# 【可调参数区】
|
||
# ============================================================================
|
||
|
||
# --- 数据源 ---
|
||
DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
|
||
|
||
# --- 导入统一的参数配置 ---
|
||
from triangle_config import (
|
||
DETECTION_PARAMS as PARAMS,
|
||
RECENT_DAYS,
|
||
ONLY_VALID,
|
||
VERBOSE,
|
||
REALTIME_MODE, # 新增
|
||
FLEXIBLE_ZONE, # 新增
|
||
USE_V2_OPTIMIZATION, # v2优化开关
|
||
)
|
||
|
||
|
||
# ============================================================================
|
||
# pkl 数据加载
|
||
# ============================================================================
|
||
|
||
class FakeModule:
|
||
"""空壳模块,绕过 model 依赖"""
|
||
ndarray = np.ndarray
|
||
|
||
|
||
def load_pkl(pkl_path: str) -> dict:
|
||
"""加载 pkl 文件,返回字典 {mtx, dtes, tkrs, tkrs_name, ...}"""
|
||
sys.modules['model'] = FakeModule()
|
||
sys.modules['model.index_info'] = FakeModule()
|
||
|
||
with open(pkl_path, 'rb') as f:
|
||
data = pickle.load(f)
|
||
return data
|
||
|
||
|
||
def load_ohlcv_from_pkl(data_dir: str) -> tuple:
|
||
"""
|
||
从 pkl 文件加载 OHLCV 数据
|
||
|
||
Returns:
|
||
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx: shape=(n_stocks, n_days)
|
||
dates: shape=(n_days,) 真实日期 (如 20050104)
|
||
tkrs: shape=(n_stocks,) 股票代码
|
||
tkrs_name: shape=(n_stocks,) 股票名称
|
||
"""
|
||
open_data = load_pkl(os.path.join(data_dir, "open.pkl"))
|
||
high_data = load_pkl(os.path.join(data_dir, "high.pkl"))
|
||
low_data = load_pkl(os.path.join(data_dir, "low.pkl"))
|
||
close_data = load_pkl(os.path.join(data_dir, "close.pkl"))
|
||
volume_data = load_pkl(os.path.join(data_dir, "volume.pkl"))
|
||
|
||
# 使用 close 的元数据
|
||
dates = close_data["dtes"]
|
||
tkrs = close_data["tkrs"]
|
||
tkrs_name = close_data["tkrs_name"]
|
||
|
||
return (
|
||
open_data["mtx"],
|
||
high_data["mtx"],
|
||
low_data["mtx"],
|
||
close_data["mtx"],
|
||
volume_data["mtx"],
|
||
dates,
|
||
tkrs,
|
||
tkrs_name,
|
||
)
|
||
|
||
|
||
# ============================================================================
|
||
# 日志工具
|
||
# ============================================================================
|
||
|
||
class Logger:
|
||
"""同时输出到控制台和日志文件"""
|
||
def __init__(self, log_path: str):
|
||
self.log_path = log_path
|
||
self.buffer = StringIO()
|
||
|
||
def print(self, *args, **kwargs):
|
||
"""打印到控制台和缓冲区"""
|
||
# 打印到控制台
|
||
print(*args, **kwargs)
|
||
# 写入缓冲区
|
||
print(*args, **kwargs, file=self.buffer)
|
||
|
||
def save(self):
|
||
"""保存日志到文件"""
|
||
with open(self.log_path, 'w', encoding='utf-8') as f:
|
||
f.write(self.buffer.getvalue())
|
||
print(f"\n日志已保存到: {self.log_path}")
|
||
|
||
|
||
# ============================================================================
|
||
# 主流程
|
||
# ============================================================================
|
||
|
||
def main() -> None:
|
||
start_time = time.time()
|
||
|
||
# 初始化日志
|
||
outputs_dir = os.path.join(os.path.dirname(__file__), "..", "outputs", "converging_triangles")
|
||
os.makedirs(outputs_dir, exist_ok=True)
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
log_path = os.path.join(outputs_dir, f"run_log_{timestamp}.txt")
|
||
log = Logger(log_path)
|
||
|
||
log.print("=" * 70)
|
||
log.print("Converging Triangle Batch Detection")
|
||
log.print(f"运行时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||
log.print("=" * 70)
|
||
|
||
# 1. 加载数据
|
||
log.print("\n[1] Loading OHLCV pkl files...")
|
||
load_start = time.time()
|
||
open_mtx, high_mtx, low_mtx, close_mtx, volume_mtx, dates, tkrs, tkrs_name = load_ohlcv_from_pkl(DATA_DIR)
|
||
load_time = time.time() - load_start
|
||
n_stocks, n_days = close_mtx.shape
|
||
log.print(f" Stocks: {n_stocks}")
|
||
log.print(f" Days: {n_days}")
|
||
log.print(f" Date range: {dates[0]} ~ {dates[-1]}")
|
||
log.print(f" 加载耗时: {load_time:.2f} 秒")
|
||
|
||
# 2. 找有效数据范围(排除全 NaN 的列)
|
||
any_valid = np.any(~np.isnan(close_mtx), axis=0)
|
||
valid_day_idx = np.where(any_valid)[0]
|
||
if len(valid_day_idx) == 0:
|
||
log.print("No valid data found!")
|
||
return
|
||
last_valid_day = valid_day_idx[-1]
|
||
log.print(f" Last valid day: {last_valid_day} (date: {dates[last_valid_day]})")
|
||
|
||
# 3. 计算范围
|
||
start_day = PARAMS.window - 1
|
||
end_day = last_valid_day # 使用最后有效日,而不是矩阵末尾
|
||
|
||
if RECENT_DAYS is not None:
|
||
start_day = max(start_day, end_day - RECENT_DAYS + 1)
|
||
|
||
log.print(f"\n[2] Detection range: day {start_day} ~ {end_day}")
|
||
log.print(f" Window size: {PARAMS.window}")
|
||
log.print(f" Total points: {n_stocks} x {end_day - start_day + 1} = {n_stocks * (end_day - start_day + 1)}")
|
||
|
||
# 3. 批量检测
|
||
log.print("\n[3] Running batch detection...")
|
||
detect_start = time.time()
|
||
|
||
# 根据配置选择v1或v2版本
|
||
if USE_V2_OPTIMIZATION and not REALTIME_MODE:
|
||
# v2优化版本(预计算枢轴点,不支持实时模式)
|
||
log.print(" 使用: v2优化版本(预计算枢轴点)")
|
||
df = detect_converging_triangle_batch_v2(
|
||
open_mtx=open_mtx,
|
||
high_mtx=high_mtx,
|
||
low_mtx=low_mtx,
|
||
close_mtx=close_mtx,
|
||
volume_mtx=volume_mtx,
|
||
params=PARAMS,
|
||
start_day=start_day,
|
||
end_day=end_day,
|
||
only_valid=ONLY_VALID,
|
||
verbose=VERBOSE,
|
||
)
|
||
else:
|
||
# v1原版(支持实时模式)
|
||
if USE_V2_OPTIMIZATION and REALTIME_MODE:
|
||
log.print(" 注意: v2优化不支持实时模式,自动回退到v1")
|
||
log.print(" 使用: v1原版")
|
||
df = detect_converging_triangle_batch(
|
||
open_mtx=open_mtx,
|
||
high_mtx=high_mtx,
|
||
low_mtx=low_mtx,
|
||
close_mtx=close_mtx,
|
||
volume_mtx=volume_mtx,
|
||
params=PARAMS,
|
||
start_day=start_day,
|
||
end_day=end_day,
|
||
only_valid=ONLY_VALID,
|
||
verbose=VERBOSE,
|
||
real_time_mode=REALTIME_MODE,
|
||
flexible_zone=FLEXIBLE_ZONE,
|
||
)
|
||
|
||
detect_time = time.time() - detect_start
|
||
log.print(f" 检测耗时: {detect_time:.2f} 秒")
|
||
if not USE_V2_OPTIMIZATION or REALTIME_MODE:
|
||
log.print(f" 检测模式: {'实时模式' if REALTIME_MODE else '标准模式'}")
|
||
if REALTIME_MODE:
|
||
log.print(f" 灵活区域: {FLEXIBLE_ZONE} 天")
|
||
|
||
# 4. 添加股票代码、名称和真实日期
|
||
if len(df) > 0:
|
||
df["stock_code"] = df["stock_idx"].map(lambda x: tkrs[x] if x < len(tkrs) else "")
|
||
df["stock_name"] = df["stock_idx"].map(lambda x: tkrs_name[x] if x < len(tkrs_name) else "")
|
||
df["date"] = df["date_idx"].map(lambda x: dates[x] if x < len(dates) else 0)
|
||
# 计算综合强度(取向上和向下的最大值)
|
||
df["max_strength"] = df[["breakout_strength_up", "breakout_strength_down"]].max(axis=1)
|
||
|
||
# 5. 输出结果
|
||
log.print("\n" + "=" * 70)
|
||
log.print("Detection Results")
|
||
log.print("=" * 70)
|
||
|
||
if ONLY_VALID:
|
||
log.print(f"\nTotal valid triangles detected: {len(df)}")
|
||
else:
|
||
valid_count = df["is_valid"].sum()
|
||
log.print(f"\nTotal records: {len(df)}")
|
||
log.print(f"Valid triangles: {valid_count} ({valid_count/len(df)*100:.1f}%)")
|
||
|
||
# 按突破方向统计
|
||
if len(df) > 0 and "breakout_dir" in df.columns:
|
||
breakout_stats = df[df["is_valid"] == True]["breakout_dir"].value_counts()
|
||
log.print(f"\nBreakout statistics:")
|
||
for dir_name, count in breakout_stats.items():
|
||
log.print(f" - {dir_name}: {count}")
|
||
|
||
# 突破强度统计
|
||
if len(df) > 0 and "breakout_strength_up" in df.columns:
|
||
valid_df = df[df["is_valid"] == True]
|
||
if len(valid_df) > 0:
|
||
log.print(f"\nBreakout strength (valid triangles):")
|
||
log.print(f" - Up mean: {valid_df['breakout_strength_up'].mean():.4f}, max: {valid_df['breakout_strength_up'].max():.4f}")
|
||
log.print(f" - Down mean: {valid_df['breakout_strength_down'].mean():.4f}, max: {valid_df['breakout_strength_down'].max():.4f}")
|
||
|
||
# 6. 保存结果
|
||
# 保存完整 CSV (使用 utf-8-sig 支持中文)
|
||
csv_path = os.path.join(outputs_dir, "all_results.csv")
|
||
df.to_csv(csv_path, index=False, encoding="utf-8-sig")
|
||
log.print(f"\nResults saved to: {csv_path}")
|
||
|
||
# 保存高强度突破记录 (strength > 0.3)
|
||
if len(df) > 0:
|
||
strong_up = df[(df["is_valid"] == True) & (df["breakout_strength_up"] > 0.3)]
|
||
strong_down = df[(df["is_valid"] == True) & (df["breakout_strength_down"] > 0.3)]
|
||
|
||
if len(strong_up) > 0:
|
||
strong_up_path = os.path.join(outputs_dir, "strong_breakout_up.csv")
|
||
strong_up.to_csv(strong_up_path, index=False, encoding="utf-8-sig")
|
||
log.print(f"Strong up breakouts ({len(strong_up)}): {strong_up_path}")
|
||
|
||
if len(strong_down) > 0:
|
||
strong_down_path = os.path.join(outputs_dir, "strong_breakout_down.csv")
|
||
strong_down.to_csv(strong_down_path, index=False, encoding="utf-8-sig")
|
||
log.print(f"Strong down breakouts ({len(strong_down)}): {strong_down_path}")
|
||
|
||
# 7. 每日最佳股票报告(核心新功能)
|
||
if len(df) > 0 and "date" in df.columns:
|
||
log.print("\n" + "=" * 70)
|
||
log.print("每日最佳收敛三角形股票 (按日期)")
|
||
log.print("=" * 70)
|
||
|
||
valid_df = df[df["is_valid"] == True].copy()
|
||
if len(valid_df) > 0:
|
||
# 按日期分组,每天取分数最高的股票
|
||
daily_best = valid_df.loc[valid_df.groupby("date")["max_strength"].idxmax()]
|
||
daily_best = daily_best.sort_values("date", ascending=True)
|
||
|
||
log.print(f"\n共 {len(daily_best)} 个交易日检测到收敛三角形:")
|
||
log.print("-" * 90)
|
||
log.print(f"{'日期':<12} {'股票代码':<10} {'股票名称':<12} {'强度↑':>8} {'强度↓':>8} {'方向':<6} {'收敛比':>8}")
|
||
log.print("-" * 90)
|
||
|
||
for _, row in daily_best.iterrows():
|
||
log.print(
|
||
f"{int(row['date']):<12} "
|
||
f"{row['stock_code']:<10} "
|
||
f"{row['stock_name']:<12} "
|
||
f"{row['breakout_strength_up']:>8.4f} "
|
||
f"{row['breakout_strength_down']:>8.4f} "
|
||
f"{row['breakout_dir']:<6} "
|
||
f"{row['width_ratio']:>8.3f}"
|
||
)
|
||
|
||
log.print("-" * 90)
|
||
|
||
# 保存每日最佳到 CSV
|
||
daily_best_path = os.path.join(outputs_dir, "daily_best.csv")
|
||
daily_best.to_csv(daily_best_path, index=False, encoding="utf-8-sig")
|
||
log.print(f"\n每日最佳已保存到: {daily_best_path}")
|
||
|
||
# 统计:每只股票被选为每日最佳的次数
|
||
log.print("\n" + "-" * 70)
|
||
log.print("股票被选为每日最佳的次数排行:")
|
||
log.print("-" * 70)
|
||
stock_counts = daily_best.groupby(["stock_code", "stock_name"]).size().reset_index(name="count")
|
||
stock_counts = stock_counts.sort_values("count", ascending=False).head(20)
|
||
for _, row in stock_counts.iterrows():
|
||
log.print(f" {row['stock_code']} {row['stock_name']}: {row['count']} 次")
|
||
else:
|
||
log.print("\n没有检测到有效的收敛三角形")
|
||
|
||
# 8. 显示样本
|
||
if len(df) > 0:
|
||
log.print("\n" + "-" * 70)
|
||
log.print("Sample results (first 10):")
|
||
log.print("-" * 70)
|
||
display_cols = [
|
||
"stock_code", "date", "is_valid",
|
||
"breakout_strength_up", "breakout_strength_down",
|
||
"breakout_dir", "width_ratio"
|
||
]
|
||
display_cols = [c for c in display_cols if c in df.columns]
|
||
log.print(df[display_cols].head(10).to_string(index=False))
|
||
|
||
total_time = time.time() - start_time
|
||
log.print("\n" + "=" * 70)
|
||
log.print(f"总耗时: {total_time:.2f} 秒 ({total_time/60:.2f} 分钟)")
|
||
log.print("=" * 70)
|
||
|
||
# 保存日志
|
||
log.save()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|