Skip to content

Commit 1acc5d9

Browse files
committed
numba impl cleanup
1 parent 553d53e commit 1acc5d9

File tree

5 files changed

+29
-214
lines changed

5 files changed

+29
-214
lines changed

src_numba/core_numba/mcts.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
from __future__ import annotations
2-
import random
2+
33
import math
4+
import random
5+
46
import numpy as np
57
from numba import njit
6-
from typing import List, Tuple
8+
79
from .othello import (
8-
make_move,
9-
get_valid_moves,
1010
STATE_BLACK_TURN,
11-
STATE_WHITE_TURN,
1211
STATE_BLACK_WON,
13-
STATE_WHITE_WON,
1412
STATE_DRAW,
13+
STATE_WHITE_TURN,
14+
STATE_WHITE_WON,
15+
get_valid_moves,
16+
make_move,
1517
)
1618

1719

18-
def mcts_move(board: np.ndarray, black_score: int, white_score: int, state: int, iterations: int) -> Tuple[int, int]:
20+
def mcts_move(board: np.ndarray, black_score: int, white_score: int, state: int, iterations: int):
1921
"""Returns the best move for the current turn using Monte Carlo Tree Search."""
22+
2023
valid_moves = [tuple(move) for move in get_valid_moves(board, state)] # Convert to list of tuples
2124
root = Node(None, (-1, -1), state, valid_moves)
2225

@@ -68,12 +71,12 @@ def mcts_move(board: np.ndarray, black_score: int, white_score: int, state: int,
6871
class Node:
6972
"""Node of the MCTS tree."""
7073

71-
def __init__(self, parent: Node | None, move: Tuple[int, int], turn: int, unexplored: List[Tuple[int, int]]):
74+
def __init__(self, parent: Node | None, move: tuple[int, int], turn: int, unexplored: list[tuple[int, int]]):
7275
self.move = move
7376
self.turn = turn
7477
self.unexplored = unexplored
7578
self.parent = parent
76-
self.children: List[Node] = []
79+
self.children: list[Node] = []
7780
self.visits = 0
7881
self.wins = 0
7982

src_numba/core_numba/minimax.py

+3
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
def minimax_move(board: np.ndarray, black_score: int, white_score: int, state: int, depth: int) -> Tuple[int, int]:
3636
"""Use minimax to find a good move for the current player. Returns (x, y)."""
37+
3738
moves = [tuple(move) for move in get_valid_moves(board, state)] # Convert to list of tuples
3839
if not moves:
3940
return (-1, -1)
@@ -67,6 +68,7 @@ def _minimax(
6768
beta: float,
6869
) -> Tuple[float, Tuple[int, int]]:
6970
"""Minimax with alpha-beta pruning. Returns (value, (x, y))."""
71+
7072
if depth == 0 or state not in (STATE_BLACK_TURN, STATE_WHITE_TURN):
7173
return _evaluate_board(board, black_score, white_score, state, my_turn), (-1, -1)
7274

@@ -111,6 +113,7 @@ def _minimax(
111113
@njit
112114
def _evaluate_board(board: np.ndarray, black_score: int, white_score: int, state: int, my_turn: int) -> float:
113115
"""Evaluate the board using the REWARDS matrix with Numba-compatible loops."""
116+
114117
if state == STATE_BLACK_WON:
115118
return np.float64(np.inf) if my_turn == STATE_BLACK_TURN else np.float64(-np.inf)
116119
if state == STATE_WHITE_WON:

src_numba/core_numba/othello.py

+14-114
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from numba import jit, njit
2+
from numba import njit
33

44
# Constants replacing Enums
55
CELL_EMPTY = 0
@@ -14,7 +14,19 @@
1414
STATE_DRAW = 5
1515

1616
# Directions for checking flipped cells
17-
DIRECTIONS = np.array([[1, 0], [-1, 0], [0, 1], [0, -1], [1, 1], [-1, -1], [1, -1], [-1, 1]], dtype=np.int32)
17+
DIRECTIONS = np.array(
18+
[
19+
[1, 0],
20+
[-1, 0],
21+
[0, 1],
22+
[0, -1],
23+
[1, 1],
24+
[-1, -1],
25+
[1, -1],
26+
[-1, 1],
27+
],
28+
dtype=np.int32,
29+
)
1830

1931

2032
@njit
@@ -220,115 +232,3 @@ def flipped_cells_in_direction(
220232
if not (0 <= x < 8 and 0 <= y < 8) or board[y, x] != player:
221233
return flipped[:0]
222234
return flipped[:count]
223-
224-
225-
@njit
226-
def make_move_reversible(
227-
board: np.ndarray,
228-
black_score: np.int32,
229-
white_score: np.int32,
230-
state: np.int32,
231-
move_x: int,
232-
move_y: int,
233-
):
234-
"""Make a move and return flipped cells for undoing. Returns (board, black_score, white_score, state, flipped)."""
235-
236-
if state not in (STATE_BLACK_TURN, STATE_WHITE_TURN) or board[move_y, move_x] != 3:
237-
return board, black_score, white_score, state, np.zeros((0, 2), dtype=np.int32)
238-
239-
player = CELL_BLACK if state == STATE_BLACK_TURN else CELL_WHITE
240-
opponent = CELL_WHITE if state == STATE_BLACK_TURN else CELL_BLACK
241-
board[move_y, move_x] = player
242-
243-
flipped = get_flipped_cells(board, move_x, move_y, player, opponent)
244-
num_flipped = flipped.shape[0]
245-
for x, y in flipped:
246-
board[y, x] = player
247-
248-
# Update scores
249-
if player == CELL_BLACK:
250-
black_score += 1 + num_flipped
251-
white_score -= num_flipped
252-
else:
253-
white_score += 1 + num_flipped
254-
black_score -= num_flipped
255-
256-
board, black_score, white_score, new_state, _ = update_state(board, black_score, white_score, state)
257-
return board, black_score, white_score, new_state, flipped
258-
259-
260-
@njit
261-
def undo_move(
262-
board: np.ndarray,
263-
move_x: int,
264-
move_y: int,
265-
flipped: np.ndarray,
266-
original_state: np.int32,
267-
):
268-
"""Undo a move by restoring the board."""
269-
270-
player = CELL_BLACK if original_state == STATE_BLACK_TURN else CELL_WHITE
271-
opponent = CELL_WHITE if original_state == STATE_BLACK_TURN else CELL_BLACK
272-
273-
# Restore flipped cells to opponent
274-
for x, y in flipped:
275-
board[y, x] = opponent
276-
277-
# Clear the move position
278-
board[move_y, move_x] = 0 # Will be updated to VALID later if needed
279-
280-
281-
# Example usage (not Numba-compiled)
282-
def print_board(board: np.ndarray) -> None:
283-
"""Print the board for debugging."""
284-
285-
for row in board:
286-
print(
287-
" ".join(
288-
["." if c == CELL_EMPTY else "V" if c == CELL_VALID else "B" if c == CELL_BLACK else "W" for c in row]
289-
)
290-
)
291-
print()
292-
293-
294-
def state_to_str(state: int) -> str:
295-
"""Convert state to string for printing."""
296-
297-
return {
298-
STATE_BLACK_TURN: "Black's turn",
299-
STATE_WHITE_TURN: "White's turn",
300-
STATE_BLACK_WON: "Black won",
301-
STATE_WHITE_WON: "White won",
302-
STATE_DRAW: "Draw",
303-
}.get(state, "Unknown")
304-
305-
306-
if __name__ == "__main__":
307-
# Initialize game
308-
board, black_score, white_score, state = init_game()
309-
print("Initial board:")
310-
print_board(board)
311-
312-
# Test a valid move
313-
print("Making move (3, 2):")
314-
board, black_score, white_score, state, success = make_move(board, black_score, white_score, state, 3, 2)
315-
if success:
316-
print("Move successful!")
317-
print_board(board)
318-
print(f"Black score: {black_score}, White score: {white_score}")
319-
print(f"State: {state_to_str(state)}")
320-
321-
# Print valid moves
322-
valid_moves = get_valid_moves(board, state)
323-
print("Valid moves:", valid_moves.tolist())
324-
else:
325-
print("Move failed!")
326-
327-
# Test an invalid move
328-
print("\nMaking invalid move (0, 0):")
329-
board, black_score, white_score, state, success = make_move(board, black_score, white_score, state, 0, 0)
330-
if success:
331-
print("Move successful!")
332-
print_board(board)
333-
else:
334-
print("Move failed!")

src_numba/core_numba/ui.py

-55
This file was deleted.

src_numba/main.py

-36
This file was deleted.

0 commit comments

Comments
 (0)