Skip to content

Commit 5ebaaa2

Browse files
committed
Implementation of AI algorithm and AI game
Implement the algorith MCTS to perform tic-tac-toe game in kernel space, in order to achieve the algorithm we also implment fixed-point arithmetic including logarithm and integer square root. The main control loop of AI tic-tac-toe game is the function "ai_game()" which is located in simrupt.c now. However it present a severe problem which will corrupt too much CPU computation time within one game. Future improvements are needed to seperate the function into multiple tasklet.
1 parent 7c3f05a commit 5ebaaa2

File tree

8 files changed

+430
-27
lines changed

8 files changed

+430
-27
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
TARGET = kmldrv
2-
kmldrv-objs = simrupt.o
2+
kmldrv-objs = simrupt.o game.o mcts.o
33
obj-m := $(TARGET).o
44

55
KDIR ?= /lib/modules/$(shell uname -r)/build

game.c

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#include <linux/slab.h>
2+
3+
#include "game.h"
4+
5+
6+
const line_t lines[4] = {
7+
{1, 0, 0, 0, BOARD_SIZE - GOAL + 1, BOARD_SIZE}, // ROW
8+
{0, 1, 0, 0, BOARD_SIZE, BOARD_SIZE - GOAL + 1}, // COL
9+
{1, 1, 0, 0, BOARD_SIZE - GOAL + 1, BOARD_SIZE - GOAL + 1}, // PRIMARY
10+
{1, -1, 0, GOAL - 1, BOARD_SIZE - GOAL + 1, BOARD_SIZE}, // SECONDARY
11+
};
12+
13+
static char check_line_segment_win(const char *t, int i, int j, line_t line)
14+
{
15+
char last = t[GET_INDEX(i, j)];
16+
if (last == ' ')
17+
return ' ';
18+
for (int k = 1; k < GOAL; k++) {
19+
if (last != t[GET_INDEX(i + k * line.i_shift, j + k * line.j_shift)])
20+
return ' ';
21+
}
22+
23+
#if !ALLOW_EXCEED
24+
if (last == LOOKUP(t, i - line.i_shift, j - line.j_shift, ' ') ||
25+
last ==
26+
LOOKUP(t, i + GOAL * line.i_shift, j + GOAL * line.j_shift, ' '))
27+
return ' ';
28+
#endif
29+
return last;
30+
}
31+
32+
char check_win(char *t)
33+
{
34+
for (int i_line = 0; i_line < 4; ++i_line) {
35+
line_t line = lines[i_line];
36+
for (int i = line.i_lower_bound; i < line.i_upper_bound; ++i) {
37+
for (int j = line.j_lower_bound; j < line.j_upper_bound; ++j) {
38+
char win = check_line_segment_win(t, i, j, line);
39+
if (win != ' ')
40+
return win;
41+
}
42+
}
43+
}
44+
for (int i = 0; i < N_GRIDS; i++)
45+
if (t[i] == ' ')
46+
return ' ';
47+
return 'D';
48+
}
49+
50+
fixed_point_t calculate_win_value(char win, char player)
51+
{
52+
if (win == player)
53+
return 1U << FIXED_SCALE_BITS;
54+
if (win == (player ^ 'O' ^ 'X'))
55+
return 0U;
56+
return 1U << (FIXED_SCALE_BITS - 1);
57+
}
58+
59+
int *available_moves(const char *table)
60+
{
61+
int *moves = kzalloc(N_GRIDS * sizeof(int), GFP_KERNEL);
62+
int m = 0;
63+
for (int i = 0; i < N_GRIDS; i++)
64+
if (table[i] == ' ')
65+
moves[m++] = i;
66+
if (m < N_GRIDS)
67+
moves[m] = -1;
68+
return moves;
69+
}

game.h

+13-1
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,23 @@ typedef struct {
1717
int i_lower_bound, j_lower_bound, i_upper_bound, j_upper_bound;
1818
} line_t;
1919

20+
/* Self-defined fixed-point type, using last 10 bits as fractional bits,
21+
* starting from lsb */
22+
#define FIXED_SCALE_BITS 8
23+
#define FIXED_MAX (~0U)
24+
#define FIXED_MIN (0U)
25+
#define GET_SIGN(x) ((x) & (1U << 31))
26+
#define SET_SIGN(x) ((x) | (1U << 31))
27+
#define CLR_SIGN(x) ((x) & ((1U << 31) - 1U))
28+
typedef unsigned fixed_point_t;
29+
2030
#define DRAW_SIZE (N_GRIDS + BOARD_SIZE)
2131
#define DRAWBUFFER_SIZE \
2232
((BOARD_SIZE * (BOARD_SIZE + 1) << 1) + (BOARD_SIZE * BOARD_SIZE) + \
2333
((BOARD_SIZE << 1) + 1))
2434

2535
extern const line_t lines[4];
2636

27-
// void draw_board(const char *t);
37+
int *available_moves(const char *table);
38+
char check_win(char *t);
39+
fixed_point_t calculate_win_value(char win, char player);

mcts.c

+206
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
#include <linux/slab.h>
2+
#include <linux/string.h>
3+
4+
#include "game.h"
5+
#include "mcts.h"
6+
#include "util.h"
7+
#include "wyhash.h"
8+
9+
struct node {
10+
int move;
11+
char player;
12+
int n_visits;
13+
fixed_point_t score;
14+
struct node *parent;
15+
struct node *children[N_GRIDS];
16+
};
17+
18+
static struct node *new_node(int move, char player, struct node *parent)
19+
{
20+
struct node *node = kzalloc(sizeof(struct node), GFP_KERNEL);
21+
node->move = move;
22+
node->player = player;
23+
node->n_visits = 0;
24+
node->score = 0;
25+
node->parent = parent;
26+
memset(node->children, 0, sizeof(node->children));
27+
return node;
28+
}
29+
30+
static void free_node(struct node *node)
31+
{
32+
for (int i = 0; i < N_GRIDS; i++)
33+
if (node->children[i])
34+
free_node(node->children[i]);
35+
kfree(node);
36+
}
37+
38+
fixed_point_t fixed_sqrt(fixed_point_t x)
39+
{
40+
if (!x || x == (1U << FIXED_SCALE_BITS))
41+
return x;
42+
43+
fixed_point_t s = 0U;
44+
for (int i = (31 - __builtin_clz(x | 1)); i >= 0; i--) {
45+
fixed_point_t t = (1U << i);
46+
if ((((s + t) * (s + t)) >> FIXED_SCALE_BITS) <= x)
47+
s += t;
48+
}
49+
return s;
50+
}
51+
52+
fixed_point_t fixed_log(fixed_point_t v)
53+
{
54+
if (!v || v == (1U << FIXED_SCALE_BITS))
55+
return 0;
56+
57+
fixed_point_t numerator = (v - (1U << FIXED_SCALE_BITS));
58+
int neg = 0;
59+
if (GET_SIGN(numerator)) {
60+
neg = 1;
61+
numerator = CLR_SIGN(numerator);
62+
numerator = (1U << 31) - numerator;
63+
}
64+
65+
fixed_point_t y =
66+
(numerator << FIXED_SCALE_BITS) / (v + (1U << FIXED_SCALE_BITS));
67+
68+
fixed_point_t ans = 0U;
69+
for (unsigned i = 1; i < 20; i += 2) {
70+
fixed_point_t z = (1U << FIXED_SCALE_BITS);
71+
for (int j = 0; j < i; j++) {
72+
z *= y;
73+
z >>= FIXED_SCALE_BITS;
74+
}
75+
z <<= FIXED_SCALE_BITS;
76+
z /= (i << FIXED_SCALE_BITS);
77+
78+
ans += z;
79+
}
80+
ans <<= 1;
81+
ans = neg ? SET_SIGN(ans) : ans;
82+
return ans;
83+
}
84+
85+
static inline fixed_point_t uct_score(int n_total,
86+
int n_visits,
87+
fixed_point_t score)
88+
{
89+
if (n_visits == 0)
90+
return FIXED_MAX;
91+
92+
fixed_point_t result =
93+
score << FIXED_SCALE_BITS /
94+
(fixed_point_t) (n_visits << FIXED_SCALE_BITS);
95+
fixed_point_t tmp =
96+
EXPLORATION_FACTOR *
97+
fixed_sqrt(fixed_log(n_total << FIXED_SCALE_BITS) / n_visits);
98+
tmp >>= FIXED_SCALE_BITS;
99+
return result + tmp;
100+
}
101+
102+
static struct node *select_move(struct node *node)
103+
{
104+
struct node *best_node = NULL;
105+
fixed_point_t best_score = 0U;
106+
for (int i = 0; i < N_GRIDS; i++) {
107+
if (!node->children[i])
108+
continue;
109+
fixed_point_t score =
110+
uct_score(node->n_visits, node->children[i]->n_visits,
111+
node->children[i]->score);
112+
if (score > best_score) {
113+
best_score = score;
114+
best_node = node->children[i];
115+
}
116+
}
117+
return best_node;
118+
}
119+
120+
static fixed_point_t simulate(char *table, char player)
121+
{
122+
char current_player = player;
123+
char temp_table[N_GRIDS];
124+
memcpy(temp_table, table, N_GRIDS);
125+
while (1) {
126+
int *moves = available_moves(temp_table);
127+
if (moves[0] == -1) {
128+
kfree(moves);
129+
break;
130+
}
131+
int n_moves = 0;
132+
while (n_moves < N_GRIDS && moves[n_moves] != -1)
133+
++n_moves;
134+
int move = moves[wyhash64() % n_moves];
135+
kfree(moves);
136+
temp_table[move] = current_player;
137+
char win;
138+
if ((win = check_win(temp_table)) != ' ')
139+
return calculate_win_value(win, player);
140+
current_player ^= 'O' ^ 'X';
141+
}
142+
return (fixed_point_t) (1UL << (FIXED_SCALE_BITS - 1));
143+
}
144+
145+
static void backpropagate(struct node *node, fixed_point_t score)
146+
{
147+
while (node) {
148+
node->n_visits++;
149+
node->score += score;
150+
node = node->parent;
151+
score = 1 - score;
152+
}
153+
}
154+
155+
static void expand(struct node *node, char *table)
156+
{
157+
int *moves = available_moves(table);
158+
int n_moves = 0;
159+
while (n_moves < N_GRIDS && moves[n_moves] != -1)
160+
++n_moves;
161+
for (int i = 0; i < n_moves; i++) {
162+
node->children[i] = new_node(moves[i], node->player ^ 'O' ^ 'X', node);
163+
}
164+
kfree(moves);
165+
}
166+
167+
int mcts(char *table, char player)
168+
{
169+
char win;
170+
struct node *root = new_node(-1, player, NULL);
171+
for (int i = 0; i < ITERATIONS; i++) {
172+
struct node *node = root;
173+
char temp_table[N_GRIDS];
174+
memcpy(temp_table, table, N_GRIDS);
175+
while (1) {
176+
if ((win = check_win(temp_table)) != ' ') {
177+
fixed_point_t score =
178+
calculate_win_value(win, node->player ^ 'O' ^ 'X');
179+
backpropagate(node, score);
180+
break;
181+
}
182+
if (node->n_visits == 0) {
183+
fixed_point_t score = simulate(temp_table, node->player);
184+
backpropagate(node, score);
185+
break;
186+
}
187+
if (node->children[0] == NULL)
188+
expand(node, temp_table);
189+
node = select_move(node);
190+
if (!node)
191+
return -1;
192+
temp_table[node->move] = node->player ^ 'O' ^ 'X';
193+
}
194+
}
195+
struct node *best_node = root;
196+
int most_visits = -1;
197+
for (int i = 0; i < N_GRIDS; i++) {
198+
if (root->children[i] && root->children[i]->n_visits > most_visits) {
199+
most_visits = root->children[i]->n_visits;
200+
best_node = root->children[i];
201+
}
202+
}
203+
int best_move = best_node->move;
204+
free_node(root);
205+
return best_move;
206+
}

mcts.h

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
3+
#define ITERATIONS 100000
4+
#define EXPLORATION_FACTOR fixed_sqrt(1U << (FIXED_SCALE_BITS + 1))
5+
6+
int mcts(char *table, char player);

0 commit comments

Comments
 (0)