Skip to content

Commit

Permalink
py, refactoring visitor pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
yssource committed Feb 12, 2022
1 parent 4b6188c commit 64130b0
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 49 deletions.
100 changes: 62 additions & 38 deletions abquant/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from __future__ import annotations
import abc
from typing import List
from typing import List, TYPE_CHECKING
from abquant.helper import time_counter, INSTRUMENT_TYPE

if TYPE_CHECKING:
from abquant.data.tdx import Stock, Future, Etf


class ISecurity(object, metaclass=abc.ABCMeta):
"""
Expand All @@ -14,58 +17,79 @@ class ISecurity(object, metaclass=abc.ABCMeta):

@abc.abstractmethod
def accept(self, visitor: ISecurityVisitor) -> None:
pass
...

def getClassName(self):
return self.__class__.__name__
@abc.abstractmethod
def create_list(self, iSecurity: ISecurity) -> None:
...

@abc.abstractmethod
def create_day(self, iSecurity: ISecurity) -> None:
...

class ISecurityVisitor(object, metaclass=abc.ABCMeta):
@abc.abstractmethod
def create_list(self, iSecurity: ISecurity):
pass
def create_min(self, iSecurity: ISecurity) -> None:
...

def getClassName(self) -> str:
return self.__class__.__name__


class ISecurityVisitor(object, metaclass=abc.ABCMeta):
@abc.abstractmethod
def create_day(self, iSecurity: ISecurity):
pass
def visit_stock(self, stock: Stock) -> None:
...

@abc.abstractmethod
def create_min(self, iSecurity: ISecurity):
pass
def visit_future(self, future: Future) -> None:
...

@abc.abstractmethod
def create_xdxr(self, iSecurity: ISecurity):
pass
def visit_etf(self, stock: Etf) -> None:
...


class SecurityVisitor(ISecurityVisitor):
def __init__(self, *args, **kwargs):
"docstring"
pass

def create_list(self, iSecurity: ISecurity):
if getattr(iSecurity, "create_list"):
iSecurity.create_list(ins_type=INSTRUMENT_TYPE.INDX)
iSecurity.create_list(ins_type=INSTRUMENT_TYPE.ETF)
iSecurity.create_list(ins_type=INSTRUMENT_TYPE.CS)

def create_day(self, iSecurity: ISecurity):
if getattr(iSecurity, "create_day"):
iSecurity.create_day(ins_type=INSTRUMENT_TYPE.INDX)
iSecurity.create_day(ins_type=INSTRUMENT_TYPE.ETF)
iSecurity.create_day(ins_type=INSTRUMENT_TYPE.CS)

def create_min(self, iSecurity: ISecurity):
if getattr(iSecurity, "create_min"):
iSecurity.create_min(ins_type=INSTRUMENT_TYPE.INDX)
iSecurity.create_min(ins_type=INSTRUMENT_TYPE.ETF)
iSecurity.create_min(ins_type=INSTRUMENT_TYPE.CS)

def create_xdxr(self, iSecurity: ISecurity):
if getattr(iSecurity, "create_xdxr", None):
pass
# please manually `abquant stock xdxr`
# iSecurity.create_xdxr()
...

def visit_stock(self, stock: Stock):
if getattr(stock, "create_list"):
stock.create_list(ins_type=INSTRUMENT_TYPE.INDX)
stock.create_list(ins_type=INSTRUMENT_TYPE.CS)

if getattr(stock, "create_day"):
stock.create_day(ins_type=INSTRUMENT_TYPE.INDX)
stock.create_day(ins_type=INSTRUMENT_TYPE.CS)

if getattr(stock, "create_min"):
stock.create_min(ins_type=INSTRUMENT_TYPE.INDX)
stock.create_min(ins_type=INSTRUMENT_TYPE.CS)

if getattr(stock, "create_xdxr"):
stock.create_xdxr(ins_type=INSTRUMENT_TYPE.INDX)
stock.create_xdxr(ins_type=INSTRUMENT_TYPE.CS)

def visit_future(self, future: Future):
if getattr(future, "create_list"):
future.create_list(ins_type=INSTRUMENT_TYPE.FUTURE)

if getattr(future, "create_day"):
future.create_day(ins_type=INSTRUMENT_TYPE.FUTURE)

if getattr(future, "create_min"):
future.create_min(ins_type=INSTRUMENT_TYPE.FUTURE)

def visit_etf(self, etf: Etf):
if getattr(etf, "create_list"):
etf.create_list(ins_type=INSTRUMENT_TYPE.ETF)

if getattr(etf, "create_day"):
etf.create_day(ins_type=INSTRUMENT_TYPE.ETF)

if getattr(etf, "create_min"):
etf.create_min(ins_type=INSTRUMENT_TYPE.ETF)


def get_broker(broker="tdx"):
Expand Down
18 changes: 9 additions & 9 deletions abquant/data/tdx.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ def __init__(self, *args, **kwargs):
self.freqs = kwargs.get("freqs", ["1min", "5min", "15min", "30min", "60min"])

def accept(self, visitor: ISecurityVisitor) -> None:
visitor.create_day(self)
visitor.create_min(self)
visitor.create_xdxr(self)
visitor.visit_stock(self)

def create_list(self, *args, **kwargs):
"""save security list
Expand Down Expand Up @@ -88,6 +86,7 @@ def create_day(self, *args, **kwargs):
client {[type]} -- [description] (default: {DATABASE})
"""
ins_type = kwargs.pop("ins_type", INSTRUMENT_TYPE.CS)

instruments = (
self.codes
if self.codes
Expand Down Expand Up @@ -465,9 +464,12 @@ def __init__(self, *args, **kwargs):
self.freqs = []

def accept(self, visitor: ISecurityVisitor) -> None:
visitor.create_day(self)
visitor.create_min(self)
visitor.create_xdxr(self)
visitor.visit_future(self)

def create_list(self, *args, **kwargs):
"""create_list, for Future
"""
return

def create_day(self, *args, **kwargs):
ins_type = kwargs.pop("ins_type", "future")
Expand All @@ -489,9 +491,7 @@ def __init__(self, *args, **kwargs):
self.freqs = kwargs.get("freqs", ["1min", "5min", "15min", "30min", "60min"])

def accept(self, visitor: ISecurityVisitor) -> None:
visitor.create_list(self)
visitor.create_day(self)
visitor.create_min(self)
visitor.visit_etf(self)

def create_list(self, *args, **kwargs):
"""save etf list
Expand Down
3 changes: 1 addition & 2 deletions abquant/data/ts_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import re
import pandas as pd # type: ignore
from abquant.utils.ts import (
code_to_symbol,
LIVE_DATA_URL,
Expand All @@ -10,6 +8,7 @@
LIVE_DATA_COLS,
)
from abquant.utils.logger import user_log as ulog
import pandas as pd # type: ignore
import re
import requests
from requests.exceptions import Timeout
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ simplejson = "^3.17.2"
pymongo = "^3.11.3"
pyarrow = "^4.0.0"
requests = "^2.27.1"
tqdm = "^4.62.3"

[tool.poetry.dev-dependencies]
conan = "^1.34.0"
Expand Down

0 comments on commit 64130b0

Please sign in to comment.