Newer
Older
DeltaGo_gcc / cnn_calc.c
#include "go.h"
#ifdef USE_SSE
#include <immintrin.h>
#endif

static inline int get_hiddenNext_5x5(game_hdl_t *hdl,  cnn_t hiddenNext[21][21][FILTER_SIZE], cnn_t hiddenPrev[23][23][CHANNEL_SIZE], 
	const int x, const int y, const int inputNums, const int outputNums)
{
	int input, output;
	
	
	for (input = 0; input < inputNums; input++){
		// for (output = 0; output < outputNums; output++)for (q = 0; q < 5; q++)	for (p = 0; p < 5; p++)
		// hiddenNext[x][y][output]  += hiddenPrev[(x) + (p) - 1][(y) + (q) - 1][input] * cnn_5x5_orig[(p) + (q) * 5]

		cnn_t i00 = hiddenPrev[(x) + (0) - 1][(y) + (0) - 1][input];
		cnn_t i01 = hiddenPrev[(x) + (0) - 1][(y) + (1) - 1][input];
		cnn_t i02 = hiddenPrev[(x) + (0) - 1][(y) + (2) - 1][input];
		cnn_t i03 = hiddenPrev[(x) + (0) - 1][(y) + (3) - 1][input];
		cnn_t i04 = hiddenPrev[(x) + (0) - 1][(y) + (4) - 1][input];
		cnn_t i10 = hiddenPrev[(x) + (1) - 1][(y) + (0) - 1][input];
		cnn_t i11 = hiddenPrev[(x) + (1) - 1][(y) + (1) - 1][input];
		cnn_t i12 = hiddenPrev[(x) + (1) - 1][(y) + (2) - 1][input];
		cnn_t i13 = hiddenPrev[(x) + (1) - 1][(y) + (3) - 1][input];
		cnn_t i14 = hiddenPrev[(x) + (1) - 1][(y) + (4) - 1][input];
		cnn_t i20 = hiddenPrev[(x) + (2) - 1][(y) + (0) - 1][input];
		cnn_t i21 = hiddenPrev[(x) + (2) - 1][(y) + (1) - 1][input];
		cnn_t i22 = hiddenPrev[(x) + (2) - 1][(y) + (2) - 1][input];
		cnn_t i23 = hiddenPrev[(x) + (2) - 1][(y) + (3) - 1][input];
		cnn_t i24 = hiddenPrev[(x) + (2) - 1][(y) + (4) - 1][input];
		cnn_t i30 = hiddenPrev[(x) + (3) - 1][(y) + (0) - 1][input];
		cnn_t i31 = hiddenPrev[(x) + (3) - 1][(y) + (1) - 1][input];
		cnn_t i32 = hiddenPrev[(x) + (3) - 1][(y) + (2) - 1][input];
		cnn_t i33 = hiddenPrev[(x) + (3) - 1][(y) + (3) - 1][input];
		cnn_t i34 = hiddenPrev[(x) + (3) - 1][(y) + (4) - 1][input];
		cnn_t i40 = hiddenPrev[(x) + (4) - 1][(y) + (0) - 1][input];
		cnn_t i41 = hiddenPrev[(x) + (4) - 1][(y) + (1) - 1][input];
		cnn_t i42 = hiddenPrev[(x) + (4) - 1][(y) + (2) - 1][input];
		cnn_t i43 = hiddenPrev[(x) + (4) - 1][(y) + (3) - 1][input];
		cnn_t i44 = hiddenPrev[(x) + (4) - 1][(y) + (4) - 1][input];
#ifdef USE_SSE
		assert (outputNums %VEC_WIDTH == 0);
		
		__m256 i000 = _mm256_set1_ps(i00);
		__m256 i001 = _mm256_set1_ps(i01);
		__m256 i002 = _mm256_set1_ps(i02);
		__m256 i003 = _mm256_set1_ps(i03);
		__m256 i004 = _mm256_set1_ps(i04);

		__m256 i010 = _mm256_set1_ps(i10);
		__m256 i011 = _mm256_set1_ps(i11);
		__m256 i012 = _mm256_set1_ps(i12);
		__m256 i013 = _mm256_set1_ps(i13);
		__m256 i014 = _mm256_set1_ps(i14);

		__m256 i020 = _mm256_set1_ps(i20);
		__m256 i021 = _mm256_set1_ps(i21);
		__m256 i022 = _mm256_set1_ps(i22);
		__m256 i023 = _mm256_set1_ps(i23);
		__m256 i024 = _mm256_set1_ps(i24);

		__m256 i030 = _mm256_set1_ps(i30);
		__m256 i031 = _mm256_set1_ps(i31);
		__m256 i032 = _mm256_set1_ps(i32);
		__m256 i033 = _mm256_set1_ps(i33);
		__m256 i034 = _mm256_set1_ps(i34);

		__m256 i040 = _mm256_set1_ps(i40);
		__m256 i041 = _mm256_set1_ps(i41);
		__m256 i042 = _mm256_set1_ps(i42);
		__m256 i043 = _mm256_set1_ps(i43);
		__m256 i044 = _mm256_set1_ps(i44);


		cnn_t *weight = cnn_param->weight_cnn_5x5[input][0];
		for (output = 0; output < outputNums; output += VEC_WIDTH){
			
			__m256 v = _mm256_mul_ps(_mm256_loadu_ps(&weight[0 * VEC_WIDTH]), i000); 
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[1 * VEC_WIDTH]), i010)); 
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[2 * VEC_WIDTH]), i020));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[3 * VEC_WIDTH]), i030));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[4 * VEC_WIDTH]), i040));

			
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[5 * VEC_WIDTH]), i001));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[6 * VEC_WIDTH]), i011));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[7 * VEC_WIDTH]), i021));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[8 * VEC_WIDTH]), i031));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[9 * VEC_WIDTH]), i041));
			
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[10 * VEC_WIDTH]), i002));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[11 * VEC_WIDTH]), i012));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[12 * VEC_WIDTH]), i022));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[13 * VEC_WIDTH]), i032));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[14 * VEC_WIDTH]), i042));

			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[15 * VEC_WIDTH]), i003));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[16 * VEC_WIDTH]), i013));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[17 * VEC_WIDTH]), i023));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[18 * VEC_WIDTH]), i033));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[19 * VEC_WIDTH]), i043));

			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[20 * VEC_WIDTH]), i004));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[21 * VEC_WIDTH]), i014));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[22 * VEC_WIDTH]), i024));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[23 * VEC_WIDTH]), i034));
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[24 * VEC_WIDTH]), i044));

			__m256 prev = _mm256_loadu_ps(&hiddenNext[x][y][output]); 
			_mm256_storeu_ps(&hiddenNext[x][y][output], _mm256_add_ps(prev,v));
			weight += VEC_WIDTH * 25;
		}
#else
		
		for (output = 0; output < outputNums; output += VEC_WIDTH){	
			cnn_t *weight = cnn_param->weight_cnn_5x5[input][output];
			int j;

			for (j = 0; j < VEC_WIDTH; j++){
				hiddenNext[x][y][output + j] += 
				i00 * weight[ 0 * 8 + j] + i10 * weight[ 1 * 8 + j] + i20 * weight[ 2 * 8 + j] + i30 * weight[ 3 * 8 + j] + i40 * weight[ 4 * 8 + j] +
				i01 * weight[ 5 * 8 + j] + i11 * weight[ 6 * 8 + j] + i21 * weight[ 7 * 8 + j] + i31 * weight[ 8 * 8 + j] + i41 * weight[ 9 * 8 + j] +
				i02 * weight[10 * 8 + j] + i12 * weight[11 * 8 + j] + i22 * weight[12 * 8 + j] + i32 * weight[13 * 8 + j] + i42 * weight[14 * 8 + j] +
				i03 * weight[15 * 8 + j] + i13 * weight[16 * 8 + j] + i23 * weight[17 * 8 + j] + i33 * weight[18 * 8 + j] + i43 * weight[19 * 8 + j] +
				i04 * weight[20 * 8 + j] + i14 * weight[21 * 8 + j] + i24 * weight[22 * 8 + j] + i34 * weight[23 * 8 + j] + i44 * weight[24 * 8 + j];
			}
		}
#endif
	}

	for (output = 0; output < outputNums; output++){
		if (hiddenNext[x][y][output] <= 0.0){
			hiddenNext[x][y][output] = 0.0;
		}
	}
	return 0;
}

static int convolution5x5 (cnn_t hiddenNext[21][21][FILTER_SIZE], cnn_t hiddenPrev[23][23][CHANNEL_SIZE], game_hdl_t *hdl, 
			const int outputNums, const int inputNums)
{
	int x;

	for (x = 1; x <= 19; x++){
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 1, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 2, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 3, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 4, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 5, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 6, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 7, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 8, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x, 9, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,10, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,11, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,12, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,13, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,14, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,15, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,16, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,17, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,18, inputNums, outputNums);
		get_hiddenNext_5x5(hdl, hiddenNext, hiddenPrev, x,19, inputNums, outputNums);
	}
	return 0;
}


static inline int get_hiddenNext_3x3(game_hdl_t *hdl,  cnn_t hiddenNext[21][21][FILTER_SIZE], cnn_t hiddenPrev[21][21][FILTER_SIZE], 
	const int x, const int y, const int inputNums, const int outputNums, const int layer)
{
	int input, output;
	
	for (input = 0; input < inputNums; input++){
		//for (output = 0; output < outputNums; output++)for (q = 0; q < 3; q++)for (p = 0; p < 3; p++)
		// hiddenNext[x][y][output]  += hiddenPrev[(x) + (p) - 1][(y) + (q) - 1][input] * cnn_3x3_orig[(p) + (q) * 3]
		
		cnn_t i00 = hiddenPrev[(x) + (0) - 1][(y) + (0) - 1][input];
		cnn_t i01 = hiddenPrev[(x) + (0) - 1][(y) + (1) - 1][input];
		cnn_t i02 = hiddenPrev[(x) + (0) - 1][(y) + (2) - 1][input];
		cnn_t i10 = hiddenPrev[(x) + (1) - 1][(y) + (0) - 1][input];
		cnn_t i11 = hiddenPrev[(x) + (1) - 1][(y) + (1) - 1][input];
		cnn_t i12 = hiddenPrev[(x) + (1) - 1][(y) + (2) - 1][input];
		cnn_t i20 = hiddenPrev[(x) + (2) - 1][(y) + (0) - 1][input];
		cnn_t i21 = hiddenPrev[(x) + (2) - 1][(y) + (1) - 1][input];
		cnn_t i22 = hiddenPrev[(x) + (2) - 1][(y) + (2) - 1][input];
		
#ifdef USE_SSE
		assert (outputNums %VEC_WIDTH == 0);
		
		__m256 i000 = _mm256_set1_ps(i00);
		__m256 i001 = _mm256_set1_ps(i01);
		__m256 i002 = _mm256_set1_ps(i02);
		__m256 i010 = _mm256_set1_ps(i10);
		__m256 i011 = _mm256_set1_ps(i11);
		__m256 i012 = _mm256_set1_ps(i12);
		__m256 i020 = _mm256_set1_ps(i20);
		__m256 i021 = _mm256_set1_ps(i21);
		__m256 i022 = _mm256_set1_ps(i22);
		cnn_t *weight = cnn_param->weight_cnn_3x3[layer][input][0];
		for (output = 0; output < outputNums; output += VEC_WIDTH){
			
			__m256 v = _mm256_mul_ps(_mm256_loadu_ps(&weight[0 * VEC_WIDTH]), i000); 
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[1 * VEC_WIDTH]), i010)); 
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[2 * VEC_WIDTH]), i020)); 
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[3 * VEC_WIDTH]), i001)); 
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[4 * VEC_WIDTH]), i011)); 
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[5 * VEC_WIDTH]), i021)); 
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[6 * VEC_WIDTH]), i002)); 
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[7 * VEC_WIDTH]), i012)); 
			v = _mm256_add_ps(v, _mm256_mul_ps(_mm256_loadu_ps(&weight[8 * VEC_WIDTH]), i022)); 
			__m256 prev = _mm256_loadu_ps(&hiddenNext[x][y][output]); 
			_mm256_storeu_ps(&hiddenNext[x][y][output], _mm256_add_ps(prev,v));
			weight += VEC_WIDTH * 9;
		}
		
#else
		int j;
		for (output = 0; output < outputNums; output += VEC_WIDTH){	
			cnn_t *weight = cnn_param->weight_cnn_3x3[layer][input][output];

			for (j = 0; j <  VEC_WIDTH; j++){
				hiddenNext[x][y][output + j] += i00 * weight[0 * 8 + j] + i10 * weight[1 * 8 + j] + i20 * weight[2 * 8 + j] + 
												i01 * weight[3 * 8 + j] + i11 * weight[4 * 8 + j] + i21 * weight[5 * 8 + j] + 
												i02 * weight[6 * 8 + j] + i12 * weight[7 * 8 + j] + i22 * weight[8 * 8 + j];
			}
		}
#endif
	}
	for (output = 0; output < outputNums; output++){
		if (hiddenNext[x][y][output] <= 0.0){
			hiddenNext[x][y][output] = 0.0;
		}
	}
	return 0;
}

static int convolution3x3 (const int layer, cnn_t hiddenNext[21][21][FILTER_SIZE], cnn_t hiddenPrev[21][21][FILTER_SIZE], game_hdl_t *hdl, 
			const int outputNums, const int inputNums)
{
	int x;

	for (x = 1; x <= 19; x++){
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 1, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 2, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 3, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 4, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 5, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 6, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 7, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 8, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x, 9, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,10, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,11, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,12, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,13, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,14, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,15, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,16, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,17, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,18, inputNums, outputNums, layer);
		get_hiddenNext_3x3(hdl, hiddenNext, hiddenPrev, x,19, inputNums, outputNums, layer);
	}
	return 0;
}

static inline int convolution13 (cnn_t hiddenOut[21][21], cnn_t hiddenPrev[21][21][FILTER_SIZE], game_hdl_t *hdl, const int inputNums)
{
	int input, x, y;
	
	for (x = 1; x <= 19; x++){
		for (y = 1; y <= 19; y++){
			cnn_t tmp = cnn_param->weight_cnn_bias[(x - 1) + 19 * (y - 1)];
			for (input = 0; input < inputNums; input++){
				tmp += hiddenPrev[x][y][input] * cnn_param->weight_cnn_1x1[input];
			}
			hiddenOut[x][y] = tmp;
		}
	}
	
	return 0;
}

int get_cnn_prob(goban_t *ban, game_hdl_t *hdl, double pos_to_prob[XY_SIZE])
{
	int x, y, input;
	for (input = 0; input < CHANNEL_SIZE; input++){
		for (x = 1; x <= 19; x++){
			for (y = 1; y <= 19; y++){
				hdl->hiddenIn[x + 1][y + 1][input] = (cnn_t)ban->feature[input][x][y]; 
			}
		}
	}
	
	memset (hdl->hidden, 0, 14 * 21 * 21 * FILTER_SIZE * sizeof(cnn_t));
	memset (hdl->hiddenOut, 0, 21 * 21 * sizeof(cnn_t));
	
	convolution5x5(    hdl->hidden[1],  hdl->hiddenIn,   hdl, FILTER_SIZE, CHANNEL_SIZE);
	convolution3x3(2,  hdl->hidden[2],  hdl->hidden[1],  hdl, FILTER_SIZE, FILTER_SIZE);
	convolution3x3(3,  hdl->hidden[3],  hdl->hidden[2],  hdl, FILTER_SIZE, FILTER_SIZE);
	convolution3x3(4,  hdl->hidden[4],  hdl->hidden[3],  hdl, FILTER_SIZE, FILTER_SIZE);
	convolution3x3(5,  hdl->hidden[5],  hdl->hidden[4],  hdl, FILTER_SIZE, FILTER_SIZE);
	convolution3x3(6,  hdl->hidden[6],  hdl->hidden[5],  hdl, FILTER_SIZE, FILTER_SIZE);
	convolution3x3(7,  hdl->hidden[7],  hdl->hidden[6],  hdl, FILTER_SIZE, FILTER_SIZE);
	convolution3x3(8,  hdl->hidden[8],  hdl->hidden[7],  hdl, FILTER_SIZE, FILTER_SIZE);
	convolution3x3(9,  hdl->hidden[9],  hdl->hidden[8],  hdl, FILTER_SIZE, FILTER_SIZE);
	convolution3x3(10, hdl->hidden[10], hdl->hidden[9],  hdl, FILTER_SIZE, FILTER_SIZE);
	convolution3x3(11, hdl->hidden[11], hdl->hidden[10], hdl, FILTER_SIZE, FILTER_SIZE);
	convolution3x3(12, hdl->hidden[12], hdl->hidden[11], hdl, FILTER_SIZE, FILTER_SIZE);
	convolution13(     hdl->hiddenOut, hdl->hidden[12], hdl, FILTER_SIZE);	
	
	double max = -1000000.0;
	int best_pos = 0;
	double prob;
	// get prob_tot
	double prob_tot = 0.0;
	for (y = 1; y <= BOARD_EDGE_SIZE ; y++){
		for (x = 1; x <= BOARD_EDGE_SIZE; x++){
			prob = exp(hdl->hiddenOut[x][y]);
			prob_tot += prob;
		}
	}
	if (prob_tot == 0.0){
		prob_tot = 1.0;
	}
	// get best_pos and print prob
	for (y = 1; y <= BOARD_EDGE_SIZE ; y++){
		for (x = 1; x <= BOARD_EDGE_SIZE; x++){
			int pos = X_Y_TO_POS(x, y);
			double prob = exp(hdl->hiddenOut[x][y]);
			pos_to_prob[pos] = prob/prob_tot;
			if (ban->color[pos] == SP && judge_eff_te(pos, ban)){ 
				if (prob >= max){
					max = prob;
					best_pos = pos;
				}
			}
		}
	}
	return best_pos;
}