- Add scripts/scoring/ module with normalizer, sensitivity analysis, and config - Enhance stock_viewer.html with standardized scoring display - Add integration tests and normalization verification scripts - Add documentation for standardization implementation and usage guides - Add data distribution analysis reports for strength scoring dimensions - Update discussion documents with algorithm optimization plans
838 lines
22 KiB
Markdown
838 lines
22 KiB
Markdown
# 收敛三角形检测算法优化方案
|
||
|
||
## 目标场景
|
||
|
||
- **股票数量**: 5000只
|
||
- **运行频率**: 每天执行
|
||
- **当前耗时**: 10-60秒(预估)
|
||
- **优化目标**: <5秒(10倍提速)
|
||
- **质量要求**: 保持检测准确率
|
||
|
||
---
|
||
|
||
## 性能瓶颈分析
|
||
|
||
### 当前算法流程
|
||
|
||
```
|
||
检测流程(每只股票/每天):
|
||
┌─────────────────────────────────────────────────┐
|
||
│ 1. 枢轴点检测 (pivots_fractal_hybrid) 30% │ ← 热点1
|
||
│ 2. 边界线拟合 (fit_pivot_line) 25% │ ← 热点2
|
||
│ 3. 几何验证 (收敛度/触碰/斜率) 20% │
|
||
│ 4. 突破强度计算 (价格/成交量) 15% │
|
||
│ 5. DataFrame构建 + 数据复制 10% │ ← 热点3
|
||
└─────────────────────────────────────────────────┘
|
||
```
|
||
|
||
### 关键瓶颈
|
||
|
||
1. **枢轴点检测**: O(n*k) 滑动窗口,重复计算
|
||
2. **边界线拟合**: 迭代离群点移除,多次最小二乘
|
||
3. **Python循环**: 大量stock×day双层循环
|
||
4. **内存分配**: 频繁创建临时数组
|
||
|
||
---
|
||
|
||
## 优化方案(分级实施)
|
||
|
||
### 🚀 Level 1: 向量化优化(预期提速2-3倍)
|
||
|
||
#### 1.1 枢轴点检测向量化
|
||
|
||
**当前实现** (O(n*k) 滑动窗口):
|
||
```python
|
||
def pivots_fractal(high, low, k=15):
|
||
"""滑动窗口查找局部极值"""
|
||
ph, pl = [], []
|
||
for i in range(k, len(high) - k):
|
||
# 检查左右k个点
|
||
if all(high[i] >= high[j] for j in range(i-k, i+k+1) if j != i):
|
||
ph.append(i)
|
||
if all(low[i] <= low[j] for j in range(i-k, i+k+1) if j != i):
|
||
pl.append(i)
|
||
return np.array(ph), np.array(pl)
|
||
```
|
||
|
||
**优化方案** (向量化):
|
||
```python
|
||
def pivots_fractal_vectorized(high, low, k=15):
|
||
"""
|
||
向量化枢轴点检测
|
||
|
||
核心思路:
|
||
1. 使用scipy.signal.argrelextrema一次性找所有极值
|
||
2. 或使用卷积/滚动窗口向量化计算
|
||
|
||
预期提速:3-5倍
|
||
"""
|
||
from scipy.signal import argrelextrema
|
||
|
||
# 找局部极大值(高点)
|
||
ph = argrelextrema(high, np.greater_equal, order=k)[0]
|
||
|
||
# 找局部极小值(低点)
|
||
pl = argrelextrema(low, np.less_equal, order=k)[0]
|
||
|
||
return ph, pl
|
||
|
||
|
||
def pivots_fractal_rolling(high, low, k=15):
|
||
"""
|
||
使用pandas滚动窗口实现
|
||
|
||
预期提速:2-4倍
|
||
"""
|
||
import pandas as pd
|
||
|
||
high_series = pd.Series(high)
|
||
low_series = pd.Series(low)
|
||
|
||
# 滚动窗口找最大/最小值索引
|
||
window = 2*k + 1
|
||
high_rolling_max = high_series.rolling(window, center=True).max()
|
||
low_rolling_min = low_series.rolling(window, center=True).min()
|
||
|
||
# 中心点等于窗口极值的位置即为枢轴点
|
||
ph = np.where((high_series == high_rolling_max) & (high_series.notna()))[0]
|
||
pl = np.where((low_series == low_rolling_min) & (low_series.notna()))[0]
|
||
|
||
return ph, pl
|
||
```
|
||
|
||
**实施**:
|
||
- 文件:`src/converging_triangle.py`
|
||
- 函数:`pivots_fractal()` 和 `pivots_fractal_hybrid()`
|
||
- 向下兼容:保留原函数作为fallback
|
||
|
||
---
|
||
|
||
#### 1.2 边界线拟合优化
|
||
|
||
**当前实现** (迭代离群点移除):
|
||
```python
|
||
def fit_pivot_line(pivot_indices, pivot_values, mode="upper", max_iter=10):
|
||
"""迭代移除离群点,多次最小二乘拟合"""
|
||
for iter in range(max_iter):
|
||
a, b = np.polyfit(indices, values, 1) # 最小二乘
|
||
residuals = values - (a * indices + b)
|
||
outliers = find_outliers(residuals)
|
||
if no_outliers: break
|
||
remove_outliers()
|
||
return a, b
|
||
```
|
||
|
||
**优化方案A** (预计算+缓存):
|
||
```python
|
||
def fit_pivot_line_cached(pivot_indices, pivot_values, mode="upper", cache=None):
|
||
"""
|
||
缓存中间结果,避免重复计算
|
||
|
||
场景:相邻日期的枢轴点大部分重叠
|
||
策略:缓存最近N天的拟合结果,增量更新
|
||
|
||
预期提速:30-50%(针对滚动窗口场景)
|
||
"""
|
||
cache_key = (tuple(pivot_indices), tuple(pivot_values), mode)
|
||
|
||
if cache and cache_key in cache:
|
||
return cache[cache_key]
|
||
|
||
# 原有拟合逻辑
|
||
result = _fit_pivot_line_core(pivot_indices, pivot_values, mode)
|
||
|
||
if cache is not None:
|
||
cache[cache_key] = result
|
||
|
||
return result
|
||
```
|
||
|
||
**优化方案B** (快速拟合算法):
|
||
```python
|
||
def fit_pivot_line_ransac(pivot_indices, pivot_values, mode="upper"):
|
||
"""
|
||
使用RANSAC快速拟合(对离群点鲁棒)
|
||
|
||
sklearn.linear_model.RANSACRegressor
|
||
预期提速:2-3倍
|
||
"""
|
||
from sklearn.linear_model import RANSACRegressor
|
||
|
||
X = pivot_indices.reshape(-1, 1)
|
||
y = pivot_values
|
||
|
||
ransac = RANSACRegressor(
|
||
residual_threshold=threshold,
|
||
max_trials=100,
|
||
random_state=42
|
||
)
|
||
ransac.fit(X, y)
|
||
|
||
a = ransac.estimator_.coef_[0]
|
||
b = ransac.estimator_.intercept_
|
||
inlier_mask = ransac.inlier_mask_
|
||
|
||
return a, b, np.where(inlier_mask)[0]
|
||
```
|
||
|
||
**推荐**: 先实施方案A(缓存),简单且收益稳定
|
||
|
||
---
|
||
|
||
#### 1.3 消除Python循环
|
||
|
||
**当前实现** (双层循环):
|
||
```python
|
||
def detect_converging_triangle_batch(...):
|
||
results = []
|
||
for stock_idx in range(n_stocks):
|
||
for date_idx in range(start_day, end_day + 1):
|
||
result = detect_converging_triangle_single(
|
||
stock_idx, date_idx, ...
|
||
)
|
||
results.append(result)
|
||
return pd.DataFrame(results)
|
||
```
|
||
|
||
**优化方案** (向量化外层):
|
||
```python
|
||
def detect_converging_triangle_batch_vectorized(...):
|
||
"""
|
||
外层循环向量化
|
||
|
||
策略:
|
||
1. 按date_idx分组,一次处理所有股票
|
||
2. 使用numpy广播并行计算
|
||
|
||
预期提速:1.5-2倍
|
||
"""
|
||
all_results = []
|
||
|
||
for date_idx in range(start_day, end_day + 1):
|
||
# 一次性处理所有股票在同一天的检测
|
||
# 提取窗口数据(向量化)
|
||
window_start = date_idx - window + 1
|
||
high_windows = high_mtx[:, window_start:date_idx+1] # (n_stocks, window)
|
||
low_windows = low_mtx[:, window_start:date_idx+1]
|
||
|
||
# 批量检测枢轴点(利用numpy向量运算)
|
||
pivots_batch = detect_pivots_batch(high_windows, low_windows)
|
||
|
||
# 批量拟合边界线
|
||
fits_batch = fit_lines_batch(pivots_batch)
|
||
|
||
# 批量计算强度
|
||
strengths_batch = calc_strengths_batch(fits_batch, ...)
|
||
|
||
all_results.append(strengths_batch)
|
||
|
||
return np.vstack(all_results)
|
||
```
|
||
|
||
**关键**: 需要重构算法,使单个函数能处理 (n_stocks, window) 维度
|
||
|
||
---
|
||
|
||
### ⚡ Level 2: Numba JIT加速(预期提速5-10倍)
|
||
|
||
#### 2.1 Numba加速核心函数
|
||
|
||
```python
|
||
from numba import jit, prange
|
||
|
||
@jit(nopython=True, parallel=True, cache=True)
|
||
def pivots_fractal_numba(high, low, k=15):
|
||
"""
|
||
Numba加速枢轴点检测
|
||
|
||
优势:
|
||
- nopython=True: 编译为机器码
|
||
- parallel=True: 多线程并行
|
||
- cache=True: 缓存编译结果
|
||
|
||
预期提速:10-20倍(相比纯Python)
|
||
"""
|
||
n = len(high)
|
||
ph_list = []
|
||
pl_list = []
|
||
|
||
for i in prange(k, n - k): # 并行循环
|
||
# 检查是否为高点
|
||
is_high_pivot = True
|
||
for j in range(i - k, i + k + 1):
|
||
if j != i and high[i] < high[j]:
|
||
is_high_pivot = False
|
||
break
|
||
if is_high_pivot:
|
||
ph_list.append(i)
|
||
|
||
# 检查是否为低点
|
||
is_low_pivot = True
|
||
for j in range(i - k, i + k + 1):
|
||
if j != i and low[i] > low[j]:
|
||
is_low_pivot = False
|
||
break
|
||
if is_low_pivot:
|
||
pl_list.append(i)
|
||
|
||
return np.array(ph_list), np.array(pl_list)
|
||
|
||
|
||
@jit(nopython=True, cache=True)
|
||
def fit_line_numba(x, y):
|
||
"""Numba加速最小二乘拟合"""
|
||
n = len(x)
|
||
x_mean = np.mean(x)
|
||
y_mean = np.mean(y)
|
||
|
||
numerator = np.sum((x - x_mean) * (y - y_mean))
|
||
denominator = np.sum((x - x_mean) ** 2)
|
||
|
||
a = numerator / denominator
|
||
b = y_mean - a * x_mean
|
||
|
||
return a, b
|
||
|
||
|
||
@jit(nopython=True, parallel=True)
|
||
def detect_batch_numba(
|
||
high_mtx, low_mtx, close_mtx, volume_mtx,
|
||
window, k, start_day, end_day
|
||
):
|
||
"""
|
||
Numba加速批量检测
|
||
|
||
核心优化:
|
||
- 消除Python对象开销
|
||
- 并行化最外层循环
|
||
- 预分配结果数组
|
||
|
||
预期提速:5-10倍
|
||
"""
|
||
n_stocks, n_days = high_mtx.shape
|
||
total_points = n_stocks * (end_day - start_day + 1)
|
||
|
||
# 预分配结果数组
|
||
strength_up = np.zeros(total_points, dtype=np.float64)
|
||
strength_down = np.zeros(total_points, dtype=np.float64)
|
||
is_valid = np.zeros(total_points, dtype=np.bool_)
|
||
|
||
# 并行处理每个检测点
|
||
for idx in prange(total_points):
|
||
stock_idx = idx // (end_day - start_day + 1)
|
||
day_offset = idx % (end_day - start_day + 1)
|
||
date_idx = start_day + day_offset
|
||
|
||
# 提取窗口数据
|
||
window_start = date_idx - window + 1
|
||
high_win = high_mtx[stock_idx, window_start:date_idx+1]
|
||
low_win = low_mtx[stock_idx, window_start:date_idx+1]
|
||
|
||
# 检测枢轴点
|
||
ph, pl = pivots_fractal_numba(high_win, low_win, k)
|
||
|
||
# ... 后续处理 ...
|
||
|
||
strength_up[idx] = computed_strength_up
|
||
strength_down[idx] = computed_strength_down
|
||
is_valid[idx] = computed_is_valid
|
||
|
||
return strength_up, strength_down, is_valid
|
||
```
|
||
|
||
**实施要点**:
|
||
- Numba要求函数纯数值计算,不能有pandas/字典等Python对象
|
||
- 首次运行会有JIT编译开销(~1-2秒),后续调用极快
|
||
- 需要将算法拆分为纯数值函数
|
||
|
||
---
|
||
|
||
### 🔥 Level 3: 并行化+缓存策略(预期提速10-20倍)
|
||
|
||
#### 3.1 多进程并行
|
||
|
||
```python
|
||
from multiprocessing import Pool, cpu_count
|
||
from functools import partial
|
||
|
||
def detect_stock_range(stock_indices, high_mtx, low_mtx, ...):
|
||
"""处理一批股票的检测任务"""
|
||
results = []
|
||
for stock_idx in stock_indices:
|
||
for date_idx in range(start_day, end_day + 1):
|
||
result = detect_converging_triangle_single(
|
||
stock_idx, date_idx, high_mtx, low_mtx, ...
|
||
)
|
||
results.append(result)
|
||
return results
|
||
|
||
|
||
def detect_converging_triangle_parallel(
|
||
high_mtx, low_mtx, close_mtx, volume_mtx,
|
||
params, start_day, end_day,
|
||
n_workers=None
|
||
):
|
||
"""
|
||
多进程并行检测
|
||
|
||
策略:
|
||
- 将5000只股票分成n_workers组
|
||
- 每个进程处理一组股票
|
||
- 主进程合并结果
|
||
|
||
预期提速:接近线性(8核约7倍)
|
||
"""
|
||
n_stocks = high_mtx.shape[0]
|
||
n_workers = n_workers or cpu_count() - 1
|
||
|
||
# 分配任务(按股票索引分组)
|
||
stock_groups = np.array_split(range(n_stocks), n_workers)
|
||
|
||
# 创建部分函数(固定参数)
|
||
detect_fn = partial(
|
||
detect_stock_range,
|
||
high_mtx=high_mtx,
|
||
low_mtx=low_mtx,
|
||
close_mtx=close_mtx,
|
||
volume_mtx=volume_mtx,
|
||
params=params,
|
||
start_day=start_day,
|
||
end_day=end_day
|
||
)
|
||
|
||
# 并行执行
|
||
with Pool(n_workers) as pool:
|
||
results_groups = pool.map(detect_fn, stock_groups)
|
||
|
||
# 合并结果
|
||
all_results = []
|
||
for group_results in results_groups:
|
||
all_results.extend(group_results)
|
||
|
||
return pd.DataFrame(all_results)
|
||
```
|
||
|
||
**注意**:
|
||
- 适合CPU密集型任务
|
||
- 需要足够内存(数据复制到子进程)
|
||
- 5000只股票场景下,8-16核最优
|
||
|
||
---
|
||
|
||
#### 3.2 增量计算+缓存
|
||
|
||
```python
|
||
class IncrementalDetector:
|
||
"""
|
||
增量检测器:缓存历史计算结果
|
||
|
||
场景:每天新增一个交易日,复用历史检测结果
|
||
策略:
|
||
1. 缓存最近N天的枢轴点/拟合线
|
||
2. 新增日期时只计算增量部分
|
||
3. LRU淘汰旧缓存
|
||
|
||
预期收益:
|
||
- 首次运行:无加速
|
||
- 后续每日:提速5-10倍(只需计算最新day)
|
||
"""
|
||
|
||
def __init__(self, window=240, cache_size=100):
|
||
self.window = window
|
||
self.pivot_cache = {} # {stock_idx: {date_idx: (ph, pl)}}
|
||
self.fit_cache = {} # {stock_idx: {date_idx: fit_result}}
|
||
self.cache_size = cache_size
|
||
|
||
def detect_incremental(self, stock_idx, new_date_idx, high, low, close, volume):
|
||
"""
|
||
增量检测:利用缓存快速计算
|
||
|
||
逻辑:
|
||
1. 检查缓存中是否有前一天的结果
|
||
2. 如果有,只需:
|
||
- 更新枢轴点(新增1天数据)
|
||
- 复用历史拟合结果
|
||
3. 如果无,全量计算并缓存
|
||
"""
|
||
prev_date_idx = new_date_idx - 1
|
||
|
||
# 尝试从缓存获取前一天结果
|
||
if stock_idx in self.pivot_cache and prev_date_idx in self.pivot_cache[stock_idx]:
|
||
# 增量更新枢轴点
|
||
prev_ph, prev_pl = self.pivot_cache[stock_idx][prev_date_idx]
|
||
new_ph, new_pl = self._update_pivots_incremental(
|
||
prev_ph, prev_pl, high, low, new_date_idx
|
||
)
|
||
else:
|
||
# 全量计算
|
||
new_ph, new_pl = pivots_fractal(high, low, k=15)
|
||
|
||
# 缓存枢轴点
|
||
if stock_idx not in self.pivot_cache:
|
||
self.pivot_cache[stock_idx] = {}
|
||
self.pivot_cache[stock_idx][new_date_idx] = (new_ph, new_pl)
|
||
|
||
# ... 后续处理 ...
|
||
|
||
return result
|
||
|
||
def _update_pivots_incremental(self, prev_ph, prev_pl, high, low, new_idx):
|
||
"""
|
||
增量更新枢轴点
|
||
|
||
策略:
|
||
1. 大部分枢轴点位置不变(相对索引+1)
|
||
2. 只需检查窗口边界的新增/移除
|
||
"""
|
||
# 简化版:这里需要更复杂的逻辑
|
||
# 实际应该检查最近k个点是否形成新枢轴
|
||
k = 15
|
||
last_points = high[-2*k:]
|
||
|
||
# 检查最新点是否为枢轴
|
||
if self._is_pivot_high(last_points, k):
|
||
prev_ph = np.append(prev_ph, new_idx)
|
||
if self._is_pivot_low(low[-2*k:], k):
|
||
prev_pl = np.append(prev_pl, new_idx)
|
||
|
||
return prev_ph, prev_pl
|
||
```
|
||
|
||
**实施优先级**:
|
||
- 中等(适合生产环境每日运行)
|
||
- 首次启动无收益,后续每日收益显著
|
||
|
||
---
|
||
|
||
### 💾 Level 4: 数据结构优化(预期提速1.5-2倍)
|
||
|
||
#### 4.1 使用Numpy结构化数组替代DataFrame
|
||
|
||
```python
|
||
def detect_converging_triangle_batch_numpy(...):
|
||
"""
|
||
使用numpy结构化数组替代pandas DataFrame
|
||
|
||
优势:
|
||
- 避免pandas对象开销
|
||
- 内存连续,cache友好
|
||
- 直接返回numpy数组供后续处理
|
||
|
||
预期提速:30-50%(减少内存分配)
|
||
"""
|
||
n_stocks, n_days = close_mtx.shape
|
||
total_points = n_stocks * (end_day - start_day + 1)
|
||
|
||
# 定义结果结构
|
||
dtype = np.dtype([
|
||
('stock_idx', np.int32),
|
||
('date_idx', np.int32),
|
||
('is_valid', np.bool_),
|
||
('strength_up', np.float32),
|
||
('strength_down', np.float32),
|
||
('convergence_score', np.float32),
|
||
('volume_score', np.float32),
|
||
('geometry_score', np.float32),
|
||
('activity_score', np.float32),
|
||
('tilt_score', np.float32),
|
||
])
|
||
|
||
# 预分配结果数组
|
||
results = np.empty(total_points, dtype=dtype)
|
||
|
||
idx = 0
|
||
for stock_idx in range(n_stocks):
|
||
for date_idx in range(start_day, end_day + 1):
|
||
result = detect_single(stock_idx, date_idx, ...)
|
||
|
||
# 直接写入结构化数组(无中间对象)
|
||
results[idx]['stock_idx'] = stock_idx
|
||
results[idx]['date_idx'] = date_idx
|
||
results[idx]['is_valid'] = result.is_valid
|
||
results[idx]['strength_up'] = result.strength_up
|
||
# ...
|
||
idx += 1
|
||
|
||
return results # 后续可转为DataFrame: pd.DataFrame(results)
|
||
```
|
||
|
||
---
|
||
|
||
#### 4.2 内存映射文件(大规模数据)
|
||
|
||
```python
|
||
def load_data_mmap(data_dir):
|
||
"""
|
||
使用内存映射加载数据
|
||
|
||
适用场景:
|
||
- 数据量 > 可用内存
|
||
- 多进程共享数据(避免复制)
|
||
|
||
预期收益:
|
||
- 加载时间:从秒级降到毫秒级
|
||
- 内存占用:0(按需加载页面)
|
||
"""
|
||
import os
|
||
|
||
# 保存为.npy格式(支持mmap)
|
||
high_mmap = np.load(
|
||
os.path.join(data_dir, 'high.npy'),
|
||
mmap_mode='r' # 只读模式
|
||
)
|
||
|
||
return high_mmap # 返回mmap对象,按需加载数据
|
||
|
||
|
||
def save_data_for_mmap(data, filepath):
|
||
"""保存数据为mmap兼容格式"""
|
||
np.save(filepath, data)
|
||
```
|
||
|
||
---
|
||
|
||
## 优化实施路线图
|
||
|
||
### Phase 1: 快速收益(1-2周,预期2-3倍提速)
|
||
|
||
**优先级P0**:
|
||
1. ✅ 枢轴点检测向量化(使用scipy或pandas rolling)
|
||
2. ✅ 边界线拟合缓存
|
||
3. ✅ 消除简单的Python循环(能向量化的先向量化)
|
||
|
||
**预期收益**:
|
||
- 耗时:30-60秒 → 10-20秒
|
||
- 实施难度:低
|
||
- 风险:低(向下兼容)
|
||
|
||
---
|
||
|
||
### Phase 2: Numba加速(2-3周,预期5-10倍提速)
|
||
|
||
**优先级P1**:
|
||
1. ✅ 核心函数Numba化(pivots/fit_line/calc_strength)
|
||
2. ✅ 批量检测主循环Numba化
|
||
3. ⚠️ 单元测试(确保数值精度一致)
|
||
|
||
**预期收益**:
|
||
- 耗时:10-20秒 → 2-5秒
|
||
- 实施难度:中
|
||
- 风险:中(需要验证数值稳定性)
|
||
|
||
**注意事项**:
|
||
```python
|
||
# Numba限制:
|
||
# 1. 不支持pandas DataFrame(需改用numpy)
|
||
# 2. 不支持字典/列表(需改用numpy数组)
|
||
# 3. 不支持动态类型(需显式类型标注)
|
||
|
||
# 解决方案:
|
||
# - 将pandas逻辑分离到外层
|
||
# - 核心计算用纯numpy实现
|
||
# - 添加类型标注
|
||
```
|
||
|
||
---
|
||
|
||
### Phase 3: 并行化(1周,预期10-20倍提速)
|
||
|
||
**优先级P2**(如果Phase 2后仍需优化):
|
||
1. ✅ 多进程并行检测(Pool/ProcessPoolExecutor)
|
||
2. ✅ 增量计算+缓存策略(生产环境每日运行)
|
||
|
||
**预期收益**:
|
||
- 耗时:2-5秒 → <1秒
|
||
- 实施难度:低(如果已完成Phase 1-2)
|
||
- 风险:低
|
||
|
||
---
|
||
|
||
### Phase 4: 极致优化(按需实施)
|
||
|
||
**优先级P3**(仅当前面优化不够):
|
||
1. Cython重写核心模块(C扩展)
|
||
2. GPU加速(CUDA/cupy)
|
||
3. Rust扩展(pyo3)
|
||
|
||
**预期收益**:
|
||
- 耗时:<1秒 → 毫秒级
|
||
- 实施难度:高
|
||
- 风险:高(维护成本)
|
||
|
||
---
|
||
|
||
## 基准测试脚本
|
||
|
||
### 使用现有测试脚本
|
||
|
||
```bash
|
||
# 当前项目已有性能测试脚本
|
||
cd d:\project\technical-patterns-lab
|
||
|
||
# 运行小规模测试(10只股票)
|
||
python scripts/test_performance.py
|
||
|
||
# 查看profiling结果
|
||
pip install snakeviz
|
||
snakeviz outputs/performance/profile_*.prof
|
||
```
|
||
|
||
### 创建5000股测试脚本
|
||
|
||
```python
|
||
# scripts/benchmark_5000_stocks.py
|
||
"""
|
||
5000只股票性能测试
|
||
"""
|
||
import time
|
||
import numpy as np
|
||
from src.converging_triangle import detect_converging_triangle_batch, ConvergingTriangleParams
|
||
|
||
def generate_synthetic_data(n_stocks=5000, n_days=500):
|
||
"""生成合成数据用于测试"""
|
||
np.random.seed(42)
|
||
|
||
base_price = 10 + np.random.randn(n_stocks, 1) * 2
|
||
returns = np.random.randn(n_stocks, n_days) * 0.02
|
||
|
||
close = base_price * np.cumprod(1 + returns, axis=1)
|
||
high = close * (1 + np.abs(np.random.randn(n_stocks, n_days)) * 0.01)
|
||
low = close * (1 - np.abs(np.random.randn(n_stocks, n_days)) * 0.01)
|
||
open_ = close * (1 + np.random.randn(n_stocks, n_days) * 0.005)
|
||
volume = np.random.randint(100000, 1000000, (n_stocks, n_days))
|
||
|
||
return open_, high, low, close, volume
|
||
|
||
def benchmark_5000_stocks():
|
||
print("=" * 80)
|
||
print("5000只股票性能测试")
|
||
print("=" * 80)
|
||
|
||
# 生成数据
|
||
print("\n生成测试数据...")
|
||
open_, high, low, close, volume = generate_synthetic_data(5000, 500)
|
||
print(f"数据形状: {close.shape}")
|
||
|
||
# 配置参数
|
||
params = ConvergingTriangleParams(window=240, pivot_k=15)
|
||
|
||
# 测试
|
||
print("\n开始检测...")
|
||
start = time.time()
|
||
|
||
df = detect_converging_triangle_batch(
|
||
open_mtx=open_,
|
||
high_mtx=high,
|
||
low_mtx=low,
|
||
close_mtx=close,
|
||
volume_mtx=volume,
|
||
params=params,
|
||
start_day=239,
|
||
end_day=499,
|
||
only_valid=True,
|
||
verbose=False
|
||
)
|
||
|
||
elapsed = time.time() - start
|
||
|
||
# 结果
|
||
print("\n" + "=" * 80)
|
||
print("测试结果")
|
||
print("=" * 80)
|
||
print(f"总耗时: {elapsed:.2f} 秒")
|
||
print(f"检测点数: {5000 * 261}") # (500-239)
|
||
print(f"速度: {5000*261/elapsed:.1f} 点/秒")
|
||
print(f"有效形态: {len(df)}")
|
||
|
||
# 评估
|
||
print("\n" + "=" * 80)
|
||
if elapsed < 5:
|
||
print("✅ 性能优秀! (<5秒)")
|
||
elif elapsed < 10:
|
||
print("✔️ 性能良好 (5-10秒)")
|
||
elif elapsed < 30:
|
||
print("⚠️ 性能一般 (10-30秒), 建议优化")
|
||
else:
|
||
print("❌ 性能较差 (>30秒), 急需优化")
|
||
|
||
if __name__ == '__main__':
|
||
benchmark_5000_stocks()
|
||
```
|
||
|
||
---
|
||
|
||
## 质量保证
|
||
|
||
### 回归测试
|
||
|
||
```python
|
||
# tests/test_optimization_correctness.py
|
||
"""
|
||
优化正确性测试:确保优化后结果一致
|
||
"""
|
||
import numpy as np
|
||
import pytest
|
||
|
||
def test_optimized_vs_original():
|
||
"""对比优化版本和原版本的结果"""
|
||
# 加载测试数据
|
||
data = load_test_case()
|
||
|
||
# 原版本
|
||
result_orig = detect_original(data)
|
||
|
||
# 优化版本
|
||
result_opt = detect_optimized(data)
|
||
|
||
# 验证结果一致(允许微小数值误差)
|
||
np.testing.assert_allclose(
|
||
result_orig['strength_up'],
|
||
result_opt['strength_up'],
|
||
rtol=1e-5,
|
||
atol=1e-8
|
||
)
|
||
|
||
# 验证有效性标记完全一致
|
||
assert (result_orig['is_valid'] == result_opt['is_valid']).all()
|
||
```
|
||
|
||
---
|
||
|
||
## 预期效果总结
|
||
|
||
| 优化阶段 | 实施难度 | 预期提速 | 累计提速 | 总耗时(5000股) |
|
||
|---------|---------|---------|---------|---------------|
|
||
| **当前** | - | - | 1x | 30-60秒 |
|
||
| Phase 1 | 低 | 2-3x | 2-3x | 10-20秒 |
|
||
| Phase 2 | 中 | 5-10x | 10-30x | 2-5秒 |
|
||
| Phase 3 | 低 | 10-20x | 20-60x | <1秒 |
|
||
| Phase 4 | 高 | 50-100x | 100-300x | 毫秒级 |
|
||
|
||
**推荐路径**: Phase 1 → Phase 2 → 观察是否满足需求 → 按需进入Phase 3
|
||
|
||
---
|
||
|
||
## 立即行动
|
||
|
||
### 本周任务
|
||
|
||
1. **基准测试当前性能**
|
||
```bash
|
||
python scripts/benchmark_5000_stocks.py
|
||
```
|
||
|
||
2. **确认瓶颈函数**
|
||
```bash
|
||
python -m cProfile -o profile.stats scripts/benchmark_5000_stocks.py
|
||
python -m pstats profile.stats
|
||
# >> stats
|
||
# >> sort cumulative
|
||
# >> stats 20
|
||
```
|
||
|
||
3. **优先实施**:枢轴点检测向量化(收益最大、难度最低)
|
||
|
||
---
|
||
|
||
需要我帮你实现具体的优化代码吗?比如从枢轴点检测向量化开始?
|