-
Notifications
You must be signed in to change notification settings - Fork 37
/
train_dqn.py
108 lines (86 loc) · 2.93 KB
/
train_dqn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from model.dqn_agent import DQNAgent
import gym
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import joblib
import arrow
from utils.trader import Trader
class TraderEnv(object):
def __init__(self):
print("setting up trading env")
df = pd.read_pickle("cache/encoded_rows.pkl")
encoded = np.load("cache/unscaled_data.npy").astype(np.float32)
self.trader = Trader()
self.current_step = 1
valid_tickers = self.trader.quotes.valid_tickers
# filter valid tickers
valid_rows, valid_x = [], []
for idx, row in df.iterrows():
if row["Ticker"] in valid_tickers:
valid_rows.append(row)
valid_x.append(encoded[idx])
df = pd.DataFrame(valid_rows)
encoded = np.array(valid_x)
# only use subset of data
split = int(0.4 * len(encoded))
df, encoded = df.iloc[split:], encoded[split:]
split = int(0.6 * len(encoded))
encoded, encoded_test = encoded[:split], encoded[split:]
self.df, self.df_test = df.iloc[:split], df.iloc[split:]
self.day = arrow.get(self.df["Time"].iloc[0].format("YYYY-MM-DD"))
# scale
scaler = MinMaxScaler()
scaler.fit(encoded)
self.encoded, self.encoded_test = scaler.transform(encoded), scaler.transform(
encoded_test
)
joblib.dump(scaler, "cache/dqn_scaler.gz")
def step(self, action):
row = self.df.iloc[self.current_step]
# new day, check expiries
current_day = arrow.get(row["Time"].format("YYYY-MM-DD"))
if current_day != self.day:
self.trader.eod(self.day.format("YYYY-MM-DD"))
self.day = current_day
if action == 0:
current_price = row["Spot"]
expiry = row["Expiry"].format("YYYY-MM-DD")
ticker = row["Ticker"]
self.trader.trade_on_signal(ticker, "BULLISH", current_price, expiry)
next_state = self.encoded[self.current_step]
self.current_step += 1
reward = self.trader.current_reward
done = reward < -50 or self.current_step == len(self.encoded)
return next_state, reward, done
def reset(self):
self.trader = Trader()
self.current_step = 1
self.day = arrow.get(self.df["Time"].iloc[0].format("YYYY-MM-DD"))
return self.encoded[0]
seed = 777
def seed_torch(seed):
torch.manual_seed(seed)
if torch.backends.cudnn.enabled:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
np.random.seed(seed)
seed_torch(seed)
# parameters
num_frames = int(1e7)
memory_size = 10000
batch_size = 32
target_update = 1000
epsilon_decay = 1 / 2000
env = TraderEnv()
agent = DQNAgent(
env,
env.encoded.shape[1],
memory_size,
batch_size,
target_update,
epsilon_decay,
gamma=0.999,
)
agent.train(num_frames)