technical-patterns-lab/discuss/20260130-检测算法优化方案.md
褚宏光 bf6baa5483 Add scoring module and enhance HTML viewer with standardization
- Add scripts/scoring/ module with normalizer, sensitivity analysis, and config
- Enhance stock_viewer.html with standardized scoring display
- Add integration tests and normalization verification scripts
- Add documentation for standardization implementation and usage guides
- Add data distribution analysis reports for strength scoring dimensions
- Update discussion documents with algorithm optimization plans
2026-01-30 18:43:37 +08:00

838 lines
22 KiB
Markdown
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 收敛三角形检测算法优化方案
## 目标场景
- **股票数量**: 5000只
- **运行频率**: 每天执行
- **当前耗时**: 10-60秒预估
- **优化目标**: <5秒10倍提速
- **质量要求**: 保持检测准确率
---
## 性能瓶颈分析
### 当前算法流程
```
检测流程(每只股票/每天):
┌─────────────────────────────────────────────────┐
│ 1. 枢轴点检测 (pivots_fractal_hybrid) 30% │ ← 热点1
│ 2. 边界线拟合 (fit_pivot_line) 25% │ ← 热点2
│ 3. 几何验证 (收敛度/触碰/斜率) 20% │
│ 4. 突破强度计算 (价格/成交量) 15% │
│ 5. DataFrame构建 + 数据复制 10% │ ← 热点3
└─────────────────────────────────────────────────┘
```
### 关键瓶颈
1. **枢轴点检测**: O(n*k) 滑动窗口重复计算
2. **边界线拟合**: 迭代离群点移除多次最小二乘
3. **Python循环**: 大量stock×day双层循环
4. **内存分配**: 频繁创建临时数组
---
## 优化方案(分级实施)
### 🚀 Level 1: 向量化优化预期提速2-3倍
#### 1.1 枢轴点检测向量化
**当前实现** (O(n*k) 滑动窗口):
```python
def pivots_fractal(high, low, k=15):
"""滑动窗口查找局部极值"""
ph, pl = [], []
for i in range(k, len(high) - k):
# 检查左右k个点
if all(high[i] >= high[j] for j in range(i-k, i+k+1) if j != i):
ph.append(i)
if all(low[i] <= low[j] for j in range(i-k, i+k+1) if j != i):
pl.append(i)
return np.array(ph), np.array(pl)
```
**优化方案** (向量化):
```python
def pivots_fractal_vectorized(high, low, k=15):
"""
向量化枢轴点检测
核心思路:
1. 使用scipy.signal.argrelextrema一次性找所有极值
2. 或使用卷积/滚动窗口向量化计算
预期提速3-5倍
"""
from scipy.signal import argrelextrema
# 找局部极大值(高点)
ph = argrelextrema(high, np.greater_equal, order=k)[0]
# 找局部极小值(低点)
pl = argrelextrema(low, np.less_equal, order=k)[0]
return ph, pl
def pivots_fractal_rolling(high, low, k=15):
"""
使用pandas滚动窗口实现
预期提速2-4倍
"""
import pandas as pd
high_series = pd.Series(high)
low_series = pd.Series(low)
# 滚动窗口找最大/最小值索引
window = 2*k + 1
high_rolling_max = high_series.rolling(window, center=True).max()
low_rolling_min = low_series.rolling(window, center=True).min()
# 中心点等于窗口极值的位置即为枢轴点
ph = np.where((high_series == high_rolling_max) & (high_series.notna()))[0]
pl = np.where((low_series == low_rolling_min) & (low_series.notna()))[0]
return ph, pl
```
**实施**
- 文件`src/converging_triangle.py`
- 函数`pivots_fractal()` `pivots_fractal_hybrid()`
- 向下兼容保留原函数作为fallback
---
#### 1.2 边界线拟合优化
**当前实现** (迭代离群点移除):
```python
def fit_pivot_line(pivot_indices, pivot_values, mode="upper", max_iter=10):
"""迭代移除离群点,多次最小二乘拟合"""
for iter in range(max_iter):
a, b = np.polyfit(indices, values, 1) # 最小二乘
residuals = values - (a * indices + b)
outliers = find_outliers(residuals)
if no_outliers: break
remove_outliers()
return a, b
```
**优化方案A** (预计算+缓存):
```python
def fit_pivot_line_cached(pivot_indices, pivot_values, mode="upper", cache=None):
"""
缓存中间结果,避免重复计算
场景:相邻日期的枢轴点大部分重叠
策略缓存最近N天的拟合结果增量更新
预期提速30-50%(针对滚动窗口场景)
"""
cache_key = (tuple(pivot_indices), tuple(pivot_values), mode)
if cache and cache_key in cache:
return cache[cache_key]
# 原有拟合逻辑
result = _fit_pivot_line_core(pivot_indices, pivot_values, mode)
if cache is not None:
cache[cache_key] = result
return result
```
**优化方案B** (快速拟合算法):
```python
def fit_pivot_line_ransac(pivot_indices, pivot_values, mode="upper"):
"""
使用RANSAC快速拟合对离群点鲁棒
sklearn.linear_model.RANSACRegressor
预期提速2-3倍
"""
from sklearn.linear_model import RANSACRegressor
X = pivot_indices.reshape(-1, 1)
y = pivot_values
ransac = RANSACRegressor(
residual_threshold=threshold,
max_trials=100,
random_state=42
)
ransac.fit(X, y)
a = ransac.estimator_.coef_[0]
b = ransac.estimator_.intercept_
inlier_mask = ransac.inlier_mask_
return a, b, np.where(inlier_mask)[0]
```
**推荐**: 先实施方案A缓存简单且收益稳定
---
#### 1.3 消除Python循环
**当前实现** (双层循环):
```python
def detect_converging_triangle_batch(...):
results = []
for stock_idx in range(n_stocks):
for date_idx in range(start_day, end_day + 1):
result = detect_converging_triangle_single(
stock_idx, date_idx, ...
)
results.append(result)
return pd.DataFrame(results)
```
**优化方案** (向量化外层):
```python
def detect_converging_triangle_batch_vectorized(...):
"""
外层循环向量化
策略:
1. 按date_idx分组一次处理所有股票
2. 使用numpy广播并行计算
预期提速1.5-2倍
"""
all_results = []
for date_idx in range(start_day, end_day + 1):
# 一次性处理所有股票在同一天的检测
# 提取窗口数据(向量化)
window_start = date_idx - window + 1
high_windows = high_mtx[:, window_start:date_idx+1] # (n_stocks, window)
low_windows = low_mtx[:, window_start:date_idx+1]
# 批量检测枢轴点利用numpy向量运算
pivots_batch = detect_pivots_batch(high_windows, low_windows)
# 批量拟合边界线
fits_batch = fit_lines_batch(pivots_batch)
# 批量计算强度
strengths_batch = calc_strengths_batch(fits_batch, ...)
all_results.append(strengths_batch)
return np.vstack(all_results)
```
**关键**: 需要重构算法使单个函数能处理 (n_stocks, window) 维度
---
### ⚡ Level 2: Numba JIT加速预期提速5-10倍
#### 2.1 Numba加速核心函数
```python
from numba import jit, prange
@jit(nopython=True, parallel=True, cache=True)
def pivots_fractal_numba(high, low, k=15):
"""
Numba加速枢轴点检测
优势:
- nopython=True: 编译为机器码
- parallel=True: 多线程并行
- cache=True: 缓存编译结果
预期提速10-20倍相比纯Python
"""
n = len(high)
ph_list = []
pl_list = []
for i in prange(k, n - k): # 并行循环
# 检查是否为高点
is_high_pivot = True
for j in range(i - k, i + k + 1):
if j != i and high[i] < high[j]:
is_high_pivot = False
break
if is_high_pivot:
ph_list.append(i)
# 检查是否为低点
is_low_pivot = True
for j in range(i - k, i + k + 1):
if j != i and low[i] > low[j]:
is_low_pivot = False
break
if is_low_pivot:
pl_list.append(i)
return np.array(ph_list), np.array(pl_list)
@jit(nopython=True, cache=True)
def fit_line_numba(x, y):
"""Numba加速最小二乘拟合"""
n = len(x)
x_mean = np.mean(x)
y_mean = np.mean(y)
numerator = np.sum((x - x_mean) * (y - y_mean))
denominator = np.sum((x - x_mean) ** 2)
a = numerator / denominator
b = y_mean - a * x_mean
return a, b
@jit(nopython=True, parallel=True)
def detect_batch_numba(
high_mtx, low_mtx, close_mtx, volume_mtx,
window, k, start_day, end_day
):
"""
Numba加速批量检测
核心优化:
- 消除Python对象开销
- 并行化最外层循环
- 预分配结果数组
预期提速5-10倍
"""
n_stocks, n_days = high_mtx.shape
total_points = n_stocks * (end_day - start_day + 1)
# 预分配结果数组
strength_up = np.zeros(total_points, dtype=np.float64)
strength_down = np.zeros(total_points, dtype=np.float64)
is_valid = np.zeros(total_points, dtype=np.bool_)
# 并行处理每个检测点
for idx in prange(total_points):
stock_idx = idx // (end_day - start_day + 1)
day_offset = idx % (end_day - start_day + 1)
date_idx = start_day + day_offset
# 提取窗口数据
window_start = date_idx - window + 1
high_win = high_mtx[stock_idx, window_start:date_idx+1]
low_win = low_mtx[stock_idx, window_start:date_idx+1]
# 检测枢轴点
ph, pl = pivots_fractal_numba(high_win, low_win, k)
# ... 后续处理 ...
strength_up[idx] = computed_strength_up
strength_down[idx] = computed_strength_down
is_valid[idx] = computed_is_valid
return strength_up, strength_down, is_valid
```
**实施要点**
- Numba要求函数纯数值计算不能有pandas/字典等Python对象
- 首次运行会有JIT编译开销~1-2秒后续调用极快
- 需要将算法拆分为纯数值函数
---
### 🔥 Level 3: 并行化+缓存策略预期提速10-20倍
#### 3.1 多进程并行
```python
from multiprocessing import Pool, cpu_count
from functools import partial
def detect_stock_range(stock_indices, high_mtx, low_mtx, ...):
"""处理一批股票的检测任务"""
results = []
for stock_idx in stock_indices:
for date_idx in range(start_day, end_day + 1):
result = detect_converging_triangle_single(
stock_idx, date_idx, high_mtx, low_mtx, ...
)
results.append(result)
return results
def detect_converging_triangle_parallel(
high_mtx, low_mtx, close_mtx, volume_mtx,
params, start_day, end_day,
n_workers=None
):
"""
多进程并行检测
策略:
- 将5000只股票分成n_workers组
- 每个进程处理一组股票
- 主进程合并结果
预期提速接近线性8核约7倍
"""
n_stocks = high_mtx.shape[0]
n_workers = n_workers or cpu_count() - 1
# 分配任务(按股票索引分组)
stock_groups = np.array_split(range(n_stocks), n_workers)
# 创建部分函数(固定参数)
detect_fn = partial(
detect_stock_range,
high_mtx=high_mtx,
low_mtx=low_mtx,
close_mtx=close_mtx,
volume_mtx=volume_mtx,
params=params,
start_day=start_day,
end_day=end_day
)
# 并行执行
with Pool(n_workers) as pool:
results_groups = pool.map(detect_fn, stock_groups)
# 合并结果
all_results = []
for group_results in results_groups:
all_results.extend(group_results)
return pd.DataFrame(all_results)
```
**注意**
- 适合CPU密集型任务
- 需要足够内存数据复制到子进程
- 5000只股票场景下8-16核最优
---
#### 3.2 增量计算+缓存
```python
class IncrementalDetector:
"""
增量检测器:缓存历史计算结果
场景:每天新增一个交易日,复用历史检测结果
策略:
1. 缓存最近N天的枢轴点/拟合线
2. 新增日期时只计算增量部分
3. LRU淘汰旧缓存
预期收益:
- 首次运行:无加速
- 后续每日提速5-10倍只需计算最新day
"""
def __init__(self, window=240, cache_size=100):
self.window = window
self.pivot_cache = {} # {stock_idx: {date_idx: (ph, pl)}}
self.fit_cache = {} # {stock_idx: {date_idx: fit_result}}
self.cache_size = cache_size
def detect_incremental(self, stock_idx, new_date_idx, high, low, close, volume):
"""
增量检测:利用缓存快速计算
逻辑:
1. 检查缓存中是否有前一天的结果
2. 如果有,只需:
- 更新枢轴点新增1天数据
- 复用历史拟合结果
3. 如果无,全量计算并缓存
"""
prev_date_idx = new_date_idx - 1
# 尝试从缓存获取前一天结果
if stock_idx in self.pivot_cache and prev_date_idx in self.pivot_cache[stock_idx]:
# 增量更新枢轴点
prev_ph, prev_pl = self.pivot_cache[stock_idx][prev_date_idx]
new_ph, new_pl = self._update_pivots_incremental(
prev_ph, prev_pl, high, low, new_date_idx
)
else:
# 全量计算
new_ph, new_pl = pivots_fractal(high, low, k=15)
# 缓存枢轴点
if stock_idx not in self.pivot_cache:
self.pivot_cache[stock_idx] = {}
self.pivot_cache[stock_idx][new_date_idx] = (new_ph, new_pl)
# ... 后续处理 ...
return result
def _update_pivots_incremental(self, prev_ph, prev_pl, high, low, new_idx):
"""
增量更新枢轴点
策略:
1. 大部分枢轴点位置不变(相对索引+1
2. 只需检查窗口边界的新增/移除
"""
# 简化版:这里需要更复杂的逻辑
# 实际应该检查最近k个点是否形成新枢轴
k = 15
last_points = high[-2*k:]
# 检查最新点是否为枢轴
if self._is_pivot_high(last_points, k):
prev_ph = np.append(prev_ph, new_idx)
if self._is_pivot_low(low[-2*k:], k):
prev_pl = np.append(prev_pl, new_idx)
return prev_ph, prev_pl
```
**实施优先级**
- 中等适合生产环境每日运行
- 首次启动无收益后续每日收益显著
---
### 💾 Level 4: 数据结构优化预期提速1.5-2倍
#### 4.1 使用Numpy结构化数组替代DataFrame
```python
def detect_converging_triangle_batch_numpy(...):
"""
使用numpy结构化数组替代pandas DataFrame
优势:
- 避免pandas对象开销
- 内存连续cache友好
- 直接返回numpy数组供后续处理
预期提速30-50%(减少内存分配)
"""
n_stocks, n_days = close_mtx.shape
total_points = n_stocks * (end_day - start_day + 1)
# 定义结果结构
dtype = np.dtype([
('stock_idx', np.int32),
('date_idx', np.int32),
('is_valid', np.bool_),
('strength_up', np.float32),
('strength_down', np.float32),
('convergence_score', np.float32),
('volume_score', np.float32),
('geometry_score', np.float32),
('activity_score', np.float32),
('tilt_score', np.float32),
])
# 预分配结果数组
results = np.empty(total_points, dtype=dtype)
idx = 0
for stock_idx in range(n_stocks):
for date_idx in range(start_day, end_day + 1):
result = detect_single(stock_idx, date_idx, ...)
# 直接写入结构化数组(无中间对象)
results[idx]['stock_idx'] = stock_idx
results[idx]['date_idx'] = date_idx
results[idx]['is_valid'] = result.is_valid
results[idx]['strength_up'] = result.strength_up
# ...
idx += 1
return results # 后续可转为DataFrame: pd.DataFrame(results)
```
---
#### 4.2 内存映射文件(大规模数据)
```python
def load_data_mmap(data_dir):
"""
使用内存映射加载数据
适用场景:
- 数据量 > 可用内存
- 多进程共享数据(避免复制)
预期收益:
- 加载时间:从秒级降到毫秒级
- 内存占用0按需加载页面
"""
import os
# 保存为.npy格式支持mmap
high_mmap = np.load(
os.path.join(data_dir, 'high.npy'),
mmap_mode='r' # 只读模式
)
return high_mmap # 返回mmap对象按需加载数据
def save_data_for_mmap(data, filepath):
"""保存数据为mmap兼容格式"""
np.save(filepath, data)
```
---
## 优化实施路线图
### Phase 1: 快速收益1-2周预期2-3倍提速
**优先级P0**
1. 枢轴点检测向量化使用scipy或pandas rolling
2. 边界线拟合缓存
3. 消除简单的Python循环能向量化的先向量化
**预期收益**
- 耗时30-60秒 10-20秒
- 实施难度
- 风险向下兼容
---
### Phase 2: Numba加速2-3周预期5-10倍提速
**优先级P1**
1. 核心函数Numba化pivots/fit_line/calc_strength
2. 批量检测主循环Numba化
3. 单元测试确保数值精度一致
**预期收益**
- 耗时10-20秒 2-5秒
- 实施难度
- 风险需要验证数值稳定性
**注意事项**
```python
# Numba限制
# 1. 不支持pandas DataFrame需改用numpy
# 2. 不支持字典/列表需改用numpy数组
# 3. 不支持动态类型(需显式类型标注)
# 解决方案:
# - 将pandas逻辑分离到外层
# - 核心计算用纯numpy实现
# - 添加类型标注
```
---
### Phase 3: 并行化1周预期10-20倍提速
**优先级P2**如果Phase 2后仍需优化
1. 多进程并行检测Pool/ProcessPoolExecutor
2. 增量计算+缓存策略生产环境每日运行
**预期收益**
- 耗时2-5秒 <1秒
- 实施难度如果已完成Phase 1-2
- 风险
---
### Phase 4: 极致优化(按需实施)
**优先级P3**仅当前面优化不够
1. Cython重写核心模块C扩展
2. GPU加速CUDA/cupy
3. Rust扩展pyo3
**预期收益**
- 耗时<1秒 毫秒级
- 实施难度
- 风险维护成本
---
## 基准测试脚本
### 使用现有测试脚本
```bash
# 当前项目已有性能测试脚本
cd d:\project\technical-patterns-lab
# 运行小规模测试10只股票
python scripts/test_performance.py
# 查看profiling结果
pip install snakeviz
snakeviz outputs/performance/profile_*.prof
```
### 创建5000股测试脚本
```python
# scripts/benchmark_5000_stocks.py
"""
5000只股票性能测试
"""
import time
import numpy as np
from src.converging_triangle import detect_converging_triangle_batch, ConvergingTriangleParams
def generate_synthetic_data(n_stocks=5000, n_days=500):
"""生成合成数据用于测试"""
np.random.seed(42)
base_price = 10 + np.random.randn(n_stocks, 1) * 2
returns = np.random.randn(n_stocks, n_days) * 0.02
close = base_price * np.cumprod(1 + returns, axis=1)
high = close * (1 + np.abs(np.random.randn(n_stocks, n_days)) * 0.01)
low = close * (1 - np.abs(np.random.randn(n_stocks, n_days)) * 0.01)
open_ = close * (1 + np.random.randn(n_stocks, n_days) * 0.005)
volume = np.random.randint(100000, 1000000, (n_stocks, n_days))
return open_, high, low, close, volume
def benchmark_5000_stocks():
print("=" * 80)
print("5000只股票性能测试")
print("=" * 80)
# 生成数据
print("\n生成测试数据...")
open_, high, low, close, volume = generate_synthetic_data(5000, 500)
print(f"数据形状: {close.shape}")
# 配置参数
params = ConvergingTriangleParams(window=240, pivot_k=15)
# 测试
print("\n开始检测...")
start = time.time()
df = detect_converging_triangle_batch(
open_mtx=open_,
high_mtx=high,
low_mtx=low,
close_mtx=close,
volume_mtx=volume,
params=params,
start_day=239,
end_day=499,
only_valid=True,
verbose=False
)
elapsed = time.time() - start
# 结果
print("\n" + "=" * 80)
print("测试结果")
print("=" * 80)
print(f"总耗时: {elapsed:.2f} 秒")
print(f"检测点数: {5000 * 261}") # (500-239)
print(f"速度: {5000*261/elapsed:.1f} 点/秒")
print(f"有效形态: {len(df)}")
# 评估
print("\n" + "=" * 80)
if elapsed < 5:
print("✅ 性能优秀! (<5秒)")
elif elapsed < 10:
print("✔️ 性能良好 (5-10秒)")
elif elapsed < 30:
print("⚠️ 性能一般 (10-30秒), 建议优化")
else:
print("❌ 性能较差 (>30秒), 急需优化")
if __name__ == '__main__':
benchmark_5000_stocks()
```
---
## 质量保证
### 回归测试
```python
# tests/test_optimization_correctness.py
"""
优化正确性测试:确保优化后结果一致
"""
import numpy as np
import pytest
def test_optimized_vs_original():
"""对比优化版本和原版本的结果"""
# 加载测试数据
data = load_test_case()
# 原版本
result_orig = detect_original(data)
# 优化版本
result_opt = detect_optimized(data)
# 验证结果一致(允许微小数值误差)
np.testing.assert_allclose(
result_orig['strength_up'],
result_opt['strength_up'],
rtol=1e-5,
atol=1e-8
)
# 验证有效性标记完全一致
assert (result_orig['is_valid'] == result_opt['is_valid']).all()
```
---
## 预期效果总结
| 优化阶段 | 实施难度 | 预期提速 | 累计提速 | 总耗时(5000股) |
|---------|---------|---------|---------|---------------|
| **当前** | - | - | 1x | 30-60秒 |
| Phase 1 | | 2-3x | 2-3x | 10-20秒 |
| Phase 2 | | 5-10x | 10-30x | 2-5秒 |
| Phase 3 | | 10-20x | 20-60x | <1秒 |
| Phase 4 | | 50-100x | 100-300x | 毫秒级 |
**推荐路径**: Phase 1 Phase 2 观察是否满足需求 按需进入Phase 3
---
## 立即行动
### 本周任务
1. **基准测试当前性能**
```bash
python scripts/benchmark_5000_stocks.py
```
2. **确认瓶颈函数**
```bash
python -m cProfile -o profile.stats scripts/benchmark_5000_stocks.py
python -m pstats profile.stats
# >> stats
# >> sort cumulative
# >> stats 20
```
3. **优先实施**枢轴点检测向量化收益最大难度最低
---
需要我帮你实现具体的优化代码吗比如从枢轴点检测向量化开始