Skip to content

Commit

Permalink
added multiple enemy feature to env and solved discrete env.
Browse files Browse the repository at this point in the history
  • Loading branch information
khush3 committed Jun 18, 2019
1 parent af9ba35 commit 62d9b71
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 18 deletions.
222 changes: 222 additions & 0 deletions custom_env_solution.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Solution to custom environment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from myenv import ENVIRONMENT\n",
"import numpy as np\n",
"import torch\n",
"import cv2\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"HM_EPISODES = 25000\n",
"MOVE_PENALTY = 1\n",
"ENEMY_PENALTY = 300\n",
"FOOD_REWARD = 25\n",
"epsilon = 0.9\n",
"EPS_DECAY = 0.9998 \n",
"SHOW_EVERY = 1000 \n",
"DISPLAY_EVERY= 500\n",
"SIZE = 10\n",
"LEARNING_RATE = 0.1\n",
"DISCOUNT = 0.95\n",
"total_reward = 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_q_table(start_q_table=None,size=10,action=4):\n",
" \n",
" if start_q_table is None:\n",
" q_table = np.random.randn(size,size,action)\n",
" print(q_table.size)\n",
"\n",
" else:\n",
" with open(start_q_table, \"rb\") as f:\n",
" q_table = pickle.load(f)\n",
" \n",
" return q_table\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env = ENVIRONMENT(diagonal=False,size=10,num_enemy = 3, num_food = 1)\n",
"q = get_q_table(size=10)\n",
"# Test Environmet by rendering once\n",
"\n",
"print(env.startover())\n",
"\n",
"for i in range(100):\n",
" print(env.step(np.random.randint(0,4)))\n",
" env.render()\n",
" cv2.waitKey(100)\n",
"cv2.destroyAllWindows()\n",
"\n",
"print(env.startover())\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# for i in range(1):\n",
"# print(env.step(2))\n",
"# env.render()\n",
"# cv2.waitKey(0) \n",
"# cv2.destroyAllWindows()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Improve q-value lookup table\n",
"for episode in range(HM_EPISODES):\n",
" \n",
" state, reward, done = env.startover(newpos=True)\n",
" \n",
" while not done:\n",
" \n",
" current_q = q[state[0],state[1],:]\n",
" \n",
" if random.random() > epsilon:\n",
" \n",
" action = np.argmax(current_q)\n",
"\n",
" else:\n",
" action = np.random.randint(0,4)\n",
"\n",
" next_state, (next_reward, done) = env.step(action)\n",
" total_reward += next_reward\n",
" future_q = q[next_state[0],next_state[1],:]\n",
" q[state[0],state[1],action] = (1 - LEARNING_RATE) * current_q[action] + LEARNING_RATE * ( next_reward + DISCOUNT * max(future_q) - current_q[action])\n",
"\n",
" if done and next_reward == 100:\n",
" q[state[0],state[1] :] = 0\n",
"\n",
"# if done and next_reward == -100:\n",
"# print('Hell i fucked!')\n",
" \n",
" if episode%SHOW_EVERY == 0:\n",
" env.render()\n",
" cv2.waitKey(100)\n",
" \n",
" state = next_state\n",
" \n",
" cv2.destroyAllWindows()\n",
" epsilon *= EPS_DECAY\n",
" \n",
" if episode%DISPLAY_EVERY == 0:\n",
" print('Episode: ',episode,'state:',state,'| Total Average Reward:', total_reward/500,'| Epsilon:', epsilon)\n",
" total_reward= 0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"## TESTING\n",
"# import time\n",
"# with open(f\"qtable-{int(time.time())}.pickle\", \"wb\") as f:\n",
"# pickle.dump(q_table, f)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# from myenv import ENVIRONMENT\n",
"# import cv2 \n",
"# env = ENVIRONMENT(num_enemy = 3, size= 20)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# env.render()\n",
"# cv2.waitKey(0)\n",
"# cv2.destroyAllWindows()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
45 changes: 27 additions & 18 deletions myenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pickle
from matplotlib import style
import time
import numpy as np

style.use("ggplot")

Expand Down Expand Up @@ -38,7 +39,6 @@ def act(self, choice, diagonal = False):
self.move(x=1, y=-1)

else:
print(choice)
if choice == 0:
self.move(x=0, y=1)
elif choice == 1:
Expand All @@ -51,20 +51,16 @@ def act(self, choice, diagonal = False):

def move(self, x=-100, y=-100):

# If no value for x, move randomly
if x == -100:
self.x += np.random.randint(-1, 2)
else:
self.x += x

# If no value for y, move randomly
if y == -100:
self.y += np.random.randint(-1, 2)
else:
self.y += y


# If we are out of bounds, fix!
if self.x < 0:
self.x = 0
elif self.x > self.size-1:
Expand All @@ -79,24 +75,35 @@ class ENVIRONMENT():



def __init__(self, player_number=1, enemy_numer=1, food_number=1, size = 10, DIAGONAL = False):
def __init__(self, num_player=1, num_enemy=1, num_food=1, size = 10, diagonal = False):
self.size = size
self.diagonal = DIAGONAL
self.diagonal = diagonal
self.num_enemy = num_enemy
self.num_food = num_food
self.player = Blob(size)
self.enemy = Blob(size)
self.food = Blob(size)
self.enemy = [Blob() for _ in range(self.num_enemy)]
self.food = [Blob() for _ in range(self.num_food)]
self.reward = 0
self.colors = {1: (255, 0, 0),
2: (0, 255, 0),
3: (0, 0, 255)}
self.px,self.py = self.player.x,self.player.y
self.ex,self.ey = [self.enemy[iter].x for iter in range(self.num_enemy)], [self.enemy[iter].y for iter in range(self.num_enemy)]
self.fx,self.fy = [self.food[iter].x for iter in range(self.num_food)], [self.food[iter].y for iter in range(self.num_food)]


def reset(self):
self.player = Blob(self.size)
self.enemy = Blob(self.size)
self.food = Blob(self.size)
def startover(self, newpos=False):

self.player.x, self.player.y = self.px, self.py
for iter in range(self.num_enemy):
self.enemy[iter].x, self.enemy[iter].y = self.ex[iter], self.ey[iter]
for iter in range(self.num_food):
self.food[iter].x, self.food[iter].y = self.fx[iter], self.fy[iter]
if newpos == True:
self.player = Blob(self.size)
self.reward = 0

return (self.player.x, self.player.y), self.reward
return (self.player.x, self.player.y), self.reward, False

def step(self, action):

Expand All @@ -106,10 +113,10 @@ def step(self, action):

def calculate_reward(self):

if self.player.x == self.enemy.x and self.player.y == self.enemy.y:
if self.player.x in [self.enemy[iter].x for iter in range(self.num_enemy)] and self.player.y in [self.enemy[iter].y for iter in range(self.num_enemy)]:
return -100, True

if self.player.x == self.food.x and self.player.y == self.food.y:
if self.player.x in [self.food[iter].x for iter in range(self.num_food)] and self.player.y in [self.food[iter].y for iter in range(self.num_food)]:
return 100, True

else:
Expand All @@ -119,9 +126,11 @@ def calculate_reward(self):
def render(self):

env = np.zeros((self.size, self.size, 3), dtype=np.uint8)
env[self.food.x][self.food.y] = self.colors[2]
for iter in range(self.num_food):
env[self.food[iter].x][self.food[iter].y] = self.colors[2]
for iter in range(self.num_enemy):
env[self.enemy[iter].x][self.enemy[iter].y] = self.colors[3]
env[self.player.x][self.player.y] = self.colors[1]
env[self.enemy.x][self.enemy.y] = self.colors[3]
img = Image.fromarray(env, 'RGB')
img = img.resize((300, 300))
cv2.imshow("image", np.array(img))
Expand Down

0 comments on commit 62d9b71

Please sign in to comment.