lightsout.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. import gym
  2. import numpy as np
  3. import pickle
  4. import random
  5. import tensorflow as tf
  6. import json
  7. from collections import deque
  8. from tensorflow.keras import Model, Sequential
  9. from tensorflow.keras.layers import Dense, Embedding, Reshape
  10. from tensorflow.keras.optimizers import Adam
  11. from tf_agents.environments import py_environment
  12. from tf_agents.environments import suite_gym
  13. MAX_STEPS = 20 # Maximum number of steps when generating a board
  14. BOARD_ROWS = 5
  15. BOARD_COLS = 5
  16. LIMIT = 100 # start a new game if it takes this many
  17. class LightsOutEnvironment(py_environment.PyEnvironment):
  18. board = None
  19. previous_action = None
  20. def _winner(self):
  21. ''' Returns a 1 if we won '''
  22. for i in range(BOARD_ROWS):
  23. for j in range(BOARD_COLS):
  24. if(self.board[i, j] != 0):
  25. return None
  26. return 1
  27. def _flip(self, value):
  28. if value == 1:
  29. return 0
  30. return 1
  31. def _take_action(self, position, imaginary=False):
  32. ''' Applies the action and returns a new board '''
  33. newboard = self._state.copy()
  34. newboard[position] = self._flip(self._state[position])
  35. # Left
  36. if position[0] > 0:
  37. newboard[(position[0]-1, position[1])] = self._flip(self._state[(position[0]-1, position[1])])
  38. # Right
  39. if position[0] < BOARD_COLS-1:
  40. newboard[(position[0]+1, position[1])] = self._flip(self._state[(position[0]+1, position[1])])
  41. # Up
  42. if position[1] > 0:
  43. newboard[(position[0], position[1]-1)] = self._flip(self._state[(position[0], position[1]-1)])
  44. # Down
  45. if position[1] < BOARD_ROWS-1:
  46. newboard[(position[0], position[1]+1)] = self._flip(self._state[(position[0], position[1]+1)])
  47. if not imaginary:
  48. self.previous_action = position
  49. self._state = newboard
  50. return newboard
  51. def _available_positions(self):
  52. ''' We can push any button except the one we just did '''
  53. positions = []
  54. for i in range(BOARD_ROWS):
  55. for j in range(BOARD_COLS):
  56. if (i, j) != self.previous_action:
  57. positions.append((i, j)) # need to be tuple
  58. return positions
  59. def _gen_solvable_board(self):
  60. ''' Generates a new solvable board '''
  61. self._state = np.zeros((BOARD_ROWS, BOARD_COLS))
  62. steps = self.rng.integers(1, MAX_STEPS)
  63. self.previous_action = None
  64. for i in range(steps):
  65. positions = self.availablePositions()
  66. idx = np.random.choice(len(positions))
  67. action = positions[idx]
  68. self._take_action(position=action, imaginary=False)
  69. def __init__(self):
  70. self.rng = np.random.default_rng()
  71. self._action_spec = array_spec.BoundedArraySpec(
  72. shape=(2,), dtype=np.int, minimum=0, maximum=(BOARD_ROWS - 1, BOARD_COLS - 1), name='action')
  73. self._observation_spec = array_spec.BoundedArraySpec(
  74. shape=(BOARD_ROWS, BOARD_COLS), dtype=np.int, minimum=0, maximum=1, name='observation')
  75. self._gen_solvable_board()
  76. self._episode_ended = False
  77. self.current_steps = 0
  78. def action_spec(self):
  79. return self._action_spec
  80. def observation_spec(self):
  81. return self._observation_spec
  82. def _reset(self):
  83. self._gen_solvable_board()
  84. self._episode_ended = False
  85. self.current_steps = 0
  86. return ts.restart(self._state)
  87. def _step(self, action):
  88. if self._episode_ended:
  89. # The last action ended the episode. Ignore the current action and start
  90. # a new episode.
  91. return self._reset()
  92. self.current_steps += 1
  93. if self.current_steps >= MAX_STEPS:
  94. self._episode_ended = True
  95. return ts.termination(self._state, -1)
  96. elif self._winner():
  97. self._episode_ended = True
  98. return ts.termination(self._state, 1)
  99. else:
  100. self._take_action(action)
  101. return ts.transition(
  102. self._state, reward=0.0, discount=1.0)
  103. def main():
  104. # New tensorflow version
  105. enviroment = gym.make("Taxi-v2").env
  106. enviroment.render()
  107. print('Number of states: {}'.format(enviroment.observation_space.n))
  108. print('Number of actions: {}'.format(enviroment.action_space.n))
  109. class Agent:
  110. def __init__(self, enviroment, optimizer):
  111. # Initialize atributes
  112. self._state_size = enviroment.observation_space.n
  113. self._action_size = enviroment.action_space.n
  114. self._optimizer = optimizer
  115. self.expirience_replay = deque(maxlen=2000)
  116. # Initialize discount and exploration rate
  117. self.gamma = 0.6
  118. self.epsilon = 0.1
  119. # Build networks
  120. self.q_network = self._build_compile_model()
  121. self.target_network = self._build_compile_model()
  122. self.alighn_target_model()
  123. def store(self, state, action, reward, next_state, terminated):
  124. self.expirience_replay.append((state, action, reward, next_state, terminated))
  125. def _build_compile_model(self):
  126. model = Sequential()
  127. model.add(Embedding(self._state_size, 10, input_length=1))
  128. model.add(Reshape((10,)))
  129. model.add(Dense(50, activation='relu'))
  130. model.add(Dense(50, activation='relu'))
  131. model.add(Dense(self._action_size, activation='linear'))
  132. model.compile(loss='mse', optimizer=self._optimizer)
  133. return model
  134. def alighn_target_model(self):
  135. self.target_network.set_weights(self.q_network.get_weights())
  136. def act(self, state):
  137. if np.random.rand() <= self.epsilon:
  138. return enviroment.action_space.sample()
  139. q_values = self.q_network.predict(state)
  140. return np.argmax(q_values[0])
  141. def retrain(self, batch_size):
  142. minibatch = random.sample(self.expirience_replay, batch_size)
  143. for state, action, reward, next_state, terminated in minibatch:
  144. target = self.q_network.predict(state)
  145. if terminated:
  146. target[0][action] = reward
  147. else:
  148. t = self.target_network.predict(next_state)
  149. target[0][action] = reward + self.gamma * np.amax(t)
  150. self.q_network.fit(state, target, epochs=1, verbose=0)
  151. def __init__(self, enviroment, optimizer):
  152. # Initialize atributes
  153. self._state_size = enviroment.observation_space.n
  154. self._action_size = enviroment.action_space.n
  155. self._optimizer = optimizer
  156. self.expirience_replay = deque(maxlen=2000)
  157. # Initialize discount and exploration rate
  158. self.gamma = 0.6
  159. self.epsilon = 0.1
  160. # Build networks
  161. self.q_network = self._build_compile_model()
  162. self.target_network = self._build_compile_model()
  163. self.alighn_target_model()
  164. ######################
  165. ## Old stuff
  166. class State:
  167. def __init__(self, p1):
  168. self.rng = np.random.default_rng()
  169. self._state = np.zeros((BOARD_ROWS, BOARD_COLS))
  170. self.player = p1
  171. self.isEnd = False
  172. self._stateHash = None
  173. # init p1 plays first
  174. self.playerSymbol = 1
  175. self.previous_action = None # We don't allow ourselves to hit the same button 2x
  176. self.record = {}
  177. self.record['wins'] = 0
  178. self.record['losses'] = 0
  179. self.record['longest'] = 0
  180. self.record['shortest'] = LIMIT
  181. self.record['current_rounds'] = 0
  182. self.record['decaying_average_wins'] = 0.0
  183. self.record['decaying_average_moves'] = 1.0 * LIMIT
  184. self.reset()
  185. # get unique hash of current board state
  186. def getHash(self):
  187. self._stateHash = str(self._state.reshape(BOARD_COLS * BOARD_ROWS))
  188. return self._stateHash
  189. def winner(self):
  190. if self.record['current_rounds'] > LIMIT:
  191. return -1
  192. for i in range(BOARD_ROWS):
  193. for j in range(BOARD_COLS):
  194. if(self._state[i, j] != 0):
  195. return None
  196. return 1
  197. def availablePositions(self):
  198. ''' We can push any button except the one we just did '''
  199. positions = []
  200. for i in range(BOARD_ROWS):
  201. for j in range(BOARD_COLS):
  202. if (i, j) != self.previous_action:
  203. positions.append((i, j)) # need to be tuple
  204. return positions
  205. def _flip(self, value):
  206. if value == 1:
  207. return 0
  208. return 1
  209. def updateState(self, position):
  210. ''' Chose action position, so update the board by inverting the lights in a plus '''
  211. self._state[position] = self._flip(self._state[position])
  212. self.previous_action = position
  213. # Left
  214. if position[0] > 0:
  215. self._state[(position[0]-1, position[1])] = self._flip(self._state[(position[0]-1, position[1])])
  216. # Right
  217. if position[0] < BOARD_COLS-1:
  218. self._state[(position[0]+1, position[1])] = self._flip(self._state[(position[0]+1, position[1])])
  219. # Up
  220. if position[1] > 0:
  221. self._state[(position[0], position[1]-1)] = self._flip(self._state[(position[0], position[1]-1)])
  222. # Down
  223. if position[1] < BOARD_ROWS-1:
  224. self._state[(position[0], position[1]+1)] = self._flip(self._state[(position[0], position[1]+1)])
  225. # only when game ends
  226. def giveReward(self):
  227. result = self.winner()
  228. # backpropagate reward
  229. # While we could use result directly, we may want to tune rewards
  230. if result == 1:
  231. #print(f'********* WINNNER *************')
  232. self.record['wins'] += 1
  233. self.record['decaying_average_wins'] = ((99.0 * self.record['decaying_average_wins'] + 1) / 100.0)
  234. self.record['decaying_average_moves'] = ((99.0 * self.record['decaying_average_moves'] + self.record['current_rounds']) / 100.0)
  235. if self.record['current_rounds'] > self.record['longest']:
  236. self.record['longest'] = self.record['current_rounds']
  237. if self.record['current_rounds'] < self.record['shortest']:
  238. self.record['shortest'] = self.record['current_rounds']
  239. self.player.feedReward(1)
  240. elif result == -1:
  241. #print(f'--------- LOSER ---------------')
  242. self.record['losses'] += 1
  243. self.record['decaying_average_wins'] = ((99.0 * self.record['decaying_average_wins'] + 0) / 100.0)
  244. self.record['decaying_average_moves'] = ((99.0 * self.record['decaying_average_moves'] + self.record['current_rounds']) / 100.0)
  245. if self.record['current_rounds'] > self.record['longest']:
  246. self.record['longest'] = self.record['current_rounds']
  247. self.player.feedReward(-1)
  248. else:
  249. self.player.feedReward(0)
  250. def gen_solvable_board(self, steps):
  251. ''' Generates a random solvable board by starting with an empty board
  252. and pressing buttons for 'steps' times
  253. '''
  254. self._state = np.zeros((BOARD_ROWS, BOARD_COLS))
  255. for i in range(steps):
  256. positions = self.availablePositions()
  257. idx = np.random.choice(len(positions))
  258. action = positions[idx]
  259. self.updateState(action)
  260. self.previous_action = None
  261. # board reset
  262. def reset(self):
  263. ''' random board '''
  264. self.gen_solvable_board(self.rng.integers(1, MAX_STEPS))
  265. self._stateHash = str(self._state.reshape(BOARD_COLS * BOARD_ROWS))
  266. self.isEnd = False
  267. self.record['current_rounds'] = 0
  268. self.previous_action = None
  269. def play(self, rounds=100):
  270. showing = False
  271. for i in range(rounds):
  272. if (i % 100) == 99 and not showing:
  273. showing = True
  274. if (i % 100) == 0 and not showing:
  275. #print(f'1000 Rounds. Showing rest of game until win.')
  276. print(f'Round {i}; Stats: {json.dumps(self.record)}')
  277. showing = False
  278. while not self.isEnd:
  279. if showing:
  280. self.showBoard()
  281. # Player
  282. positions = self.availablePositions()
  283. player_action = self.player.chooseAction(positions, self._state)
  284. # take action and upate board state
  285. if showing:
  286. print(f'Step {self.record["current_rounds"]}: Chose position: [{player_action}]')
  287. self.updateState(player_action)
  288. board_hash = self.getHash()
  289. self.player.addState(board_hash)
  290. # check board status if it is end
  291. self.record['current_rounds'] += 1
  292. win = self.winner()
  293. if win is not None:
  294. # self.showBoard()
  295. # ended with p1 either win or draw
  296. self.giveReward()
  297. self.player.reset()
  298. self.reset()
  299. showing = False
  300. break
  301. # play with human
  302. def play2(self):
  303. while not self.isEnd:
  304. self.showBoard()
  305. positions = self.availablePositions()
  306. player_action = self.player.chooseAction(positions, self._state)
  307. # take action and upate board state
  308. self.updateState(player_action)
  309. # check board status if it is end
  310. win = self.winner()
  311. if win is not None:
  312. if win == 1:
  313. print("Player wins!")
  314. else:
  315. print("You have extraordinary patience. But lost.")
  316. self.reset()
  317. break
  318. def showBoard(self):
  319. for i in range(0, BOARD_ROWS):
  320. print('-' * (4 * BOARD_COLS + 1))
  321. out = '| '
  322. for j in range(0, BOARD_COLS):
  323. if self._state[i, j] == 1:
  324. token = 'O'
  325. if self._state[i, j] == 0:
  326. token = ' '
  327. out += token + ' | '
  328. print(out)
  329. print('-' * (4 * BOARD_COLS + 1))
  330. class Player:
  331. def __init__(self, name, exp_rate=0.01):
  332. self.name = name
  333. self.states = [] # record all positions taken
  334. self.lr = 0.2
  335. self.exp_rate = exp_rate
  336. self.decay_gamma = 0.9
  337. self.states_value = {} # state -> value
  338. def getHash(self, board):
  339. boardHash = str(board.reshape(BOARD_COLS * BOARD_ROWS))
  340. return boardHash
  341. def _flip(self, value):
  342. if value == 1:
  343. return 0
  344. return 1
  345. def imagineState(self, newboard, position):
  346. ''' Create a board that would be the state of the action '''
  347. newboard[position] = self._flip(newboard[position])
  348. # Left
  349. if position[0] > 0:
  350. newboard[(position[0]-1, position[1])] = self._flip(newboard[(position[0]-1, position[1])])
  351. # Right
  352. if position[0] < BOARD_COLS-1:
  353. newboard[(position[0]+1, position[1])] = self._flip(newboard[(position[0]+1, position[1])])
  354. # Up
  355. if position[1] > 0:
  356. newboard[(position[0], position[1]-1)] = self._flip(newboard[(position[0], position[1]-1)])
  357. # Down
  358. if position[1] < BOARD_ROWS-1:
  359. newboard[(position[0], position[1]+1)] = self._flip(newboard[(position[0], position[1]+1)])
  360. return newboard
  361. def chooseAction(self, positions, current_board):
  362. value_max = -999
  363. found_good_state = False
  364. if np.random.uniform(0, 1) <= self.exp_rate:
  365. # take random action
  366. idx = np.random.choice(len(positions))
  367. action = positions[idx]
  368. else:
  369. for p in positions:
  370. next_board = current_board.copy()
  371. next_board = self.imagineState(next_board, p)
  372. next_boardHash = self.getHash(next_board)
  373. value = self.states_value.get(next_boardHash)
  374. if value is not None:
  375. found_good_state = True
  376. else:
  377. value = 0.0
  378. # print("value", value)
  379. if value >= value_max:
  380. value_max = value
  381. action = p
  382. # print("{} takes action {}".format(self.name, action))
  383. if not found_good_state:
  384. # We didn't find anything with a value, so explore
  385. idx = np.random.choice(len(positions))
  386. action = positions[idx]
  387. return action
  388. # append a hash state
  389. def addState(self, state):
  390. self.states.append(state)
  391. # at the end of game, backpropagate and update states value
  392. def feedReward(self, reward):
  393. for st in reversed(self.states):
  394. if self.states_value.get(st) is None:
  395. self.states_value[st] = 0
  396. self.states_value[st] += self.lr * (self.decay_gamma * reward - self.states_value[st])
  397. reward = self.states_value[st]
  398. def reset(self):
  399. self.states = []
  400. def savePolicy(self):
  401. fw = open('policy_' + str(self.name), 'wb')
  402. pickle.dump(self.states_value, fw)
  403. fw.close()
  404. def loadPolicy(self, file):
  405. fr = open(file, 'rb')
  406. self.states_value = pickle.load(fr)
  407. fr.close()
  408. class HumanPlayer:
  409. def __init__(self, name):
  410. self.name = name
  411. def chooseAction(self, positions, current_board):
  412. while True:
  413. row = int(input("Input your action row:"))
  414. col = int(input("Input your action col:"))
  415. action = (row, col)
  416. if action in positions:
  417. return action
  418. # append a hash state
  419. def addState(self, state):
  420. pass
  421. # at the end of game, backpropagate and update states value
  422. def feedReward(self, reward):
  423. pass
  424. def reset(self):
  425. pass
  426. if __name__ == "__main__":
  427. # training
  428. player = Player("player")
  429. st = State(player)
  430. print("training...")
  431. st.play(50000)
  432. #player.savePolicy()
  433. # play with human
  434. human = HumanPlayer("human")
  435. st = State(human)
  436. st.play2()