lightsout.tablebased.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. import json
  2. import numpy as np
  3. import pickle
  4. MAX_STEPS = 20 # Maximum number of steps when generating a board
  5. BOARD_ROWS = 5
  6. BOARD_COLS = 5
  7. LIMIT = 100 # start a new game if it takes this many
  8. class State:
  9. def __init__(self, p1):
  10. self.rng = np.random.default_rng()
  11. self.board = np.zeros((BOARD_ROWS, BOARD_COLS))
  12. self.player = p1
  13. self.isEnd = False
  14. self.boardHash = None
  15. # init p1 plays first
  16. self.playerSymbol = 1
  17. self.previous_action = None # We don't allow ourselves to hit the same button 2x
  18. self.record = {}
  19. self.record['wins'] = 0
  20. self.record['losses'] = 0
  21. self.record['longest'] = 0
  22. self.record['shortest'] = LIMIT
  23. self.record['current_rounds'] = 0
  24. self.record['decaying_average_wins'] = 0.0
  25. self.record['decaying_average_moves'] = 1.0 * LIMIT
  26. self.reset()
  27. # get unique hash of current board state
  28. def getHash(self):
  29. self.boardHash = str(self.board.reshape(BOARD_COLS * BOARD_ROWS))
  30. return self.boardHash
  31. def winner(self):
  32. if self.record['current_rounds'] > LIMIT:
  33. return -1
  34. for i in range(BOARD_ROWS):
  35. for j in range(BOARD_COLS):
  36. if(self.board[i, j] != 0):
  37. return None
  38. return 1
  39. def availablePositions(self):
  40. ''' We can push any button except the one we just did '''
  41. positions = []
  42. for i in range(BOARD_ROWS):
  43. for j in range(BOARD_COLS):
  44. if (i, j) != self.previous_action:
  45. positions.append((i, j)) # need to be tuple
  46. return positions
  47. def _flip(self, value):
  48. if value == 1:
  49. return 0
  50. return 1
  51. def updateState(self, position):
  52. ''' Chose action position, so update the board by inverting the lights in a plus '''
  53. self.board[position] = self._flip(self.board[position])
  54. self.previous_action = position
  55. # Left
  56. if position[0] > 0:
  57. self.board[(position[0]-1, position[1])] = self._flip(self.board[(position[0]-1, position[1])])
  58. # Right
  59. if position[0] < BOARD_COLS-1:
  60. self.board[(position[0]+1, position[1])] = self._flip(self.board[(position[0]+1, position[1])])
  61. # Up
  62. if position[1] > 0:
  63. self.board[(position[0], position[1]-1)] = self._flip(self.board[(position[0], position[1]-1)])
  64. # Down
  65. if position[1] < BOARD_ROWS-1:
  66. self.board[(position[0], position[1]+1)] = self._flip(self.board[(position[0], position[1]+1)])
  67. # only when game ends
  68. def giveReward(self):
  69. result = self.winner()
  70. # backpropagate reward
  71. # While we could use result directly, we may want to tune rewards
  72. if result == 1:
  73. #print(f'********* WINNNER *************')
  74. self.record['wins'] += 1
  75. self.record['decaying_average_wins'] = ((99.0 * self.record['decaying_average_wins'] + 1) / 100.0)
  76. self.record['decaying_average_moves'] = ((99.0 * self.record['decaying_average_moves'] + self.record['current_rounds']) / 100.0)
  77. if self.record['current_rounds'] > self.record['longest']:
  78. self.record['longest'] = self.record['current_rounds']
  79. if self.record['current_rounds'] < self.record['shortest']:
  80. self.record['shortest'] = self.record['current_rounds']
  81. self.player.feedReward(1)
  82. elif result == -1:
  83. #print(f'--------- LOSER ---------------')
  84. self.record['losses'] += 1
  85. self.record['decaying_average_wins'] = ((99.0 * self.record['decaying_average_wins'] + 0) / 100.0)
  86. self.record['decaying_average_moves'] = ((99.0 * self.record['decaying_average_moves'] + self.record['current_rounds']) / 100.0)
  87. if self.record['current_rounds'] > self.record['longest']:
  88. self.record['longest'] = self.record['current_rounds']
  89. self.player.feedReward(-1)
  90. else:
  91. self.player.feedReward(0)
  92. def gen_solvable_board(self, steps):
  93. ''' Generates a random solvable board by starting with an empty board
  94. and pressing buttons for 'steps' times
  95. '''
  96. self.board = np.zeros((BOARD_ROWS, BOARD_COLS))
  97. for i in range(steps):
  98. positions = self.availablePositions()
  99. idx = np.random.choice(len(positions))
  100. action = positions[idx]
  101. self.updateState(action)
  102. self.previous_action = None
  103. # board reset
  104. def reset(self):
  105. ''' random board '''
  106. self.gen_solvable_board(self.rng.integers(1, MAX_STEPS))
  107. self.boardHash = str(self.board.reshape(BOARD_COLS * BOARD_ROWS))
  108. self.isEnd = False
  109. self.record['current_rounds'] = 0
  110. self.previous_action = None
  111. def play(self, rounds=100):
  112. showing = False
  113. for i in range(rounds):
  114. if (i % 100) == 99 and not showing:
  115. showing = True
  116. if (i % 100) == 0 and not showing:
  117. #print(f'1000 Rounds. Showing rest of game until win.')
  118. print(f'Round {i}; Stats: {json.dumps(self.record)}')
  119. showing = False
  120. while not self.isEnd:
  121. if showing:
  122. self.showBoard()
  123. # Player
  124. positions = self.availablePositions()
  125. player_action = self.player.chooseAction(positions, self.board)
  126. # take action and upate board state
  127. if showing:
  128. print(f'Step {self.record["current_rounds"]}: Chose position: [{player_action}]')
  129. self.updateState(player_action)
  130. board_hash = self.getHash()
  131. self.player.addState(board_hash)
  132. # check board status if it is end
  133. self.record['current_rounds'] += 1
  134. win = self.winner()
  135. if win is not None:
  136. # self.showBoard()
  137. # ended with p1 either win or draw
  138. self.giveReward()
  139. self.player.reset()
  140. self.reset()
  141. showing = False
  142. break
  143. # play with human
  144. def play2(self):
  145. while not self.isEnd:
  146. self.showBoard()
  147. positions = self.availablePositions()
  148. player_action = self.player.chooseAction(positions, self.board)
  149. # take action and upate board state
  150. self.updateState(player_action)
  151. # check board status if it is end
  152. win = self.winner()
  153. if win is not None:
  154. if win == 1:
  155. print("Player wins!")
  156. else:
  157. print("You have extraordinary patience. But lost.")
  158. self.reset()
  159. break
  160. def showBoard(self):
  161. for i in range(0, BOARD_ROWS):
  162. print('-' * (4 * BOARD_COLS + 1))
  163. out = '| '
  164. for j in range(0, BOARD_COLS):
  165. if self.board[i, j] == 1:
  166. token = 'O'
  167. if self.board[i, j] == 0:
  168. token = ' '
  169. out += token + ' | '
  170. print(out)
  171. print('-' * (4 * BOARD_COLS + 1))
  172. class Player:
  173. def __init__(self, name, exp_rate=0.01):
  174. self.name = name
  175. self.states = [] # record all positions taken
  176. self.lr = 0.2
  177. self.exp_rate = exp_rate
  178. self.decay_gamma = 0.9
  179. self.states_value = {} # state -> value
  180. def getHash(self, board):
  181. boardHash = str(board.reshape(BOARD_COLS * BOARD_ROWS))
  182. return boardHash
  183. def _flip(self, value):
  184. if value == 1:
  185. return 0
  186. return 1
  187. def imagineState(self, newboard, position):
  188. ''' Create a board that would be the state of the action '''
  189. newboard[position] = self._flip(newboard[position])
  190. # Left
  191. if position[0] > 0:
  192. newboard[(position[0]-1, position[1])] = self._flip(newboard[(position[0]-1, position[1])])
  193. # Right
  194. if position[0] < BOARD_COLS-1:
  195. newboard[(position[0]+1, position[1])] = self._flip(newboard[(position[0]+1, position[1])])
  196. # Up
  197. if position[1] > 0:
  198. newboard[(position[0], position[1]-1)] = self._flip(newboard[(position[0], position[1]-1)])
  199. # Down
  200. if position[1] < BOARD_ROWS-1:
  201. newboard[(position[0], position[1]+1)] = self._flip(newboard[(position[0], position[1]+1)])
  202. return newboard
  203. def chooseAction(self, positions, current_board):
  204. value_max = -999
  205. found_good_state = False
  206. if np.random.uniform(0, 1) <= self.exp_rate:
  207. # take random action
  208. idx = np.random.choice(len(positions))
  209. action = positions[idx]
  210. else:
  211. for p in positions:
  212. next_board = current_board.copy()
  213. next_board = self.imagineState(next_board, p)
  214. next_boardHash = self.getHash(next_board)
  215. value = self.states_value.get(next_boardHash)
  216. if value is not None:
  217. found_good_state = True
  218. else:
  219. value = 0.0
  220. # print("value", value)
  221. if value >= value_max:
  222. value_max = value
  223. action = p
  224. # print("{} takes action {}".format(self.name, action))
  225. if not found_good_state:
  226. # We didn't find anything with a value, so explore
  227. idx = np.random.choice(len(positions))
  228. action = positions[idx]
  229. return action
  230. # append a hash state
  231. def addState(self, state):
  232. self.states.append(state)
  233. # at the end of game, backpropagate and update states value
  234. def feedReward(self, reward):
  235. for st in reversed(self.states):
  236. if self.states_value.get(st) is None:
  237. self.states_value[st] = 0
  238. self.states_value[st] += self.lr * (self.decay_gamma * reward - self.states_value[st])
  239. reward = self.states_value[st]
  240. def reset(self):
  241. self.states = []
  242. def savePolicy(self):
  243. fw = open('policy_' + str(self.name), 'wb')
  244. pickle.dump(self.states_value, fw)
  245. fw.close()
  246. def loadPolicy(self, file):
  247. fr = open(file, 'rb')
  248. self.states_value = pickle.load(fr)
  249. fr.close()
  250. class HumanPlayer:
  251. def __init__(self, name):
  252. self.name = name
  253. def chooseAction(self, positions, current_board):
  254. while True:
  255. row = int(input("Input your action row:"))
  256. col = int(input("Input your action col:"))
  257. action = (row, col)
  258. if action in positions:
  259. return action
  260. # append a hash state
  261. def addState(self, state):
  262. pass
  263. # at the end of game, backpropagate and update states value
  264. def feedReward(self, reward):
  265. pass
  266. def reset(self):
  267. pass
  268. if __name__ == "__main__":
  269. # training
  270. player = Player("player")
  271. st = State(player)
  272. print("training...")
  273. st.play(50000)
  274. #player.savePolicy()
  275. # play with human
  276. human = HumanPlayer("human")
  277. st = State(human)
  278. st.play2()