technical-patterns-lab/discuss/20260130-检测算法优化方案.md
褚宏光 bf6baa5483 Add scoring module and enhance HTML viewer with standardization
- Add scripts/scoring/ module with normalizer, sensitivity analysis, and config
- Enhance stock_viewer.html with standardized scoring display
- Add integration tests and normalization verification scripts
- Add documentation for standardization implementation and usage guides
- Add data distribution analysis reports for strength scoring dimensions
- Update discussion documents with algorithm optimization plans
2026-01-30 18:43:37 +08:00

22 KiB
Raw Permalink Blame History

收敛三角形检测算法优化方案

目标场景

  • 股票数量: 5000只
  • 运行频率: 每天执行
  • 当前耗时: 10-60秒预估
  • 优化目标: <5秒10倍提速
  • 质量要求: 保持检测准确率

性能瓶颈分析

当前算法流程

检测流程(每只股票/每天):
┌─────────────────────────────────────────────────┐
│ 1. 枢轴点检测 (pivots_fractal_hybrid)      30% │ ← 热点1
│ 2. 边界线拟合 (fit_pivot_line)             25% │ ← 热点2  
│ 3. 几何验证 (收敛度/触碰/斜率)             20% │
│ 4. 突破强度计算 (价格/成交量)              15% │
│ 5. DataFrame构建 + 数据复制                10% │ ← 热点3
└─────────────────────────────────────────────────┘

关键瓶颈

  1. 枢轴点检测: O(n*k) 滑动窗口,重复计算
  2. 边界线拟合: 迭代离群点移除,多次最小二乘
  3. Python循环: 大量stock×day双层循环
  4. 内存分配: 频繁创建临时数组

优化方案(分级实施)

🚀 Level 1: 向量化优化预期提速2-3倍

1.1 枢轴点检测向量化

当前实现 (O(n*k) 滑动窗口):

def pivots_fractal(high, low, k=15):
    """滑动窗口查找局部极值"""
    ph, pl = [], []
    for i in range(k, len(high) - k):
        # 检查左右k个点
        if all(high[i] >= high[j] for j in range(i-k, i+k+1) if j != i):
            ph.append(i)
        if all(low[i] <= low[j] for j in range(i-k, i+k+1) if j != i):
            pl.append(i)
    return np.array(ph), np.array(pl)

优化方案 (向量化):

def pivots_fractal_vectorized(high, low, k=15):
    """
    向量化枢轴点检测
    
    核心思路:
    1. 使用scipy.signal.argrelextrema一次性找所有极值
    2. 或使用卷积/滚动窗口向量化计算
    
    预期提速3-5倍
    """
    from scipy.signal import argrelextrema
    
    # 找局部极大值(高点)
    ph = argrelextrema(high, np.greater_equal, order=k)[0]
    
    # 找局部极小值(低点)
    pl = argrelextrema(low, np.less_equal, order=k)[0]
    
    return ph, pl


def pivots_fractal_rolling(high, low, k=15):
    """
    使用pandas滚动窗口实现
    
    预期提速2-4倍
    """
    import pandas as pd
    
    high_series = pd.Series(high)
    low_series = pd.Series(low)
    
    # 滚动窗口找最大/最小值索引
    window = 2*k + 1
    high_rolling_max = high_series.rolling(window, center=True).max()
    low_rolling_min = low_series.rolling(window, center=True).min()
    
    # 中心点等于窗口极值的位置即为枢轴点
    ph = np.where((high_series == high_rolling_max) & (high_series.notna()))[0]
    pl = np.where((low_series == low_rolling_min) & (low_series.notna()))[0]
    
    return ph, pl

实施

  • 文件:src/converging_triangle.py
  • 函数:pivots_fractal()pivots_fractal_hybrid()
  • 向下兼容保留原函数作为fallback

1.2 边界线拟合优化

当前实现 (迭代离群点移除):

def fit_pivot_line(pivot_indices, pivot_values, mode="upper", max_iter=10):
    """迭代移除离群点,多次最小二乘拟合"""
    for iter in range(max_iter):
        a, b = np.polyfit(indices, values, 1)  # 最小二乘
        residuals = values - (a * indices + b)
        outliers = find_outliers(residuals)
        if no_outliers: break
        remove_outliers()
    return a, b

优化方案A (预计算+缓存):

def fit_pivot_line_cached(pivot_indices, pivot_values, mode="upper", cache=None):
    """
    缓存中间结果,避免重复计算
    
    场景:相邻日期的枢轴点大部分重叠
    策略缓存最近N天的拟合结果增量更新
    
    预期提速30-50%(针对滚动窗口场景)
    """
    cache_key = (tuple(pivot_indices), tuple(pivot_values), mode)
    
    if cache and cache_key in cache:
        return cache[cache_key]
    
    # 原有拟合逻辑
    result = _fit_pivot_line_core(pivot_indices, pivot_values, mode)
    
    if cache is not None:
        cache[cache_key] = result
    
    return result

优化方案B (快速拟合算法):

def fit_pivot_line_ransac(pivot_indices, pivot_values, mode="upper"):
    """
    使用RANSAC快速拟合对离群点鲁棒
    
    sklearn.linear_model.RANSACRegressor
    预期提速2-3倍
    """
    from sklearn.linear_model import RANSACRegressor
    
    X = pivot_indices.reshape(-1, 1)
    y = pivot_values
    
    ransac = RANSACRegressor(
        residual_threshold=threshold,
        max_trials=100,
        random_state=42
    )
    ransac.fit(X, y)
    
    a = ransac.estimator_.coef_[0]
    b = ransac.estimator_.intercept_
    inlier_mask = ransac.inlier_mask_
    
    return a, b, np.where(inlier_mask)[0]

推荐: 先实施方案A缓存简单且收益稳定


1.3 消除Python循环

当前实现 (双层循环):

def detect_converging_triangle_batch(...):
    results = []
    for stock_idx in range(n_stocks):
        for date_idx in range(start_day, end_day + 1):
            result = detect_converging_triangle_single(
                stock_idx, date_idx, ...
            )
            results.append(result)
    return pd.DataFrame(results)

优化方案 (向量化外层):

def detect_converging_triangle_batch_vectorized(...):
    """
    外层循环向量化
    
    策略:
    1. 按date_idx分组一次处理所有股票
    2. 使用numpy广播并行计算
    
    预期提速1.5-2倍
    """
    all_results = []
    
    for date_idx in range(start_day, end_day + 1):
        # 一次性处理所有股票在同一天的检测
        # 提取窗口数据(向量化)
        window_start = date_idx - window + 1
        high_windows = high_mtx[:, window_start:date_idx+1]  # (n_stocks, window)
        low_windows = low_mtx[:, window_start:date_idx+1]
        
        # 批量检测枢轴点利用numpy向量运算
        pivots_batch = detect_pivots_batch(high_windows, low_windows)
        
        # 批量拟合边界线
        fits_batch = fit_lines_batch(pivots_batch)
        
        # 批量计算强度
        strengths_batch = calc_strengths_batch(fits_batch, ...)
        
        all_results.append(strengths_batch)
    
    return np.vstack(all_results)

关键: 需要重构算法,使单个函数能处理 (n_stocks, window) 维度


Level 2: Numba JIT加速预期提速5-10倍

2.1 Numba加速核心函数

from numba import jit, prange

@jit(nopython=True, parallel=True, cache=True)
def pivots_fractal_numba(high, low, k=15):
    """
    Numba加速枢轴点检测
    
    优势:
    - nopython=True: 编译为机器码
    - parallel=True: 多线程并行
    - cache=True: 缓存编译结果
    
    预期提速10-20倍相比纯Python
    """
    n = len(high)
    ph_list = []
    pl_list = []
    
    for i in prange(k, n - k):  # 并行循环
        # 检查是否为高点
        is_high_pivot = True
        for j in range(i - k, i + k + 1):
            if j != i and high[i] < high[j]:
                is_high_pivot = False
                break
        if is_high_pivot:
            ph_list.append(i)
        
        # 检查是否为低点
        is_low_pivot = True
        for j in range(i - k, i + k + 1):
            if j != i and low[i] > low[j]:
                is_low_pivot = False
                break
        if is_low_pivot:
            pl_list.append(i)
    
    return np.array(ph_list), np.array(pl_list)


@jit(nopython=True, cache=True)
def fit_line_numba(x, y):
    """Numba加速最小二乘拟合"""
    n = len(x)
    x_mean = np.mean(x)
    y_mean = np.mean(y)
    
    numerator = np.sum((x - x_mean) * (y - y_mean))
    denominator = np.sum((x - x_mean) ** 2)
    
    a = numerator / denominator
    b = y_mean - a * x_mean
    
    return a, b


@jit(nopython=True, parallel=True)
def detect_batch_numba(
    high_mtx, low_mtx, close_mtx, volume_mtx,
    window, k, start_day, end_day
):
    """
    Numba加速批量检测
    
    核心优化:
    - 消除Python对象开销
    - 并行化最外层循环
    - 预分配结果数组
    
    预期提速5-10倍
    """
    n_stocks, n_days = high_mtx.shape
    total_points = n_stocks * (end_day - start_day + 1)
    
    # 预分配结果数组
    strength_up = np.zeros(total_points, dtype=np.float64)
    strength_down = np.zeros(total_points, dtype=np.float64)
    is_valid = np.zeros(total_points, dtype=np.bool_)
    
    # 并行处理每个检测点
    for idx in prange(total_points):
        stock_idx = idx // (end_day - start_day + 1)
        day_offset = idx % (end_day - start_day + 1)
        date_idx = start_day + day_offset
        
        # 提取窗口数据
        window_start = date_idx - window + 1
        high_win = high_mtx[stock_idx, window_start:date_idx+1]
        low_win = low_mtx[stock_idx, window_start:date_idx+1]
        
        # 检测枢轴点
        ph, pl = pivots_fractal_numba(high_win, low_win, k)
        
        # ... 后续处理 ...
        
        strength_up[idx] = computed_strength_up
        strength_down[idx] = computed_strength_down
        is_valid[idx] = computed_is_valid
    
    return strength_up, strength_down, is_valid

实施要点

  • Numba要求函数纯数值计算不能有pandas/字典等Python对象
  • 首次运行会有JIT编译开销~1-2秒后续调用极快
  • 需要将算法拆分为纯数值函数

🔥 Level 3: 并行化+缓存策略预期提速10-20倍

3.1 多进程并行

from multiprocessing import Pool, cpu_count
from functools import partial

def detect_stock_range(stock_indices, high_mtx, low_mtx, ...):
    """处理一批股票的检测任务"""
    results = []
    for stock_idx in stock_indices:
        for date_idx in range(start_day, end_day + 1):
            result = detect_converging_triangle_single(
                stock_idx, date_idx, high_mtx, low_mtx, ...
            )
            results.append(result)
    return results


def detect_converging_triangle_parallel(
    high_mtx, low_mtx, close_mtx, volume_mtx,
    params, start_day, end_day,
    n_workers=None
):
    """
    多进程并行检测
    
    策略:
    - 将5000只股票分成n_workers组
    - 每个进程处理一组股票
    - 主进程合并结果
    
    预期提速接近线性8核约7倍
    """
    n_stocks = high_mtx.shape[0]
    n_workers = n_workers or cpu_count() - 1
    
    # 分配任务(按股票索引分组)
    stock_groups = np.array_split(range(n_stocks), n_workers)
    
    # 创建部分函数(固定参数)
    detect_fn = partial(
        detect_stock_range,
        high_mtx=high_mtx,
        low_mtx=low_mtx,
        close_mtx=close_mtx,
        volume_mtx=volume_mtx,
        params=params,
        start_day=start_day,
        end_day=end_day
    )
    
    # 并行执行
    with Pool(n_workers) as pool:
        results_groups = pool.map(detect_fn, stock_groups)
    
    # 合并结果
    all_results = []
    for group_results in results_groups:
        all_results.extend(group_results)
    
    return pd.DataFrame(all_results)

注意

  • 适合CPU密集型任务
  • 需要足够内存(数据复制到子进程)
  • 5000只股票场景下8-16核最优

3.2 增量计算+缓存

class IncrementalDetector:
    """
    增量检测器:缓存历史计算结果
    
    场景:每天新增一个交易日,复用历史检测结果
    策略:
    1. 缓存最近N天的枢轴点/拟合线
    2. 新增日期时只计算增量部分
    3. LRU淘汰旧缓存
    
    预期收益:
    - 首次运行:无加速
    - 后续每日提速5-10倍只需计算最新day
    """
    
    def __init__(self, window=240, cache_size=100):
        self.window = window
        self.pivot_cache = {}  # {stock_idx: {date_idx: (ph, pl)}}
        self.fit_cache = {}    # {stock_idx: {date_idx: fit_result}}
        self.cache_size = cache_size
    
    def detect_incremental(self, stock_idx, new_date_idx, high, low, close, volume):
        """
        增量检测:利用缓存快速计算
        
        逻辑:
        1. 检查缓存中是否有前一天的结果
        2. 如果有,只需:
           - 更新枢轴点新增1天数据
           - 复用历史拟合结果
        3. 如果无,全量计算并缓存
        """
        prev_date_idx = new_date_idx - 1
        
        # 尝试从缓存获取前一天结果
        if stock_idx in self.pivot_cache and prev_date_idx in self.pivot_cache[stock_idx]:
            # 增量更新枢轴点
            prev_ph, prev_pl = self.pivot_cache[stock_idx][prev_date_idx]
            new_ph, new_pl = self._update_pivots_incremental(
                prev_ph, prev_pl, high, low, new_date_idx
            )
        else:
            # 全量计算
            new_ph, new_pl = pivots_fractal(high, low, k=15)
        
        # 缓存枢轴点
        if stock_idx not in self.pivot_cache:
            self.pivot_cache[stock_idx] = {}
        self.pivot_cache[stock_idx][new_date_idx] = (new_ph, new_pl)
        
        # ... 后续处理 ...
        
        return result
    
    def _update_pivots_incremental(self, prev_ph, prev_pl, high, low, new_idx):
        """
        增量更新枢轴点
        
        策略:
        1. 大部分枢轴点位置不变(相对索引+1
        2. 只需检查窗口边界的新增/移除
        """
        # 简化版:这里需要更复杂的逻辑
        # 实际应该检查最近k个点是否形成新枢轴
        k = 15
        last_points = high[-2*k:]
        
        # 检查最新点是否为枢轴
        if self._is_pivot_high(last_points, k):
            prev_ph = np.append(prev_ph, new_idx)
        if self._is_pivot_low(low[-2*k:], k):
            prev_pl = np.append(prev_pl, new_idx)
        
        return prev_ph, prev_pl

实施优先级

  • 中等(适合生产环境每日运行)
  • 首次启动无收益,后续每日收益显著

💾 Level 4: 数据结构优化预期提速1.5-2倍

4.1 使用Numpy结构化数组替代DataFrame

def detect_converging_triangle_batch_numpy(...):
    """
    使用numpy结构化数组替代pandas DataFrame
    
    优势:
    - 避免pandas对象开销
    - 内存连续cache友好
    - 直接返回numpy数组供后续处理
    
    预期提速30-50%(减少内存分配)
    """
    n_stocks, n_days = close_mtx.shape
    total_points = n_stocks * (end_day - start_day + 1)
    
    # 定义结果结构
    dtype = np.dtype([
        ('stock_idx', np.int32),
        ('date_idx', np.int32),
        ('is_valid', np.bool_),
        ('strength_up', np.float32),
        ('strength_down', np.float32),
        ('convergence_score', np.float32),
        ('volume_score', np.float32),
        ('geometry_score', np.float32),
        ('activity_score', np.float32),
        ('tilt_score', np.float32),
    ])
    
    # 预分配结果数组
    results = np.empty(total_points, dtype=dtype)
    
    idx = 0
    for stock_idx in range(n_stocks):
        for date_idx in range(start_day, end_day + 1):
            result = detect_single(stock_idx, date_idx, ...)
            
            # 直接写入结构化数组(无中间对象)
            results[idx]['stock_idx'] = stock_idx
            results[idx]['date_idx'] = date_idx
            results[idx]['is_valid'] = result.is_valid
            results[idx]['strength_up'] = result.strength_up
            # ...
            idx += 1
    
    return results  # 后续可转为DataFrame: pd.DataFrame(results)

4.2 内存映射文件(大规模数据)

def load_data_mmap(data_dir):
    """
    使用内存映射加载数据
    
    适用场景:
    - 数据量 > 可用内存
    - 多进程共享数据(避免复制)
    
    预期收益:
    - 加载时间:从秒级降到毫秒级
    - 内存占用0按需加载页面
    """
    import os
    
    # 保存为.npy格式支持mmap
    high_mmap = np.load(
        os.path.join(data_dir, 'high.npy'),
        mmap_mode='r'  # 只读模式
    )
    
    return high_mmap  # 返回mmap对象按需加载数据


def save_data_for_mmap(data, filepath):
    """保存数据为mmap兼容格式"""
    np.save(filepath, data)

优化实施路线图

Phase 1: 快速收益1-2周预期2-3倍提速

优先级P0

  1. 枢轴点检测向量化使用scipy或pandas rolling
  2. 边界线拟合缓存
  3. 消除简单的Python循环能向量化的先向量化

预期收益

  • 耗时30-60秒 → 10-20秒
  • 实施难度:低
  • 风险:低(向下兼容)

Phase 2: Numba加速2-3周预期5-10倍提速

优先级P1

  1. 核心函数Numba化pivots/fit_line/calc_strength
  2. 批量检测主循环Numba化
  3. ⚠️ 单元测试(确保数值精度一致)

预期收益

  • 耗时10-20秒 → 2-5秒
  • 实施难度:中
  • 风险:中(需要验证数值稳定性)

注意事项

# Numba限制
# 1. 不支持pandas DataFrame需改用numpy
# 2. 不支持字典/列表需改用numpy数组
# 3. 不支持动态类型(需显式类型标注)

# 解决方案:
# - 将pandas逻辑分离到外层
# - 核心计算用纯numpy实现
# - 添加类型标注

Phase 3: 并行化1周预期10-20倍提速

优先级P2如果Phase 2后仍需优化

  1. 多进程并行检测Pool/ProcessPoolExecutor
  2. 增量计算+缓存策略(生产环境每日运行)

预期收益

  • 耗时2-5秒 → <1秒
  • 实施难度如果已完成Phase 1-2
  • 风险:低

Phase 4: 极致优化(按需实施)

优先级P3(仅当前面优化不够):

  1. Cython重写核心模块C扩展
  2. GPU加速CUDA/cupy
  3. Rust扩展pyo3

预期收益

  • 耗时:<1秒 → 毫秒级
  • 实施难度:高
  • 风险:高(维护成本)

基准测试脚本

使用现有测试脚本

# 当前项目已有性能测试脚本
cd d:\project\technical-patterns-lab

# 运行小规模测试10只股票
python scripts/test_performance.py

# 查看profiling结果
pip install snakeviz
snakeviz outputs/performance/profile_*.prof

创建5000股测试脚本

# scripts/benchmark_5000_stocks.py
"""
5000只股票性能测试
"""
import time
import numpy as np
from src.converging_triangle import detect_converging_triangle_batch, ConvergingTriangleParams

def generate_synthetic_data(n_stocks=5000, n_days=500):
    """生成合成数据用于测试"""
    np.random.seed(42)
    
    base_price = 10 + np.random.randn(n_stocks, 1) * 2
    returns = np.random.randn(n_stocks, n_days) * 0.02
    
    close = base_price * np.cumprod(1 + returns, axis=1)
    high = close * (1 + np.abs(np.random.randn(n_stocks, n_days)) * 0.01)
    low = close * (1 - np.abs(np.random.randn(n_stocks, n_days)) * 0.01)
    open_ = close * (1 + np.random.randn(n_stocks, n_days) * 0.005)
    volume = np.random.randint(100000, 1000000, (n_stocks, n_days))
    
    return open_, high, low, close, volume

def benchmark_5000_stocks():
    print("=" * 80)
    print("5000只股票性能测试")
    print("=" * 80)
    
    # 生成数据
    print("\n生成测试数据...")
    open_, high, low, close, volume = generate_synthetic_data(5000, 500)
    print(f"数据形状: {close.shape}")
    
    # 配置参数
    params = ConvergingTriangleParams(window=240, pivot_k=15)
    
    # 测试
    print("\n开始检测...")
    start = time.time()
    
    df = detect_converging_triangle_batch(
        open_mtx=open_,
        high_mtx=high,
        low_mtx=low,
        close_mtx=close,
        volume_mtx=volume,
        params=params,
        start_day=239,
        end_day=499,
        only_valid=True,
        verbose=False
    )
    
    elapsed = time.time() - start
    
    # 结果
    print("\n" + "=" * 80)
    print("测试结果")
    print("=" * 80)
    print(f"总耗时: {elapsed:.2f} 秒")
    print(f"检测点数: {5000 * 261}")  # (500-239)
    print(f"速度: {5000*261/elapsed:.1f} 点/秒")
    print(f"有效形态: {len(df)}")
    
    # 评估
    print("\n" + "=" * 80)
    if elapsed < 5:
        print("✅ 性能优秀! (<5秒)")
    elif elapsed < 10:
        print("✔️  性能良好 (5-10秒)")
    elif elapsed < 30:
        print("⚠️  性能一般 (10-30秒), 建议优化")
    else:
        print("❌ 性能较差 (>30秒), 急需优化")

if __name__ == '__main__':
    benchmark_5000_stocks()

质量保证

回归测试

# tests/test_optimization_correctness.py
"""
优化正确性测试:确保优化后结果一致
"""
import numpy as np
import pytest

def test_optimized_vs_original():
    """对比优化版本和原版本的结果"""
    # 加载测试数据
    data = load_test_case()
    
    # 原版本
    result_orig = detect_original(data)
    
    # 优化版本
    result_opt = detect_optimized(data)
    
    # 验证结果一致(允许微小数值误差)
    np.testing.assert_allclose(
        result_orig['strength_up'],
        result_opt['strength_up'],
        rtol=1e-5,
        atol=1e-8
    )
    
    # 验证有效性标记完全一致
    assert (result_orig['is_valid'] == result_opt['is_valid']).all()

预期效果总结

优化阶段 实施难度 预期提速 累计提速 总耗时(5000股)
当前 - - 1x 30-60秒
Phase 1 2-3x 2-3x 10-20秒
Phase 2 5-10x 10-30x 2-5秒
Phase 3 10-20x 20-60x <1秒
Phase 4 50-100x 100-300x 毫秒级

推荐路径: Phase 1 → Phase 2 → 观察是否满足需求 → 按需进入Phase 3


立即行动

本周任务

  1. 基准测试当前性能

    python scripts/benchmark_5000_stocks.py
    
  2. 确认瓶颈函数

    python -m cProfile -o profile.stats scripts/benchmark_5000_stocks.py
    python -m pstats profile.stats
    # >> stats
    # >> sort cumulative
    # >> stats 20
    
  3. 优先实施:枢轴点检测向量化(收益最大、难度最低)


需要我帮你实现具体的优化代码吗?比如从枢轴点检测向量化开始?