

/*
 *   	File:  method1.cu
 *   	author: Lung-Sheng Chien
 *			Department of Mathematics, Tsing Hua univeristy, R.O.C. (Taiwan).
 *			Email: d947207@oz.nthu.edu.tw
 *	 	date: 2010/01/15
 *
 *		description: modification based on Volkov's code 
 *   
 *			see HandTunedSgemm_2010_v1.pdf 
 * 
 * How to compile .cu to .cubin  
 *	"C:\CUDA\bin64\nvcc.exe"  -ccbin "C:\Program Files (x86)\Microsoft Visual Studio 8\VC\bin" -I"C:\Program Files (x86)\Microsoft Visual Studio 8\VC\include"  -O2 -arch compute_13 -code sm_13 -cubin  method1.cu
 *
 */

// Written by Vasily Volkov.
// Copyright (c) 2009, The Regents of the University of California.
// All rights reserved.


#define VECTOR_LENGTH  64
#define NUM_VECTOR     2 
#define BLOCK_SIZE_Y   16	

#define BLOCK_SIZE_X   16

#define THREAD_BLOCK_X  16   // THREAD_BLOCK_X = BLOCK_SIZE_X
#define THREAD_BLOCK_Y  4    // THREAD_BLOCK_Y = VECTOR_LENGTH / BLOCK_SIZE_X


__device__ void store_block( int num, float alpha, float *c, float beta, float *C, int ldc )
{
    if( num <= 0 ) return;

    if( beta == 0 )
    {
        //
        //  for the case when C is initialized with inf or NaN
        //
        int i = 0; 
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  

        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  

        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++];
    }
    else
    {
        int i = 0; 
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  

        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  

        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0];
    }
  
}

//
//  C = alpha*A*B + beta*C
//

/*
	lmem = 0
	smem = 1168
	reg  = 47
	active threads = 320
 */
static __global__ void  method1_sgemmNN( int m, int n, const float *A, int lda, 
const float *B, int ldb, float* C, int ldc, int k, float alpha, float beta )
{
	const int inx = threadIdx.x;
	const int iny = threadIdx.y;
	const int ibx = blockIdx.x * VECTOR_LENGTH * NUM_VECTOR ;
	const int iby = blockIdx.y * BLOCK_SIZE_Y ;
	int row = ibx + inx + iny*THREAD_BLOCK_X ;
	
	A += row;
	B += inx + ( iby + iny ) * ldb;
	C += row  + iby * ldc;

// A1 is starting addrees of next sub-matrix of A
	float *A1 = (float*)A + VECTOR_LENGTH ;
	
	float c0[BLOCK_SIZE_Y] ;
	float c1[BLOCK_SIZE_Y] ;
#pragma unroll	
	for( int i = 0 ; i < BLOCK_SIZE_Y ; i++){
		c0[i] = 0.0f ;
	}	
#pragma unroll	
	for( int i = 0 ; i < BLOCK_SIZE_Y ; i++){
		c1[i] = 0.0f ;
	}		
    
// two sub-matrix of C share the same sub-matrix of B    
	__shared__ float b[BLOCK_SIZE_X][BLOCK_SIZE_Y+1] ;
	
// put one row of sub-matrix of B into register
	float b_reg ;
	float *b_ptr ;
		
	for( ; k > 0; k -= BLOCK_SIZE_X ){
// fetch sub-matrix of B by all threads	
#pragma unroll
		for( int i = 0; i < BLOCK_SIZE_Y ; i += THREAD_BLOCK_Y ){
			b[inx][iny+i]  = B[i*ldb];
		}
		__syncthreads();

		if( k < BLOCK_SIZE_X )  break;
	
		b_ptr = (float*)b ;
#pragma unroll		 		
		for( int i = 0; i < BLOCK_SIZE_X; i++  ){
			float A0_reg = A[0]  ; A += lda ;
			float A1_reg = A1[0] ; A1 += lda ;
// fetch b[i][:] into register b_reg[:]
#pragma unroll
			for( int j = 0 ; j < BLOCK_SIZE_Y ; j++){
				b_reg = b_ptr[j] ;
				c0[j] += A0_reg * b_reg ;
				c1[j] += A1_reg * b_reg ;	
			} 						
			b_ptr += (BLOCK_SIZE_Y+1) ;	// b_ptr = &b[i][0]
		}// for each column index of sub-matrix of A
		 __syncthreads();
			
		B += BLOCK_SIZE_X ;
	};
	
// rank k-update
  b_ptr = (float*)b ;
	for(int i = 0 ; i < k ; i++){
			float A0_reg = A[0]  ; A += lda ;
			float A1_reg = A1[0] ; A1 += lda ;		
#pragma unroll
			for( int j = 0 ; j < BLOCK_SIZE_Y ; j++){
				b_reg = b_ptr[j] ;
				c0[j] += A0_reg * b_reg ;
				c1[j] += A1_reg * b_reg ;	
			} 						
			b_ptr += (BLOCK_SIZE_Y+1) ;	// b_ptr = &b[i][0]			
	}	

// we need to modify this assertion since we modify two sub-matrix of C
    if( row >= m )  return;  
    store_block( n - iby, alpha, c0, beta, C, ldc);
    
    row += VECTOR_LENGTH ;
    if( row >= m )  return;
    
    C += VECTOR_LENGTH ;
    store_block( n - iby, alpha, c1, beta, C, ldc);
}


 
	  
//
//  Matrix-matrix multiplications
//  See http://www.netlib.org/blas/sgemm.f
//
//
//  C = alpha*A*B + beta*C
//
/*
void  method1_sgemmNN_wrapper(float* C, float* A, float* B, int hA, int wA, int wB )
{
	int m = hA ;
	int n = wB ;
	dim3 threads( THREAD_BLOCK_X, THREAD_BLOCK_Y );
	dim3 grid( (m+VECTOR_LENGTH*NUM_VECTOR-1)/(VECTOR_LENGTH*NUM_VECTOR), (n+BLOCK_SIZE_Y-1)/BLOCK_SIZE_Y ) ;
	sgemmNN_device<<<grid, threads>>>( m, n, A, hA, B, wA, C, hA, wA, 1.0, 0.0 );	
}
*/