# 收敛三角形检测算法优化方案 ## 目标场景 - **股票数量**: 5000只 - **运行频率**: 每天执行 - **当前耗时**: 10-60秒(预估) - **优化目标**: <5秒(10倍提速) - **质量要求**: 保持检测准确率 --- ## 性能瓶颈分析 ### 当前算法流程 ``` 检测流程(每只股票/每天): ┌─────────────────────────────────────────────────┐ │ 1. 枢轴点检测 (pivots_fractal_hybrid) 30% │ ← 热点1 │ 2. 边界线拟合 (fit_pivot_line) 25% │ ← 热点2 │ 3. 几何验证 (收敛度/触碰/斜率) 20% │ │ 4. 突破强度计算 (价格/成交量) 15% │ │ 5. DataFrame构建 + 数据复制 10% │ ← 热点3 └─────────────────────────────────────────────────┘ ``` ### 关键瓶颈 1. **枢轴点检测**: O(n*k) 滑动窗口,重复计算 2. **边界线拟合**: 迭代离群点移除,多次最小二乘 3. **Python循环**: 大量stock×day双层循环 4. **内存分配**: 频繁创建临时数组 --- ## 优化方案(分级实施) ### 🚀 Level 1: 向量化优化(预期提速2-3倍) #### 1.1 枢轴点检测向量化 **当前实现** (O(n*k) 滑动窗口): ```python def pivots_fractal(high, low, k=15): """滑动窗口查找局部极值""" ph, pl = [], [] for i in range(k, len(high) - k): # 检查左右k个点 if all(high[i] >= high[j] for j in range(i-k, i+k+1) if j != i): ph.append(i) if all(low[i] <= low[j] for j in range(i-k, i+k+1) if j != i): pl.append(i) return np.array(ph), np.array(pl) ``` **优化方案** (向量化): ```python def pivots_fractal_vectorized(high, low, k=15): """ 向量化枢轴点检测 核心思路: 1. 使用scipy.signal.argrelextrema一次性找所有极值 2. 或使用卷积/滚动窗口向量化计算 预期提速:3-5倍 """ from scipy.signal import argrelextrema # 找局部极大值(高点) ph = argrelextrema(high, np.greater_equal, order=k)[0] # 找局部极小值(低点) pl = argrelextrema(low, np.less_equal, order=k)[0] return ph, pl def pivots_fractal_rolling(high, low, k=15): """ 使用pandas滚动窗口实现 预期提速:2-4倍 """ import pandas as pd high_series = pd.Series(high) low_series = pd.Series(low) # 滚动窗口找最大/最小值索引 window = 2*k + 1 high_rolling_max = high_series.rolling(window, center=True).max() low_rolling_min = low_series.rolling(window, center=True).min() # 中心点等于窗口极值的位置即为枢轴点 ph = np.where((high_series == high_rolling_max) & (high_series.notna()))[0] pl = np.where((low_series == low_rolling_min) & (low_series.notna()))[0] return ph, pl ``` **实施**: - 文件:`src/converging_triangle.py` - 函数:`pivots_fractal()` 和 `pivots_fractal_hybrid()` - 向下兼容:保留原函数作为fallback --- #### 1.2 边界线拟合优化 **当前实现** (迭代离群点移除): ```python def fit_pivot_line(pivot_indices, pivot_values, mode="upper", max_iter=10): """迭代移除离群点,多次最小二乘拟合""" for iter in range(max_iter): a, b = np.polyfit(indices, values, 1) # 最小二乘 residuals = values - (a * indices + b) outliers = find_outliers(residuals) if no_outliers: break remove_outliers() return a, b ``` **优化方案A** (预计算+缓存): ```python def fit_pivot_line_cached(pivot_indices, pivot_values, mode="upper", cache=None): """ 缓存中间结果,避免重复计算 场景:相邻日期的枢轴点大部分重叠 策略:缓存最近N天的拟合结果,增量更新 预期提速:30-50%(针对滚动窗口场景) """ cache_key = (tuple(pivot_indices), tuple(pivot_values), mode) if cache and cache_key in cache: return cache[cache_key] # 原有拟合逻辑 result = _fit_pivot_line_core(pivot_indices, pivot_values, mode) if cache is not None: cache[cache_key] = result return result ``` **优化方案B** (快速拟合算法): ```python def fit_pivot_line_ransac(pivot_indices, pivot_values, mode="upper"): """ 使用RANSAC快速拟合(对离群点鲁棒) sklearn.linear_model.RANSACRegressor 预期提速:2-3倍 """ from sklearn.linear_model import RANSACRegressor X = pivot_indices.reshape(-1, 1) y = pivot_values ransac = RANSACRegressor( residual_threshold=threshold, max_trials=100, random_state=42 ) ransac.fit(X, y) a = ransac.estimator_.coef_[0] b = ransac.estimator_.intercept_ inlier_mask = ransac.inlier_mask_ return a, b, np.where(inlier_mask)[0] ``` **推荐**: 先实施方案A(缓存),简单且收益稳定 --- #### 1.3 消除Python循环 **当前实现** (双层循环): ```python def detect_converging_triangle_batch(...): results = [] for stock_idx in range(n_stocks): for date_idx in range(start_day, end_day + 1): result = detect_converging_triangle_single( stock_idx, date_idx, ... ) results.append(result) return pd.DataFrame(results) ``` **优化方案** (向量化外层): ```python def detect_converging_triangle_batch_vectorized(...): """ 外层循环向量化 策略: 1. 按date_idx分组,一次处理所有股票 2. 使用numpy广播并行计算 预期提速:1.5-2倍 """ all_results = [] for date_idx in range(start_day, end_day + 1): # 一次性处理所有股票在同一天的检测 # 提取窗口数据(向量化) window_start = date_idx - window + 1 high_windows = high_mtx[:, window_start:date_idx+1] # (n_stocks, window) low_windows = low_mtx[:, window_start:date_idx+1] # 批量检测枢轴点(利用numpy向量运算) pivots_batch = detect_pivots_batch(high_windows, low_windows) # 批量拟合边界线 fits_batch = fit_lines_batch(pivots_batch) # 批量计算强度 strengths_batch = calc_strengths_batch(fits_batch, ...) all_results.append(strengths_batch) return np.vstack(all_results) ``` **关键**: 需要重构算法,使单个函数能处理 (n_stocks, window) 维度 --- ### ⚡ Level 2: Numba JIT加速(预期提速5-10倍) #### 2.1 Numba加速核心函数 ```python from numba import jit, prange @jit(nopython=True, parallel=True, cache=True) def pivots_fractal_numba(high, low, k=15): """ Numba加速枢轴点检测 优势: - nopython=True: 编译为机器码 - parallel=True: 多线程并行 - cache=True: 缓存编译结果 预期提速:10-20倍(相比纯Python) """ n = len(high) ph_list = [] pl_list = [] for i in prange(k, n - k): # 并行循环 # 检查是否为高点 is_high_pivot = True for j in range(i - k, i + k + 1): if j != i and high[i] < high[j]: is_high_pivot = False break if is_high_pivot: ph_list.append(i) # 检查是否为低点 is_low_pivot = True for j in range(i - k, i + k + 1): if j != i and low[i] > low[j]: is_low_pivot = False break if is_low_pivot: pl_list.append(i) return np.array(ph_list), np.array(pl_list) @jit(nopython=True, cache=True) def fit_line_numba(x, y): """Numba加速最小二乘拟合""" n = len(x) x_mean = np.mean(x) y_mean = np.mean(y) numerator = np.sum((x - x_mean) * (y - y_mean)) denominator = np.sum((x - x_mean) ** 2) a = numerator / denominator b = y_mean - a * x_mean return a, b @jit(nopython=True, parallel=True) def detect_batch_numba( high_mtx, low_mtx, close_mtx, volume_mtx, window, k, start_day, end_day ): """ Numba加速批量检测 核心优化: - 消除Python对象开销 - 并行化最外层循环 - 预分配结果数组 预期提速:5-10倍 """ n_stocks, n_days = high_mtx.shape total_points = n_stocks * (end_day - start_day + 1) # 预分配结果数组 strength_up = np.zeros(total_points, dtype=np.float64) strength_down = np.zeros(total_points, dtype=np.float64) is_valid = np.zeros(total_points, dtype=np.bool_) # 并行处理每个检测点 for idx in prange(total_points): stock_idx = idx // (end_day - start_day + 1) day_offset = idx % (end_day - start_day + 1) date_idx = start_day + day_offset # 提取窗口数据 window_start = date_idx - window + 1 high_win = high_mtx[stock_idx, window_start:date_idx+1] low_win = low_mtx[stock_idx, window_start:date_idx+1] # 检测枢轴点 ph, pl = pivots_fractal_numba(high_win, low_win, k) # ... 后续处理 ... strength_up[idx] = computed_strength_up strength_down[idx] = computed_strength_down is_valid[idx] = computed_is_valid return strength_up, strength_down, is_valid ``` **实施要点**: - Numba要求函数纯数值计算,不能有pandas/字典等Python对象 - 首次运行会有JIT编译开销(~1-2秒),后续调用极快 - 需要将算法拆分为纯数值函数 --- ### 🔥 Level 3: 并行化+缓存策略(预期提速10-20倍) #### 3.1 多进程并行 ```python from multiprocessing import Pool, cpu_count from functools import partial def detect_stock_range(stock_indices, high_mtx, low_mtx, ...): """处理一批股票的检测任务""" results = [] for stock_idx in stock_indices: for date_idx in range(start_day, end_day + 1): result = detect_converging_triangle_single( stock_idx, date_idx, high_mtx, low_mtx, ... ) results.append(result) return results def detect_converging_triangle_parallel( high_mtx, low_mtx, close_mtx, volume_mtx, params, start_day, end_day, n_workers=None ): """ 多进程并行检测 策略: - 将5000只股票分成n_workers组 - 每个进程处理一组股票 - 主进程合并结果 预期提速:接近线性(8核约7倍) """ n_stocks = high_mtx.shape[0] n_workers = n_workers or cpu_count() - 1 # 分配任务(按股票索引分组) stock_groups = np.array_split(range(n_stocks), n_workers) # 创建部分函数(固定参数) detect_fn = partial( detect_stock_range, high_mtx=high_mtx, low_mtx=low_mtx, close_mtx=close_mtx, volume_mtx=volume_mtx, params=params, start_day=start_day, end_day=end_day ) # 并行执行 with Pool(n_workers) as pool: results_groups = pool.map(detect_fn, stock_groups) # 合并结果 all_results = [] for group_results in results_groups: all_results.extend(group_results) return pd.DataFrame(all_results) ``` **注意**: - 适合CPU密集型任务 - 需要足够内存(数据复制到子进程) - 5000只股票场景下,8-16核最优 --- #### 3.2 增量计算+缓存 ```python class IncrementalDetector: """ 增量检测器:缓存历史计算结果 场景:每天新增一个交易日,复用历史检测结果 策略: 1. 缓存最近N天的枢轴点/拟合线 2. 新增日期时只计算增量部分 3. LRU淘汰旧缓存 预期收益: - 首次运行:无加速 - 后续每日:提速5-10倍(只需计算最新day) """ def __init__(self, window=240, cache_size=100): self.window = window self.pivot_cache = {} # {stock_idx: {date_idx: (ph, pl)}} self.fit_cache = {} # {stock_idx: {date_idx: fit_result}} self.cache_size = cache_size def detect_incremental(self, stock_idx, new_date_idx, high, low, close, volume): """ 增量检测:利用缓存快速计算 逻辑: 1. 检查缓存中是否有前一天的结果 2. 如果有,只需: - 更新枢轴点(新增1天数据) - 复用历史拟合结果 3. 如果无,全量计算并缓存 """ prev_date_idx = new_date_idx - 1 # 尝试从缓存获取前一天结果 if stock_idx in self.pivot_cache and prev_date_idx in self.pivot_cache[stock_idx]: # 增量更新枢轴点 prev_ph, prev_pl = self.pivot_cache[stock_idx][prev_date_idx] new_ph, new_pl = self._update_pivots_incremental( prev_ph, prev_pl, high, low, new_date_idx ) else: # 全量计算 new_ph, new_pl = pivots_fractal(high, low, k=15) # 缓存枢轴点 if stock_idx not in self.pivot_cache: self.pivot_cache[stock_idx] = {} self.pivot_cache[stock_idx][new_date_idx] = (new_ph, new_pl) # ... 后续处理 ... return result def _update_pivots_incremental(self, prev_ph, prev_pl, high, low, new_idx): """ 增量更新枢轴点 策略: 1. 大部分枢轴点位置不变(相对索引+1) 2. 只需检查窗口边界的新增/移除 """ # 简化版:这里需要更复杂的逻辑 # 实际应该检查最近k个点是否形成新枢轴 k = 15 last_points = high[-2*k:] # 检查最新点是否为枢轴 if self._is_pivot_high(last_points, k): prev_ph = np.append(prev_ph, new_idx) if self._is_pivot_low(low[-2*k:], k): prev_pl = np.append(prev_pl, new_idx) return prev_ph, prev_pl ``` **实施优先级**: - 中等(适合生产环境每日运行) - 首次启动无收益,后续每日收益显著 --- ### 💾 Level 4: 数据结构优化(预期提速1.5-2倍) #### 4.1 使用Numpy结构化数组替代DataFrame ```python def detect_converging_triangle_batch_numpy(...): """ 使用numpy结构化数组替代pandas DataFrame 优势: - 避免pandas对象开销 - 内存连续,cache友好 - 直接返回numpy数组供后续处理 预期提速:30-50%(减少内存分配) """ n_stocks, n_days = close_mtx.shape total_points = n_stocks * (end_day - start_day + 1) # 定义结果结构 dtype = np.dtype([ ('stock_idx', np.int32), ('date_idx', np.int32), ('is_valid', np.bool_), ('strength_up', np.float32), ('strength_down', np.float32), ('convergence_score', np.float32), ('volume_score', np.float32), ('geometry_score', np.float32), ('activity_score', np.float32), ('tilt_score', np.float32), ]) # 预分配结果数组 results = np.empty(total_points, dtype=dtype) idx = 0 for stock_idx in range(n_stocks): for date_idx in range(start_day, end_day + 1): result = detect_single(stock_idx, date_idx, ...) # 直接写入结构化数组(无中间对象) results[idx]['stock_idx'] = stock_idx results[idx]['date_idx'] = date_idx results[idx]['is_valid'] = result.is_valid results[idx]['strength_up'] = result.strength_up # ... idx += 1 return results # 后续可转为DataFrame: pd.DataFrame(results) ``` --- #### 4.2 内存映射文件(大规模数据) ```python def load_data_mmap(data_dir): """ 使用内存映射加载数据 适用场景: - 数据量 > 可用内存 - 多进程共享数据(避免复制) 预期收益: - 加载时间:从秒级降到毫秒级 - 内存占用:0(按需加载页面) """ import os # 保存为.npy格式(支持mmap) high_mmap = np.load( os.path.join(data_dir, 'high.npy'), mmap_mode='r' # 只读模式 ) return high_mmap # 返回mmap对象,按需加载数据 def save_data_for_mmap(data, filepath): """保存数据为mmap兼容格式""" np.save(filepath, data) ``` --- ## 优化实施路线图 ### Phase 1: 快速收益(1-2周,预期2-3倍提速) **优先级P0**: 1. ✅ 枢轴点检测向量化(使用scipy或pandas rolling) 2. ✅ 边界线拟合缓存 3. ✅ 消除简单的Python循环(能向量化的先向量化) **预期收益**: - 耗时:30-60秒 → 10-20秒 - 实施难度:低 - 风险:低(向下兼容) --- ### Phase 2: Numba加速(2-3周,预期5-10倍提速) **优先级P1**: 1. ✅ 核心函数Numba化(pivots/fit_line/calc_strength) 2. ✅ 批量检测主循环Numba化 3. ⚠️ 单元测试(确保数值精度一致) **预期收益**: - 耗时:10-20秒 → 2-5秒 - 实施难度:中 - 风险:中(需要验证数值稳定性) **注意事项**: ```python # Numba限制: # 1. 不支持pandas DataFrame(需改用numpy) # 2. 不支持字典/列表(需改用numpy数组) # 3. 不支持动态类型(需显式类型标注) # 解决方案: # - 将pandas逻辑分离到外层 # - 核心计算用纯numpy实现 # - 添加类型标注 ``` --- ### Phase 3: 并行化(1周,预期10-20倍提速) **优先级P2**(如果Phase 2后仍需优化): 1. ✅ 多进程并行检测(Pool/ProcessPoolExecutor) 2. ✅ 增量计算+缓存策略(生产环境每日运行) **预期收益**: - 耗时:2-5秒 → <1秒 - 实施难度:低(如果已完成Phase 1-2) - 风险:低 --- ### Phase 4: 极致优化(按需实施) **优先级P3**(仅当前面优化不够): 1. Cython重写核心模块(C扩展) 2. GPU加速(CUDA/cupy) 3. Rust扩展(pyo3) **预期收益**: - 耗时:<1秒 → 毫秒级 - 实施难度:高 - 风险:高(维护成本) --- ## 基准测试脚本 ### 使用现有测试脚本 ```bash # 当前项目已有性能测试脚本 cd d:\project\technical-patterns-lab # 运行小规模测试(10只股票) python scripts/test_performance.py # 查看profiling结果 pip install snakeviz snakeviz outputs/performance/profile_*.prof ``` ### 创建5000股测试脚本 ```python # scripts/benchmark_5000_stocks.py """ 5000只股票性能测试 """ import time import numpy as np from src.converging_triangle import detect_converging_triangle_batch, ConvergingTriangleParams def generate_synthetic_data(n_stocks=5000, n_days=500): """生成合成数据用于测试""" np.random.seed(42) base_price = 10 + np.random.randn(n_stocks, 1) * 2 returns = np.random.randn(n_stocks, n_days) * 0.02 close = base_price * np.cumprod(1 + returns, axis=1) high = close * (1 + np.abs(np.random.randn(n_stocks, n_days)) * 0.01) low = close * (1 - np.abs(np.random.randn(n_stocks, n_days)) * 0.01) open_ = close * (1 + np.random.randn(n_stocks, n_days) * 0.005) volume = np.random.randint(100000, 1000000, (n_stocks, n_days)) return open_, high, low, close, volume def benchmark_5000_stocks(): print("=" * 80) print("5000只股票性能测试") print("=" * 80) # 生成数据 print("\n生成测试数据...") open_, high, low, close, volume = generate_synthetic_data(5000, 500) print(f"数据形状: {close.shape}") # 配置参数 params = ConvergingTriangleParams(window=240, pivot_k=15) # 测试 print("\n开始检测...") start = time.time() df = detect_converging_triangle_batch( open_mtx=open_, high_mtx=high, low_mtx=low, close_mtx=close, volume_mtx=volume, params=params, start_day=239, end_day=499, only_valid=True, verbose=False ) elapsed = time.time() - start # 结果 print("\n" + "=" * 80) print("测试结果") print("=" * 80) print(f"总耗时: {elapsed:.2f} 秒") print(f"检测点数: {5000 * 261}") # (500-239) print(f"速度: {5000*261/elapsed:.1f} 点/秒") print(f"有效形态: {len(df)}") # 评估 print("\n" + "=" * 80) if elapsed < 5: print("✅ 性能优秀! (<5秒)") elif elapsed < 10: print("✔️ 性能良好 (5-10秒)") elif elapsed < 30: print("⚠️ 性能一般 (10-30秒), 建议优化") else: print("❌ 性能较差 (>30秒), 急需优化") if __name__ == '__main__': benchmark_5000_stocks() ``` --- ## 质量保证 ### 回归测试 ```python # tests/test_optimization_correctness.py """ 优化正确性测试:确保优化后结果一致 """ import numpy as np import pytest def test_optimized_vs_original(): """对比优化版本和原版本的结果""" # 加载测试数据 data = load_test_case() # 原版本 result_orig = detect_original(data) # 优化版本 result_opt = detect_optimized(data) # 验证结果一致(允许微小数值误差) np.testing.assert_allclose( result_orig['strength_up'], result_opt['strength_up'], rtol=1e-5, atol=1e-8 ) # 验证有效性标记完全一致 assert (result_orig['is_valid'] == result_opt['is_valid']).all() ``` --- ## 预期效果总结 | 优化阶段 | 实施难度 | 预期提速 | 累计提速 | 总耗时(5000股) | |---------|---------|---------|---------|---------------| | **当前** | - | - | 1x | 30-60秒 | | Phase 1 | 低 | 2-3x | 2-3x | 10-20秒 | | Phase 2 | 中 | 5-10x | 10-30x | 2-5秒 | | Phase 3 | 低 | 10-20x | 20-60x | <1秒 | | Phase 4 | 高 | 50-100x | 100-300x | 毫秒级 | **推荐路径**: Phase 1 → Phase 2 → 观察是否满足需求 → 按需进入Phase 3 --- ## 立即行动 ### 本周任务 1. **基准测试当前性能** ```bash python scripts/benchmark_5000_stocks.py ``` 2. **确认瓶颈函数** ```bash python -m cProfile -o profile.stats scripts/benchmark_5000_stocks.py python -m pstats profile.stats # >> stats # >> sort cumulative # >> stats 20 ``` 3. **优先实施**:枢轴点检测向量化(收益最大、难度最低) --- 需要我帮你实现具体的优化代码吗?比如从枢轴点检测向量化开始?