#include "go.h" #ifdef USE_SSE #include <immintrin.h> #endif static inline int get_hiddenNext_5x5(game_hdl_t *hdl, cnn_t hiddenNext[21][21][FILTER_SIZE], cnn_t hiddenPrev[23][23][CHANNEL_SIZE], const int x, const int y, const int inputNums, const int outputNums) { int input, output; for (input = 0; input < inputNums; input++){ // for (output = 0; output < outputNums; output++)for (q = 0; q < 5; q++) for (p = 0; p < 5; p++) // hiddenNext[x][y][output] += hiddenPrev[(x) + (p) - 1][(y) + (q) - 1][input] * cnn_5x5_orig[(p) + (q) * 5] cnn_t i00 = hiddenPrev[(x) + (0) - 1][(y) + (0) - 1][input]; cnn_t i01 = hiddenPrev[(x) + (0) - 1][(y) + (1) - 1][input]; cnn_t i02 = hiddenPrev[(x) + (0) - 1][(y) + (2) - 1][input]; cnn_t i03 = hiddenPrev[(x) + (0) - 1][(y) + (3) - 1][input]; cnn_t i04 = hiddenPrev[(x) + (0) - 1][(y) + (4) - 1][input]; cnn_t i10 = hiddenPrev[(x) + (1) - 1][(y) + (0) - 1][input]; cnn_t i11 = hiddenPrev[(x) + (1) - 1][(y) + (1) - 1][input]; cnn_t i12 = hiddenPrev[(x) + (1) - 1][(y) + (2) - 1][input]; cnn_t i13 = hiddenPrev[(x) + (1) - 1][(y) + (3) - 1][input]; cnn_t i14 = hiddenPrev[(x) + (1) - 1][(y) + (4) - 1][input]; cnn_t i20 = hiddenPrev[(x) + (2) - 1][(y) + (0) - 1][input]; cnn_t i21 = hiddenPrev[(x) + (2) - 1][(y) + (1) - 1][input]; cnn_t i22 = hiddenPrev[(x) + (2) - 1][(y) + (2) - 1][input]; cnn_t i23 = hiddenPrev[(x) + (2) - 1][(y) + (3) - 1][input]; cnn_t i24 = hiddenPrev[(x) + (2) - 1][(y) + (4) - 1][input]; cnn_t i30 = hiddenPrev[(x) + (3) - 1][(y) + (0) - 1][input]; cnn_t i31 = hiddenPrev[(x) + (3) - 1][(y) + (1) - 1][input]; cnn_t i32 = hiddenPrev[(x) + (3) - 1][(y) + (2) - 1][input]; cnn_t i33 = hiddenPrev[(x) + (3) - 1][(y) + (3) - 1][input]; cnn_t i34 = hiddenPrev[(x) + (3) - 1][(y) + (4) - 1][input]; cnn_t i40 = hiddenPrev[(x) + (4) - 1][(y) + (0) - 1][input]; cnn_t i41 = hiddenPrev[(x) + (4) - 1][(y) + (1) - 1][input]; cnn_t i42 = hiddenPrev[(x) + (4) - 1][(y) + (2) - 1][input]; cnn_t i43 = hiddenPrev[(x) + (4) - 1][(y) + (3) - 1][input]; cnn_t i44 = hiddenPrev[(x) + (4) - 1][(y) + (4) - 1][input]; #ifdef USE_SSE assert (outputNums %VEC_WIDTH == 0); __m256 i000 = _mm256_set1_ps(i00); __m256 i001 = _mm256_set1_ps(i01); __m256 i002 = _mm256_set1_ps(i02); __m256 i003 = _mm256_set1_ps(i03); __m256 i004 = _mm256_set1_ps(i04); __m256 i010 = _mm256_set1_ps(i10); __m256 i011 = _mm256_set1_ps(i11); __m256 i012 = _mm256_set1_ps(i12); __m256 i013 = _mm256_set1_ps(i13); __m256 i014 = _mm256_set1_ps(i14); __m256 i020 = _mm256_set1_ps(i20); __m256 i021 = _mm256_set1_ps(i21); __m256 i022 = _mm256_set1_ps(i22); __m256 i023 = _mm256_set1_ps(i23); __m256 i024 = _mm256_set1_ps(i24); __m256 i030 = _mm256_set1_ps(i30); __m256 i031 = _mm256_set1_ps(i31); __m256 i032 = _mm256_set1_ps(i32); __m256 i033 = _mm256_set1_ps(i33); __m256 i034 = _mm256_set1_ps(i34); __m256 i040 = _mm256_set1_ps(i40); __m256 i041 = _mm256_set1_ps(i41); __m256 i042 = _mm256_set1_ps(i42); __m256 i043 = _mm256_set1_ps(i43); __m256 i044 = _mm256_set1_ps(i44); cnn_t *weight = cnn_param->weight_cnn_5x5[input][0]; for (output = 0; output < outputNums; output += VEC_WIDTH){ __m256 v = _mm256_mul_ps(_mm256_loadu_ps(&weight[0 * VEC_WIDTH]), i000); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[1 * VEC_WIDTH]), i010)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[2 * VEC_WIDTH]), i020)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[3 * VEC_WIDTH]), i030)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[4 * VEC_WIDTH]), i040)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[5 * VEC_WIDTH]), i001)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[6 * VEC_WIDTH]), i011)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[7 * VEC_WIDTH]), i021)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[8 * VEC_WIDTH]), i031)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[9 * VEC_WIDTH]), i041)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[10 * VEC_WIDTH]), i002)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[11 * VEC_WIDTH]), i012)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[12 * VEC_WIDTH]), i022)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[13 * VEC_WIDTH]), i032)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[14 * VEC_WIDTH]), i042)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[15 * VEC_WIDTH]), i003)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[16 * VEC_WIDTH]), i013)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[17 * VEC_WIDTH]), i023)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[18 * VEC_WIDTH]), i033)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[19 * VEC_WIDTH]), i043)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[20 * VEC_WIDTH]), i004)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[21 * VEC_WIDTH]), i014)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[22 * VEC_WIDTH]), i024)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[23 * VEC_WIDTH]), i034)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[24 * VEC_WIDTH]), i044)); __m256 prev = _mm256_loadu_ps(&hiddenNext[x][y][output]); _mm256_storeu_ps(&hiddenNext[x][y][output], _mm256_add_ps(prev,v)); weight += VEC_WIDTH * 25; } #else for (output = 0; output < outputNums; output += VEC_WIDTH){ cnn_t *weight = cnn_param->weight_cnn_5x5[input][output]; int j; for (j = 0; j < VEC_WIDTH; j++){ hiddenNext[x][y][output + j] += i00 * weight[ 0 * 8 + j] + i10 * weight[ 1 * 8 + j] + i20 * weight[ 2 * 8 + j] + i30 * weight[ 3 * 8 + j] + i40 * weight[ 4 * 8 + j] + i01 * weight[ 5 * 8 + j] + i11 * weight[ 6 * 8 + j] + i21 * weight[ 7 * 8 + j] + i31 * weight[ 8 * 8 + j] + i41 * weight[ 9 * 8 + j] + i02 * weight[10 * 8 + j] + i12 * weight[11 * 8 + j] + i22 * weight[12 * 8 + j] + i32 * weight[13 * 8 + j] + i42 * weight[14 * 8 + j] + i03 * weight[15 * 8 + j] + i13 * weight[16 * 8 + j] + i23 * weight[17 * 8 + j] + i33 * weight[18 * 8 + j] + i43 * weight[19 * 8 + j] + i04 * weight[20 * 8 + j] + i14 * weight[21 * 8 + j] + i24 * weight[22 * 8 + j] + i34 * weight[23 * 8 + j] + i44 * weight[24 * 8 + j]; } } #endif } for (output = 0; output < outputNums; output++){ if (hiddenNext[x][y][output] <= 0.0){ hiddenNext[x][y][output] = 0.0; } } return 0; } static int convolution5x5 (cnn_t hiddenNext[21][21][FILTER_SIZE], cnn_t hiddenPrev[23][23][CHANNEL_SIZE], game_hdl_t *hdl, const int outputNums, const int inputNums) { int x; for (x = 1; x <= 19; x++){ get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 1, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 2, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 3, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 4, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 5, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 6, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 7, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 8, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 9, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,10, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,11, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,12, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,13, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,14, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,15, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,16, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,17, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,18, inputNums, outputNums); get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,19, inputNums, outputNums); } return 0; } static inline int get_hiddenNext_3x3(game_hdl_t *hdl, cnn_t hiddenNext[21][21][FILTER_SIZE], cnn_t hiddenPrev[21][21][FILTER_SIZE], const int x, const int y, const int inputNums, const int outputNums, const int layer) { int input, output; for (input = 0; input < inputNums; input++){ //for (output = 0; output < outputNums; output++)for (q = 0; q < 3; q++)for (p = 0; p < 3; p++) // hiddenNext[x][y][output] += hiddenPrev[(x) + (p) - 1][(y) + (q) - 1][input] * cnn_3x3_orig[(p) + (q) * 3] cnn_t i00 = hiddenPrev[(x) + (0) - 1][(y) + (0) - 1][input]; cnn_t i01 = hiddenPrev[(x) + (0) - 1][(y) + (1) - 1][input]; cnn_t i02 = hiddenPrev[(x) + (0) - 1][(y) + (2) - 1][input]; cnn_t i10 = hiddenPrev[(x) + (1) - 1][(y) + (0) - 1][input]; cnn_t i11 = hiddenPrev[(x) + (1) - 1][(y) + (1) - 1][input]; cnn_t i12 = hiddenPrev[(x) + (1) - 1][(y) + (2) - 1][input]; cnn_t i20 = hiddenPrev[(x) + (2) - 1][(y) + (0) - 1][input]; cnn_t i21 = hiddenPrev[(x) + (2) - 1][(y) + (1) - 1][input]; cnn_t i22 = hiddenPrev[(x) + (2) - 1][(y) + (2) - 1][input]; #ifdef USE_SSE assert (outputNums %VEC_WIDTH == 0); __m256 i000 = _mm256_set1_ps(i00); __m256 i001 = _mm256_set1_ps(i01); __m256 i002 = _mm256_set1_ps(i02); __m256 i010 = _mm256_set1_ps(i10); __m256 i011 = _mm256_set1_ps(i11); __m256 i012 = _mm256_set1_ps(i12); __m256 i020 = _mm256_set1_ps(i20); __m256 i021 = _mm256_set1_ps(i21); __m256 i022 = _mm256_set1_ps(i22); cnn_t *weight = cnn_param->weight_cnn_3x3[layer][input][0]; for (output = 0; output < outputNums; output += VEC_WIDTH){ __m256 v = _mm256_mul_ps(_mm256_loadu_ps(&weight[0 * VEC_WIDTH]), i000); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[1 * VEC_WIDTH]), i010)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[2 * VEC_WIDTH]), i020)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[3 * VEC_WIDTH]), i001)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[4 * VEC_WIDTH]), i011)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[5 * VEC_WIDTH]), i021)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[6 * VEC_WIDTH]), i002)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[7 * VEC_WIDTH]), i012)); v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[8 * VEC_WIDTH]), i022)); __m256 prev = _mm256_loadu_ps(&hiddenNext[x][y][output]); _mm256_storeu_ps(&hiddenNext[x][y][output], _mm256_add_ps(prev,v)); weight += VEC_WIDTH * 9; } #else int j; for (output = 0; output < outputNums; output += VEC_WIDTH){ cnn_t *weight = cnn_param->weight_cnn_3x3[layer][input][output]; for (j = 0; j < VEC_WIDTH; j++){ hiddenNext[x][y][output + j] += i00 * weight[0 * 8 + j] + i10 * weight[1 * 8 + j] + i20 * weight[2 * 8 + j] + i01 * weight[3 * 8 + j] + i11 * weight[4 * 8 + j] + i21 * weight[5 * 8 + j] + i02 * weight[6 * 8 + j] + i12 * weight[7 * 8 + j] + i22 * weight[8 * 8 + j]; } } #endif } for (output = 0; output < outputNums; output++){ if (hiddenNext[x][y][output] <= 0.0){ hiddenNext[x][y][output] = 0.0; } } return 0; } static int convolution3x3 (const int layer, cnn_t hiddenNext[21][21][FILTER_SIZE], cnn_t hiddenPrev[21][21][FILTER_SIZE], game_hdl_t *hdl, const int outputNums, const int inputNums) { int x; for (x = 1; x <= 19; x++){ get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 1, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 2, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 3, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 4, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 5, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 6, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 7, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 8, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 9, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,10, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,11, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,12, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,13, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,14, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,15, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,16, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,17, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,18, inputNums, outputNums, layer); get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,19, inputNums, outputNums, layer); } return 0; } static inline int convolution13 (cnn_t hiddenOut[21][21], cnn_t hiddenPrev[21][21][FILTER_SIZE], game_hdl_t *hdl, const int inputNums) { int input, x, y; for (x = 1; x <= 19; x++){ for (y = 1; y <= 19; y++){ cnn_t tmp = cnn_param->weight_cnn_bias[(x - 1) + 19 * (y - 1)]; for (input = 0; input < inputNums; input++){ tmp += hiddenPrev[x][y][input] * cnn_param->weight_cnn_1x1[input]; } hiddenOut[x][y] = tmp; } } return 0; } int get_cnn_prob(goban_t *ban, game_hdl_t *hdl, double pos_to_prob[XY_SIZE]) { int x, y, input; for (input = 0; input < CHANNEL_SIZE; input++){ for (x = 1; x <= 19; x++){ for (y = 1; y <= 19; y++){ hdl->hiddenIn[x + 1][y + 1][input] = (cnn_t)ban->feature[input][x][y]; } } } memset (hdl->hidden, 0, 14 * 21 * 21 * FILTER_SIZE * sizeof(cnn_t)); memset (hdl->hiddenOut, 0, 21 * 21 * sizeof(cnn_t)); convolution5x5( hdl->hidden[1], hdl->hiddenIn, hdl, FILTER_SIZE, CHANNEL_SIZE); convolution3x3(2, hdl->hidden[2], hdl->hidden[1], hdl, FILTER_SIZE, FILTER_SIZE); convolution3x3(3, hdl->hidden[3], hdl->hidden[2], hdl, FILTER_SIZE, FILTER_SIZE); convolution3x3(4, hdl->hidden[4], hdl->hidden[3], hdl, FILTER_SIZE, FILTER_SIZE); convolution3x3(5, hdl->hidden[5], hdl->hidden[4], hdl, FILTER_SIZE, FILTER_SIZE); convolution3x3(6, hdl->hidden[6], hdl->hidden[5], hdl, FILTER_SIZE, FILTER_SIZE); convolution3x3(7, hdl->hidden[7], hdl->hidden[6], hdl, FILTER_SIZE, FILTER_SIZE); convolution3x3(8, hdl->hidden[8], hdl->hidden[7], hdl, FILTER_SIZE, FILTER_SIZE); convolution3x3(9, hdl->hidden[9], hdl->hidden[8], hdl, FILTER_SIZE, FILTER_SIZE); convolution3x3(10, hdl->hidden[10], hdl->hidden[9], hdl, FILTER_SIZE, FILTER_SIZE); convolution3x3(11, hdl->hidden[11], hdl->hidden[10], hdl, FILTER_SIZE, FILTER_SIZE); convolution3x3(12, hdl->hidden[12], hdl->hidden[11], hdl, FILTER_SIZE, FILTER_SIZE); convolution13( hdl->hiddenOut, hdl->hidden[12], hdl, FILTER_SIZE); double max = -1000000.0; int best_pos = 0; double prob; // get prob_tot double prob_tot = 0.0; for (y = 1; y <= BOARD_EDGE_SIZE ; y++){ for (x = 1; x <= BOARD_EDGE_SIZE; x++){ prob = exp(hdl->hiddenOut[x][y]); prob_tot += prob; } } if (prob_tot == 0.0){ prob_tot = 1.0; } // get best_pos and print prob for (y = 1; y <= BOARD_EDGE_SIZE ; y++){ for (x = 1; x <= BOARD_EDGE_SIZE; x++){ int pos = X_Y_TO_POS(x, y); double prob = exp(hdl->hiddenOut[x][y]); pos_to_prob[pos] = prob/prob_tot; if (ban->color[pos] == SP && judge_eff_te(pos, ban)){ if (prob >= max){ max = prob; best_pos = pos; } } } } return best_pos; }