forked from linyiLYi/street-fighter-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
127 lines (103 loc) · 4.01 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright 2023 LIN Yi. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import sys
import retro
from stable_baselines3 import PPO
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.vec_env import SubprocVecEnv
from street_fighter_custom_wrapper import StreetFighterCustomWrapper
NUM_ENV = 16
LOG_DIR = 'logs'
os.makedirs(LOG_DIR, exist_ok=True)
# Linear scheduler
def linear_schedule(initial_value, final_value=0.0):
if isinstance(initial_value, str):
initial_value = float(initial_value)
final_value = float(final_value)
assert (initial_value > 0.0)
def scheduler(progress):
return final_value + progress * (initial_value - final_value)
return scheduler
def make_env(game, state, seed=0):
def _init():
env = retro.make(
game=game,
state=state,
use_restricted_actions=retro.Actions.FILTERED,
obs_type=retro.Observations.IMAGE
)
env = StreetFighterCustomWrapper(env)
env = Monitor(env)
env.seed(seed)
return env
return _init
def main():
# Set up the environment and model
game = "StreetFighterIISpecialChampionEdition-Genesis"
env = SubprocVecEnv([make_env(game, state="Champion.Level12.RyuVsBison", seed=i) for i in range(NUM_ENV)])
# Set linear schedule for learning rate
# Start
lr_schedule = linear_schedule(2.5e-4, 2.5e-6)
# fine-tune
# lr_schedule = linear_schedule(5.0e-5, 2.5e-6)
# Set linear scheduler for clip range
# Start
clip_range_schedule = linear_schedule(0.15, 0.025)
# fine-tune
# clip_range_schedule = linear_schedule(0.075, 0.025)
model = PPO(
"CnnPolicy",
env,
device="cuda",
verbose=1,
n_steps=512,
batch_size=512,
n_epochs=4,
gamma=0.94,
learning_rate=lr_schedule,
clip_range=clip_range_schedule,
tensorboard_log="logs"
)
# Set the save directory
save_dir = "trained_models"
os.makedirs(save_dir, exist_ok=True)
# Load the model from file
# model_path = "trained_models/ppo_ryu_7000000_steps.zip"
# Load model and modify the learning rate and entropy coefficient
# custom_objects = {
# "learning_rate": lr_schedule,
# "clip_range": clip_range_schedule,
# "n_steps": 512
# }
# model = PPO.load(model_path, env=env, device="cuda", custom_objects=custom_objects)
# Set up callbacks
# Note that 1 timesetp = 6 frame
checkpoint_interval = 31250 # checkpoint_interval * num_envs = total_steps_per_checkpoint
checkpoint_callback = CheckpointCallback(save_freq=checkpoint_interval, save_path=save_dir, name_prefix="ppo_ryu")
# Writing the training logs from stdout to a file
original_stdout = sys.stdout
log_file_path = os.path.join(save_dir, "training_log.txt")
with open(log_file_path, 'w') as log_file:
sys.stdout = log_file
model.learn(
total_timesteps=int(100000000), # total_timesteps = stage_interval * num_envs * num_stages (1120 rounds)
callback=[checkpoint_callback]#, stage_increase_callback]
)
env.close()
# Restore stdout
sys.stdout = original_stdout
# Save the final model
model.save(os.path.join(save_dir, "ppo_sf2_ryu_final.zip"))
if __name__ == "__main__":
main()