A clean implementation based on AlphaZero for any game in any framework + tutorial + Othello/Gobang/TicTacToe/Connect4 and more


Alpha Zero General (any game, any framework!)

A simplified, highly flexible, commented and (hopefully) easy to understand implementation of self-play based reinforcement learning based on the AlphaGo Zero paper (Silver et al). It is designed to be easy to adopt for any two-player turn-based adversarial game and any deep learning framework of your choice. A sample implementation has been provided for the game of Othello in PyTorch, Keras, TensorFlow and Chainer. An accompanying tutorial can be found here. We also have implementations for GoBang and TicTacToe.

To use a game of your choice, subclass the classes in Game.py and NeuralNet.py and implement their functions. Example implementations for Othello can be found in othello/OthelloGame.py and othello/{pytorch,keras,tensorflow,chainer}/NNet.py.

Coach.py contains the core training loop and MCTS.py performs the Monte Carlo Tree Search. The parameters for the self-play can be specified in main.py. Additional neural network parameters are in othello/{pytorch,keras,tensorflow,chainer}/NNet.py (cuda flag, batch size, epochs, learning rate etc.).

To start training a model for Othello:

python main.py

Choose your framework and game in main.py.

Docker Installation

For easy environment setup, we can use nvidia-docker. Once you have nvidia-docker set up, we can then simply run:


to set up a (default: pyTorch) Jupyter docker container. We can now open a new terminal and enter:

docker exec -ti pytorch_notebook python main.py


We trained a PyTorch model for 6x6 Othello (~80 iterations, 100 episodes per iteration and 25 MCTS simulations per turn). This took about 3 days on an NVIDIA Tesla K80. The pretrained model (PyTorch) can be found in pretrained_models/othello/pytorch/. You can play a game against it using pit.py. Below is the performance of the model against a random and a greedy baseline with the number of iterations. alt tag

A concise description of our algorithm can be found here.


While the current code is fairly functional, we could benefit from the following contributions:

  • Game logic files for more games that follow the specifications in Game.py, along with their neural networks
  • Neural networks in other frameworks
  • Pre-trained models for different game configurations
  • An asynchronous version of the code- parallel processes for self-play, neural net training and model comparison.
  • Asynchronous MCTS as described in the paper

Contributors and Credits

  • invalid value in division

    invalid value in division

    When I ran 'main.py' under tensorflow, I also got runtime warning:

    '.../othello/tensorflow/NNet.py: 103 RuntimeWarning: invalid value encountered in divide

    pi = np.exp(pi) / np.sum(np.exp(pi)) "

    Here we calculated softmax. When one value in pi is large, the denominator may overflow.

    we can try the following method to avoid it.

    x = np.exp(pi - np.max(pi)) pi = x/x.sum()

    But not sure whether the above warning was caused by overflow.


  • Masking Ps[s]*valids may give an array of zeros

    Masking Ps[s]*valids may give an array of zeros


    In this line we divide by sum of initial policy which were previously masked by valid moves (valids).

    I observe the cases when product of Ps[s]*valids is an array of zeros. So sum(Ps[s]) is also zero and in the given line we have numpy warning about "division by zero" after that all Ps[s][a] become NaN. Numpy doesn't raise an error so we go on and in the next visits of the state [s] we have best_act = -1 and then a = -1. Then we invoke next_s, next_player = self.game.getNextState(canonicalBoard, 1, a=-1) which is conceptually illegal operation.

    Such cases occur in the PIT phase of 1st iteration and then continue. When I debug a case I see that nn.predict() returns few nonzero values inside Ps[s] which don't match any value from valids and its product gives all zeros. On the next iterations a number of cases is gradually decreasing.

    I think that occurences of those cases depend on random number generator because they are platform dependent. If I run the same program on different physical computers I may or may not observe those cases.

    The question is, should we detect such cases and try to avoid NaNs in Ps[s]? If the case is detected we can for example make all valid moves equally probable, i.e.
    self.Ps[s] = self.Ps[s] + valids self.Ps[s] /= np.sum(self.Ps[s])

  • Add 3D TicTacToe

    Add 3D TicTacToe

    Adds 3D TicTacToe. Also adds pretrained model (although not trained for very long). Not sure that all the code is very clean/well coded so might need checking (seems to work though).

  • Max recursion depth exceeded

    Max recursion depth exceeded

    Othello TF per below on iteration 1, epoch 10. BTW: With TF less verbosity would be helpful (400 lines per epoch)

    Training Net |############################### | (400/403) Data: 0.000s | Batch: 0.031s | Total: 0:00:12 | ETA: 0:00:01 | Loss_pi: 3.4790 |Training Net |############################### | (401/403) Data: 0.000s | Batch: 0.031s | Total: 0:00:12 | ETA: 0:00:01 | Loss_pi: 3.4789 |Training Net |############################### | (402/403) Data: 0.000s | Batch: 0.031s | Total: 0:00:12 | ETA: 0:00:01 | Loss_pi: 3.4789 |Training Net |################################| (403/403) Data: 0.000s | Batch: 0.031s | Total: 0:00:12 | ETA: 0:00:01 | Loss_pi: 3.4789 | Loss_v: 0.382 PITTING AGAINST PREVIOUS VERSION /home/brian/hitme/bin/alpha-zero-general/MCTS.py:80: RuntimeWarning: invalid value encountered in true_divide self.Ps[s] /= np.sum(self.Ps[s]) # renormalize Traceback (most recent call last): File "main.py", line 29, in c.learn() File "/home/brian/hitme/bin/alpha-zero-general/Coach.py", line 99, in learn pwins, nwins, draws = arena.playGames(self.args.arenaCompare) File "/home/brian/hitme/bin/alpha-zero-general/Arena.py", line 81, in playGames gameResult = self.playGame(verbose=verbose) File "/home/brian/hitme/bin/alpha-zero-general/Arena.py", line 46, in playGame action = players[curPlayer+1](self.game.getCanonicalForm(board, curPlayer)) File "/home/brian/hitme/bin/alpha-zero-general/Coach.py", line 98, in lambda x: np.argmax(nmcts.getActionProb(x, temp=0)), self.game) File "/home/brian/hitme/bin/alpha-zero-general/MCTS.py", line 31, in getActionProb self.search(canonicalBoard) File "/home/brian/hitme/bin/alpha-zero-general/MCTS.py", line 106, in search v = self.search(next_s) File "/home/brian/hitme/bin/alpha-zero-general/MCTS.py", line 106, in search v = self.search(next_s) File "/home/brian/hitme/bin/alpha-zero-general/MCTS.py", line 106, in search v = self.search(next_s) [Previous line repeated 983 more times] File "/home/brian/hitme/bin/alpha-zero-general/MCTS.py", line 103, in search next_s, next_player = self.game.getNextState(canonicalBoard, 1, a) File "/home/brian/hitme/bin/alpha-zero-general/othello/OthelloGame.py", line 31, in getNextState b = Board(self.n) File "/home/brian/hitme/bin/alpha-zero-general/othello/OthelloLogic.py", line 24, in init for i in range(self.n): RecursionError: maximum recursion depth exceeded in comparison

  • Keras out of GPU memory

    Keras out of GPU memory

    Othello Keras after 34 iterations was out of gpu memory. I have an 8GB GTX1070 but limited to per_process_gpu_memory_fraction = 0.4 (about 3.2GB)

    Of course, I can run with more, but perhaps there should be some gpu memory size guidelines in the readme, assuming it is not an error.

    Caused by op 'batch_normalization_199/FusedBatchNorm', defined at: File "main.py", line 29, in c.learn() File "/home/brian/hitme/bin/alpha-zero-general/Coach.py", line 90, in learn pnet = self.nnet.class(self.game) File "/home/brian/hitme/bin/alpha-zero-general/othello/keras/NNet.py", line 27, in init self.nnet = onnet(game, args) File "/home/brian/hitme/bin/alpha-zero-general/othello/keras/OthelloNNet.py", line 37, in init h_conv1 = Activation('relu')(BatchNormalization(axis=3)(Conv2D(args.num_channels, 3, padding='same')(x_image))) # batch_size x board_x x board_y x num_channels File "/home/brian/hitme/lib/python3.6/site-packages/keras/engine/topology.py", line 617, in call output = self.call(inputs, **kwargs) File "/home/brian/hitme/lib/python3.6/site-packages/keras/layers/normalization.py", line 181, in call epsilon=self.epsilon) File "/home/brian/hitme/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 1824, in normalize_batch_in_training epsilon=epsilon) File "/home/brian/hitme/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 1799, in _fused_normalize_batch_in_training data_format=tf_data_format) File "/home/brian/hitme/lib/python3.6/site-packages/tensorflow/python/ops/nn_impl.py", line 831, in fused_batch_norm name=name) File "/home/brian/hitme/lib/python3.6/site-packages/tensorflow/python/ops/gen_nn_ops.py", line 2034, in _fused_batch_norm is_training=is_training, name=name) File "/home/brian/hitme/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper op_def=op_def) File "/home/brian/hitme/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2956, in create_op op_def=op_def) File "/home/brian/hitme/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1470, in init self._traceback = self._graph._extract_stack() # pylint: disable=protected-access

    ResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[64,512,6,6] [[Node: batch_normalization_199/FusedBatchNorm = FusedBatchNorm[T=DT_FLOAT, data_format="NHWC", epsilon=0.001, is_training=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](conv2d_133/BiasAdd, batch_normalization_199/gamma/read, batch_normalization_199/beta/read, batch_normalization_199/Const_4, batch_normalization_199/Const_4)]] [[Node: loss_33/add/_17985 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_3237_loss_33/add", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

  • Eliminate bottleneck in `stringRepresentation` [SPEED IMPROVEMENT]

    Eliminate bottleneck in `stringRepresentation` [SPEED IMPROVEMENT]

    I time profiled the code, and this line becomes a bottleneck, because of the string.


        def stringRepresentation(self, board):
            return str(self._base_board.with_np_pieces(np_pieces=board))

    Another issue is that sometimes returns '-0.' instead of just '0'.

    My suggestion: substitute it by something 100 times faster, although not so readable:

        def stringRepresentation(self, board):
             return board.astype(int).tostring()

    is true that we are not running the assert contained in Board.__init__ when we run .with_np_pieces, but on the other hand, running it in all iterations, may be too much.

  • Four Player Game

    Four Player Game


    I am working on an AI for a four player card game and would like to apply this repo to it. Since you only support two-player games so far, I would like to extend it to the four-player case. Do you think this is possible within reasonable time? If yes, could you please point me to the main points where the changes would have to happen? The getSymmetries() in the Game probably would have to disappear.

  • Coach accepting and rejecting new model

    Coach accepting and rejecting new model

    Hi, I might have discovered possible error at the end of coach learn episode when its deciding, whether to discard or keep new model.

    if pwins+nwins > 0 and float(nwins)/(pwins+nwins) < self.args.updateThreshold:
        print('REJECTING NEW MODEL')
        print('ACCEPTING NEW MODEL')

    There might be possible error, when there are no pwins+nwins, resulting only in draws. Then new model is accepted. I dont think its good to accept new model, because we do not know actually how good this contesting model, that resulted in all draws, actually is. Possible solution would be the following:

    if pwins+nwins == 0 or float(nwins)/(pwins+nwins) < self.args.updateThreshold:
        print('REJECTING NEW MODEL')
        print('ACCEPTING NEW MODEL')

    This will result in not-dividing by zero and will discard model resulting in draws.

  • flips assertion failure

    flips assertion failure

    Hi Surag, Thank for sharing good software. When I ran it under tensorflow framework, I got the flips assertion failure.

    ------ITER 19------ Self Play |###################### | (71/100) Eps Time: 1.669s | Total: 0:01:58 | ETA: 0:00:51Traceback (most recent call last): File "main.py", line 30, in c.learn() File "/home/***/tools/alpha-zero-general/Coach.py", line 78, in learn trainExamples += self.executeEpisode()
    File "/home/***/tools/alpha-zero-general/Coach.py", line 46, in executeEpisode pi = self.mcts.getActionProb(canonicalBoard, temp=temp) File "/home/***/tools/alpha-zero-general/MCTS.py", line 31, in getActionProb self.search(canonicalBoard) File "/home/***/tools/alpha-zero-general/MCTS.py", line 106, in search v = self.search(next_s) File "/home/***/tools/alpha-zero-general/MCTS.py", line 106, in search v = self.search(next_s) File "/home/***/tools/alpha-zero-general/MCTS.py", line 106, in search v = self.search(next_s) File "/home/***/tools/alpha-zero-general/MCTS.py", line 106, in search v = self.search(next_s) File "/home/***/tools/alpha-zero-general/MCTS.py", line 103, in search next_s, next_player = self.game.getNextState(canonicalBoard, 1, a) File "/home/***/tools/alpha-zero-general/othello/OthelloGame.py", line 34, in getNextState b.execute_move(move, player) File "/home/***/tools/alpha-zero-general/othello/OthelloLogic.py", line 111, in execute_move assert len(list(flips))>0 AssertionError

    System info: tensorflow-gpu 1.1.0 master branch (head of commit: " commit 263eccb2de3ca5eae7d5615e8b5b7d13b481b569 Author: suragnair [email protected] Date: Wed Jan 3 12:05:11 2018 +0530

    added dim to pytorch log_softmax (UserWarning)" )

    I wonder whether it is possible that flips can be empty



  • input: possibilities instead of situation

    input: possibilities instead of situation

    I've come up with a method that works much faster than the solutions I see here. I see that with all implementations the current situation on the board is given as input. I imagine that the network always has to interpret that situation first, only then the real thinking begins.

    I could not use this method because my game does not involve a board with fixed dimensions. So I came up with something else. I simply skip the first step. I do not feed the network with the current situation on the board but I only give it the possible moves. In the case of legitimate moves, the fields of the person whose turn it is are assigned "1". His opponent's legitimate fields are marked "-1" and all other prohibited fields are marked "0". I get significantly better results! And I can handle many more situations. An important advantage (for me) is that you can handle multiple board dimensions with 1 network.

    I myself am a chess player and I see the same distinction between amateurs and professionals. When a professional looks at the board he sees possibilities, when an amateur looks he first sees wooden pieces and only somewhere in the distance does he see possibilities. For that reason alone, he plays worse.

    But I understand that this method is not suitable for every game, maybe not for chess and so on However, I am sure that this method can be very beneficial for certain games. I suspect it is for Go with a few changes..

  • Better model/weights

    Better model/weights

    Tried to upload, but too large. After checkpoint 153. Results pitted against pre-trained best (1580):

    Arena.playGames |################################| (4097/2048) Eps Time: 3.473s | Total: 3:57:03 | ETA: 0:00:03 (1580, 2516, 0)

  • mainTafl.py does not work for me

    mainTafl.py does not work for me

    mainTafl.py does not work; it looks for me like an inconsistency in handling board. Coach expects for trainExamples (compared to Otello) from getNextState a ndarray, but gets Board objects.

      Traceback (most recent call last):
        File "mainTafl.py", line 34, in <module>
        File "/content/Coach.py", line 113, in learn
        File "/content/tafl/pytorch/NNet.py", line 55, in train
          boards = torch.FloatTensor(np.array(boards).astype(np.float64))
      ValueError: setting an array element with a sequence.
  • Othello Keras Pretrained Model :  UnpicklingError

    Othello Keras Pretrained Model : UnpicklingError

    The pretrained model in https://github.com/suragnair/alpha-zero-general/tree/master/pretrained_models/othello/keras cannot be loaded.


    import torch
    model=torch.load("6x6 checkpoint_145.pth.tar",map_location='cpu')

    Error Message:

    UnpicklingError                           Traceback (most recent call last)
    Cell In [2], line 2
          1 import torch
    ----> 2 model=torch.load("6x6 checkpoint_145.pth.tar",map_location='cpu')
    File c:\Users\entdi\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\serialization.py:713, in load(f, map_location, pickle_module, **pickle_load_args)
        711             return torch.jit.load(opened_file)
        712         return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    --> 713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
    File c:\Users\entdi\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\serialization.py:920, in _legacy_load(f, map_location, pickle_module, **pickle_load_args)
        914 if not hasattr(f, 'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
        915     raise RuntimeError(
        916         "torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
        917         f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this "
        918         "functionality.")
    --> 920 magic_number = pickle_module.load(f, **pickle_load_args)
        921 if magic_number != MAGIC_NUMBER:
        922     raise RuntimeError("Invalid magic number; corrupt file?")
    UnpicklingError: invalid load key, 'H'.

    The models in other folders can be loaded with similar code, but not this one.

    Also, why is this a .pth.tar file instead of .h5 ?

    Sorry, if this is a naive question.

