Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make talib support multi timescale. #581

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
add stride for talib
  • Loading branch information
eromoe committed Apr 4, 2023
commit 52a3a58eb537eeed8b36bab962e70cd05bbe3b66
5 changes: 5 additions & 0 deletions vectorbt/indicators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def talib(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbt.indicators.factory.IndicatorFactory.from_talib`."""
return IndicatorFactory.from_talib(*args, **kwargs)

def mtalib(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbt.indicators.factory.IndicatorFactory.from_talib`."""
return IndicatorFactory.from_mtalib(*args, **kwargs)

def pandas_ta(*args, **kwargs) -> tp.Type[IndicatorBase]:
"""Shortcut for `vectorbt.indicators.factory.IndicatorFactory.from_pandas_ta`."""
Expand All @@ -38,6 +41,7 @@ def ta(*args, **kwargs) -> tp.Type[IndicatorBase]:
__all__ = [
'IndicatorFactory',
'talib',
'mtalib',
'pandas_ta',
'ta',
'MA',
Expand All @@ -50,6 +54,7 @@ def ta(*args, **kwargs) -> tp.Type[IndicatorBase]:
'OBV'
]
__whitelist__ = [
'mtalib',
'talib',
'pandas_ta',
'ta'
Expand Down
130 changes: 130 additions & 0 deletions vectorbt/indicators/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3490,6 +3490,136 @@ def apply_func(input_list: tp.List[tp.AnyArray],
)
return TALibIndicator


@classmethod
def from_mtalib(cls, func_name: str, timescale:int, init_kwargs: tp.KwargsLike = None, **kwargs) -> tp.Type[IndicatorBase]:
"""Build an indicator class around a TA-Lib function.

Requires [TA-Lib](https://github.com/mrjbq7/ta-lib) installed.

For input, parameter and output names, see [docs](https://github.com/mrjbq7/ta-lib/blob/master/docs/index.md).

Args:
func_name (str): Function name.
timescale:
init_kwargs (dict): Keyword arguments passed to `IndicatorFactory`.
**kwargs: Keyword arguments passed to `IndicatorFactory.from_custom_func`.

Returns:
Indicator

Usage:
```pycon
>>> SMA = vbt.IndicatorFactory.from_talib('SMA')

>>> sma = SMA.run(price, timeperiod=[2, 3])
>>> sma.real
sma_timeperiod 2 3
a b a b
2020-01-01 NaN NaN NaN NaN
2020-01-02 1.5 4.5 NaN NaN
2020-01-03 2.5 3.5 2.0 4.0
2020-01-04 3.5 2.5 3.0 3.0
2020-01-05 4.5 1.5 4.0 2.0
```

* To get help on running the indicator, use the `help` command:

```pycon
>>> help(SMA.run)
Help on method run:

run(close, timeperiod=30, short_name='sma', hide_params=None, hide_default=True, **kwargs) method of builtins.type instance
Run `SMA` indicator.

* Inputs: `close`
* Parameters: `timeperiod`
* Outputs: `real`

Pass a list of parameter names as `hide_params` to hide their column levels.
Set `hide_default` to False to show the column levels of the parameters with a default value.

Other keyword arguments are passed to `vectorbt.indicators.factory.run_pipeline`.
```
"""
import talib
from talib import abstract

func_name = func_name.upper()
talib_func = getattr(talib, func_name)
info = abstract.Function(func_name).info
input_names = []
for in_names in info['input_names'].values():
if isinstance(in_names, (list, tuple)):
input_names.extend(list(in_names))
else:
input_names.append(in_names)
class_name = info['name']
class_docstring = "{}, {}".format(info['display_name'], info['group'])
param_names = list(info['parameters'].keys())
output_names = info['output_names']
output_flags = info['output_flags']

def apply_func(input_list: tp.List[tp.AnyArray],
in_output_tuple: tp.Tuple[tp.AnyArray, ...],
param_tuple: tp.Tuple[tp.Param, ...],
**kwargs) -> tp.Union[tp.Array2d, tp.List[tp.Array2d]]:

# TA-Lib functions can only process 1-dim arrays
n_input_cols = input_list[0].shape[1]
outputs = []

# 增加 timescale 支持
# 保证 是 timescale的倍数
n = input_list[0].shape[0]
start = n % timescale
if start:
n = n - start + timescale

for col in range(n_input_cols):

s = np.full(n, fill_value=np.nan).reshape(-1, timescale)
for i in range(timescale):
s[1 if start else 0:,i] = talib_func(
*map(lambda x: x[start+i::timescale, col], input_list),
*param_tuple,
**kwargs
)

if start:
# 有start 会跳过第一个不完整的stride
output = s.reshape(n,)[timescale-start:]
else:
# 没有则不跳
output = s.reshape(n, )

outputs.append(output)
if isinstance(outputs[0], tuple): # multiple outputs
outputs = list(zip(*outputs))
return list(map(np.column_stack, outputs))
return np.column_stack(outputs)

TALibIndicator = cls(
**merge_dicts(
dict(
class_name=class_name,
class_docstring=class_docstring,
input_names=input_names,
param_names=param_names, # 这里增加 timescale, 就得在apply_func 里剔除
output_names=output_names,
output_flags=output_flags
),
init_kwargs
)
).from_apply_func(
apply_func,
pass_packed=True,
**info['parameters'], # 输入强制要求 kv 形式
**kwargs
)
return TALibIndicator


@classmethod
def parse_pandas_ta_config(cls,
func: tp.Callable,
Expand Down