Skip to content

Commit

Permalink
get_price, working both for stock and index
Browse files Browse the repository at this point in the history
  • Loading branch information
yssource committed May 27, 2021
1 parent baf48e9 commit 94cf161
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 82 deletions.
131 changes: 81 additions & 50 deletions abquant/apis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from functools import lru_cache
from abquant.helper import unnormalize_code
from abquant.utils.code import code_tolist
from abquant.utils.tdx import is_index_cn
from abquant.utils.logger import user_log as ulog
from abquant.helper import unnormalize_code
from pyabquant import PyAbquant, FQ_TYPE, INSTRUMENT_TYPE
from pyabquant import PyAbquant, FQ_TYPE, INSTRUMENT_TYPE # type: ignore


def get_price(
order_book_ids: Union[str, Iterable[str]],
start_date: Union[datetime.date, str],
order_book_ids: Union[str, List[str]],
start_date: Union[datetime.date, datetime.datetime, str],
end_date: Optional[Union[datetime.date, datetime.datetime, str]] = None,
frequency: Optional[str] = "1d",
fields: List[str] = [],
Expand Down Expand Up @@ -80,8 +81,6 @@ def get_price(
#...
"""

order_book_ids = code_tolist(order_book_ids)

if fields is None:
fields = ["open", "close", "high", "low", "vol"]

Expand All @@ -90,6 +89,10 @@ def get_price(
if isinstance(end_date, (datetime.datetime,)):
end_date = end_date.strftime("%Y-%m-%d %H:%M:%S")

if frequency in ["1d"]:
start_date = start_date.split(" ")[0] # type: ignore
end_date = end_date.split(" ")[0] # type: ignore

if isinstance(fields, (str,)):
fields = [fields]

Expand All @@ -110,9 +113,6 @@ def get_price(
"...",
], f"invalid {field}"

if isinstance(order_book_ids, (str,)):
order_book_ids = order_book_ids.split()

assert adjust_type in [
"pre",
None,
Expand All @@ -126,72 +126,103 @@ def get_price(
else:
fq = FQ_TYPE.NONE

if frequency in ["1d"]:
from pyabqstockday import PyStockDay as stockday
if isinstance(order_book_ids, (str,)):
order_book_ids = order_book_ids.split()

if not is_index_cn(order_book_ids[0]):
order_book_ids = code_tolist(order_book_ids)
if frequency in ["1d"]:
from pyabqstockday import PyStockDay as stockday # type: ignore

sm = stockday(order_book_ids, start_date, end_date, fq)
else:

from pyabqstockmin import PyStockMin as stockmin # type: ignore

sdm = stockday(order_book_ids, start_date, end_date, fq)
sm = stockmin(order_book_ids, start_date, end_date, frequency, fq)
else:
order_book_ids = code_tolist(order_book_ids)
if frequency in ["1d"]:
from pyabqindexday import PyIndexDay as indexday # type: ignore

from pyabqstockmin import PyStockMin as stockmin
sm = indexday(order_book_ids, start_date, end_date)
else:

sdm = stockmin(order_book_ids, start_date, end_date, frequency, fq)
from pyabqindexmin import PyIndexMin as indexmin # type: ignore

sm = indexmin(order_book_ids, start_date, end_date, frequency)

code = sm.to_series_string("code")
date = sm.to_series_string("date")
df = pd.DataFrame()
try:
if frequency in ["1d"]:
df = pd.DataFrame({"code": code, "date": date})
else:
datetime_ = sm.to_series_string("datetime")
df = pd.DataFrame({"code": code, "datetime": datetime_})
print(df)
except Exception:
raise RuntimeError("fail to get_price")

date = sdm.to_series_string("date")
code = sdm.to_series_string("code")
df = pd.DataFrame({"code": code, "date": date})
for field in fields:
if field in ["code", "date"]:
if field in ["code", "date", "datetime"]:
continue
if field in ["vol"]:
df["volume"] = sdm.to_series(field)
df[field] = sdm.to_series(field)
df.set_index(["code", "date"], inplace=True)
df["volume"] = sm.to_series(field)
df[field] = sm.to_series(field)

if frequency in ["1d"]:
df.set_index(["code", "date"], inplace=True)
else:
df.set_index(["code", "datetime"], inplace=True)

if len(fields) == 1:
# return pd.Series
return df[fields[0]]
return df[fields[0]] # type: ignore
# print(df)
return df


@lru_cache(maxsize=256)
@lru_cache(maxsize=1024)
def get_all_securities(
types: List[str] = [], date: Optional[str] = None
type_: INSTRUMENT_TYPE = INSTRUMENT_TYPE.CS, date: Optional[str] = None
) -> pd.DataFrame:
df = pd.DataFrame()
if "cs" in types:
from pyabqsecuritylist import PySecurityList as securitylist

sl = securitylist()
code = sl.to_series_string("code")
volunit = sl.to_series("volunit")
decimal_point = sl.to_series("decimal_point")
name = sl.to_series_string("name")
pre_close = sl.to_series("pre_close")
sse = sl.to_series_string("sse")
sec = sl.to_series_string("sec")
print(code[:10])

df = pd.DataFrame(
{
"code": code,
"volunit": volunit,
"decimal_point": decimal_point,
"name": name,
"pre_close": pre_close,
"sse": sse,
"sec": sec,
}
)
df.set_index(["code"], inplace=True)
from pyabqsecuritylist import PySecurityList as securitylist # type: ignore

sl = securitylist(type_)

code = sl.to_series_string("code")
volunit = sl.to_series("volunit")
decimal_point = sl.to_series("decimal_point")
name = sl.to_series_string("name")
pre_close = sl.to_series("pre_close")
sse = sl.to_series_string("sse")
sec = sl.to_series_string("sec")
# print(code[:10])

df = pd.DataFrame(
{
"code": code,
"volunit": volunit,
"decimal_point": decimal_point,
"name": name,
"pre_close": pre_close,
"sse": sse,
"sec": sec,
}
)
df.set_index(["code"], inplace=True)

return df


@lru_cache(maxsize=256)
@lru_cache(maxsize=1024)
def get_security_info(
code: str, ins_type: INSTRUMENT_TYPE = INSTRUMENT_TYPE.CS
) -> pd.Series:
from pyabqsecuritylist import PySecurityList as securitylist
from pyabqsecuritylist import PySecurityList as securitylist # type: ignore

sl = securitylist([unnormalize_code(code)], "", ins_type)

Expand Down
24 changes: 19 additions & 5 deletions abquant/utils/tdx.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from functools import lru_cache
import pandas as pd
import numpy as np
from abquant.utils.logger import system_log as slog
from abquant.helper import to_json_from_pandas
from abquant.helper import normalize_code, to_json_from_pandas
from abquant.utils.qa import make_datestamp
from abquant.config import Setting
import json


def for_sz(code):
def for_sz(code: str) -> str:
"""深市代码分类
Arguments:
code {[type]} -- [description]
Expand Down Expand Up @@ -58,7 +59,7 @@ def for_sz(code):
return "undefined"


def for_sh(code):
def for_sh(code: str) -> str:
if str(code)[0] == "6":
return "stock_cn"
elif str(code)[0:3] in ["000", "880"]:
Expand Down Expand Up @@ -211,10 +212,10 @@ def query_stock_day(
# code= [code] if isinstance(code,str) else code

# code checking
code = code.split()
codes = code.split()
cursor = collections.find(
{
"code": {"$in": code},
"code": {"$in": codes},
"date_stamp": {"$lte": make_datestamp(end), "$gte": make_datestamp(start)},
},
{"_id": 0},
Expand Down Expand Up @@ -335,3 +336,16 @@ def save_error_log(err: list, key: str):
errs[key] = list(set(err))
with Setting.ERROR_CODES_JSON.open(mode="w") as f:
json.dump(errs, f)


@lru_cache()
def is_index_cn(code: str) -> bool:
# FIXME: is not perfect here, if the code is bare digits, without "XSHG" or
# "XSHE", suffix
if "XSHG" in code and "000" in code:
return True
if "XSHG" in normalize_code(code) and for_sh(code) == "index_cn":
return True
if "XSHE" in normalize_code(code) and for_sz(code) == "index_cn":
return True
return False
Loading

0 comments on commit 94cf161

Please sign in to comment.