
/*
 *   	File:  method5_v2.cu
 *   	author: Lung-Sheng Chien
 *			Department of Mathematics, Tsing Hua univeristy, R.O.C. (Taiwan).
 *			Email: d947207@oz.nthu.edu.tw
 *	 	date: 2010/01/30
 *
 *		description: see HandTunedSgemm_2010_v1.1.pdf 
 * 
 * How to compile .cu to .cubin  
 *	"C:\CUDA\bin64\nvcc.exe"  -ccbin "C:\Program Files (x86)\Microsoft Visual Studio 8\VC\bin" -I"C:\Program Files (x86)\Microsoft Visual Studio 8\VC\include"  -O2 -arch compute_13 -code sm_13 -cubin  method5_v2.cu
 *
 */

// Written by Vasily Volkov.
// Copyright (c) 2009, The Regents of the University of California.
// All rights reserved.

//#include "gpu_lapack_internal.h"

//
//  have to unroll some of the loops manually
//
__device__ void rank1_update( float a, const float *b, float *c )
{
	c[0] += a*b[0];
	c[1] += a*b[1];
	c[2] += a*b[2];
	c[3] += a*b[3];
	c[4] += a*b[4];
	c[5] += a*b[5];
	c[6] += a*b[6];
	c[7] += a*b[7];
	c[8] += a*b[8];
	c[9] += a*b[9];
	c[10] += a*b[10];
	c[11] += a*b[11];
	c[12] += a*b[12];
	c[13] += a*b[13];
	c[14] += a*b[14];
	c[15] += a*b[15];
}

__device__ void rankk_update( int k, const float *A, int lda, const float *b, int ldb, float *c )
{
    if( k <= 0 ) return;

    int i = 0;
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    rank1_update( A[0], &b[i*ldb], c ); if( ++i >= k ) return; A += lda;
    rank1_update( A[0], &b[i*ldb], c );
}

__device__ void store_block( int num, float alpha, float *c, float beta, float *C, int ldc )
{
    if( num <= 0 ) return;
    
    if( beta == 0 )
    {
        //
        //  for the case when C is initialized with inf or NaN
        //
        int i = 0; 
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  

        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  

        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++];
    }
    else
    {
        int i = 0; 
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  

        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  

        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0]; if( i >= num ) return; C += ldc;  
        C[0] = alpha*c[i++] + beta*C[0];
    }
    
}

// case 1: no check 
__device__  void  method5_sgemmNN_case1( int m, int n, float *A, int lda, 
float *B, int ldb, float* C, int ldc, int k, float alpha, float beta,
float *A_bound, float* B_bound, float* b )
{	
	const int inx = threadIdx.x;
	const int iny = threadIdx.y;
	const int ibx = blockIdx.x * 64;
	const int iby = blockIdx.y * 16;
	const int row = ibx + inx + iny*16;
	
	A += row;
	B += inx + ( iby + iny ) * ldb;
	C += row  + iby * ldc;
	
	float c[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
	
	float* b_base = (float*)b + inx*17 + iny ;
	for( ; k > 0; k -= 16 )
	{
#pragma unroll
		for( int i = 0; i < 16; i += 4 ){
//			b[inx][iny+i]  = B[i*ldb];
				b_base[i] = B[i*ldb];
		}	
		__syncthreads();

		if( k < 16 )  break;

		float *b_ptr = (float*)b ;
#pragma unroll		 		
		for( int i = 0; i < 16 ; i++  ){
			float A_reg = A[0]  ; A += lda ;
#pragma unroll
			for( int j = 0 ; j < 16 ; j++){
				float b_reg = b_ptr[j] ;
				c[j] += A_reg * b_reg ;
			} 						
			b_ptr += 17 ;	// b_ptr = &b[i][0]
		}// for each column index of sub-matrix of A
		__syncthreads();
		
		B += 16;
	};

//    rankk_update( k, A, lda, &b[0][0], 17, c );
	rankk_update( k, A, lda, b, 17, c );

	if( row >= m )  return;
    
	store_block( n - iby, alpha, c, beta, C, ldc);
}

// case 2: check B
__device__  void  method5_sgemmNN_case2( int m, int n, float *A, int lda, 
float *B, int ldb, float* C, int ldc, int k, float alpha, float beta,
float *A_bound, float* B_bound, float* b )
{	
	const int inx = threadIdx.x;
	const int iny = threadIdx.y;
	const int ibx = blockIdx.x * 64;
	const int iby = blockIdx.y * 16;
	const int row = ibx + inx + iny*16;
	
	A += row;
	B += inx + ( iby + iny ) * ldb;
	C += row  + iby * ldc;
	
	float c[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
	
 	float* b_base = (float*)b + inx*17 + iny ;
	for( ; k > 0; k -= 16 )
	{
//#pragma unroll
		for( int i = 0; i < 16; i += 4 ){
				if( &B[i*ldb] < B_bound ){
					b_base[i] = B[i*ldb];
				}else{
					break ;
				}
		}	
		__syncthreads();

		if( k < 16 )  break;

		float *b_ptr = (float*)b ;
#pragma unroll		 		
		for( int i = 0; i < 16 ; i++  ){
			float A_reg = A[0]  ; A += lda ;
#pragma unroll
			for( int j = 0 ; j < 16 ; j++){
				float b_reg = b_ptr[j] ;
				c[j] += A_reg * b_reg ;
			} 						
			b_ptr += 17 ;	// b_ptr = &b[i][0]
		}// for each column index of sub-matrix of A
		__syncthreads();
		
		B += 16;
	};

//    rankk_update( k, A, lda, &b[0][0], 17, c );
	rankk_update( k, A, lda, b, 17, c );

	if( row >= m )  return;
    
	store_block( n - iby, alpha, c, beta, C, ldc); 
}	

// case 3: check A
__device__  void  method5_sgemmNN_case3( int m, int n, float *A, int lda, 
float *B, int ldb, float* C, int ldc, int k, float alpha, float beta,
float *A_bound, float* B_bound, float* b )
{	
	const int inx = threadIdx.x;
	const int iny = threadIdx.y;
	const int ibx = blockIdx.x * 64;
	const int iby = blockIdx.y * 16;
	const int row = ibx + inx + iny*16;
	
	A += row;
	B += inx + ( iby + iny ) * ldb;
	C += row  + iby * ldc;
	
	float c[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
	
	float A_reg ;
	float* b_base = (float*)b + inx*17 + iny ;
	for( ; k > 0; k -= 16 )
	{
#pragma unroll
		for( int i = 0; i < 16; i += 4 ){
//			b[inx][iny+i]  = B[i*ldb];
				b_base[i] = B[i*ldb];
		}
		__syncthreads();

		if( k < 16 )  break;
        
		float *b_ptr = (float*)b ;
//#pragma unroll		 		
		for( int i = 0; i < 16 ; i++  ){
			if ( A < A_bound ){
				A_reg = A[0]  ; A += lda ;
			}else{
				break ;
			}
#pragma unroll
			for( int j = 0 ; j < 16 ; j++){
				float b_reg = b_ptr[j] ;
				c[j] += A_reg * b_reg ;
			} 						
			b_ptr += 17 ;	// b_ptr = &b[i][0]
		}// for each column index of sub-matrix of A
	
		__syncthreads();
		
		B += 16;
	};

//    rankk_update( k, A, lda, &b[0][0], 17, c );
		float *b_ptr = (float*)b ;
		for( int i = 0; i < k ; i++ ){
			if ( A < A_bound ){
				A_reg = A[0]  ; A += lda ;
			}else{
				break ;
			}
#pragma unroll
			for( int j = 0 ; j < 16 ; j++){
				float b_reg = b_ptr[j] ;
				c[j] += A_reg * b_reg ;
			} 						
			b_ptr += 17 ;	// b_ptr = &b[i][0]
		}// for each column index of sub-matrix of A
		
    if( row >= m )  return;
    
    store_block( n - iby, alpha, c, beta, C, ldc);

}

// case 4: check A and check B
__device__  void  method5_sgemmNN_case4( int m, int n, float *A, int lda, 
float *B, int ldb, float* C, int ldc, int k, float alpha, float beta,
float *A_bound, float* B_bound, float* b )
{	
	const int inx = threadIdx.x;
	const int iny = threadIdx.y;
	const int ibx = blockIdx.x * 64;
	const int iby = blockIdx.y * 16;
	const int row = ibx + inx + iny*16;
	
	A += row;
	B += inx + ( iby + iny ) * ldb;
	C += row  + iby * ldc;
	
	float c[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
	
	float A_reg ;
 	float* b_base = (float*)b + inx*17 + iny ;
	for( ; k > 0; k -= 16 )
	{
//#pragma unroll
		for( int i = 0; i < 16; i += 4 ){
				if( &B[i*ldb] < B_bound ){
					b_base[i] = B[i*ldb];
				}else{
					break ;
				}
		}	
		__syncthreads();

		if( k < 16 )  break;

		float *b_ptr = (float*)b ;
//#pragma unroll		 		
		for( int i = 0; i < 16 ; i++  ){
			if ( A < A_bound ){
				A_reg = A[0]  ; A += lda ;
			}else{
				break ;
			}
#pragma unroll
			for( int j = 0 ; j < 16 ; j++){
				float b_reg = b_ptr[j] ;
				c[j] += A_reg * b_reg ;
			} 						
			b_ptr += 17 ;	// b_ptr = &b[i][0]
		}// for each column index of sub-matrix of A
		__syncthreads();
		
		B += 16;
	};

	float *b_ptr = (float*)b ;
//#pragma unroll		 		
	for( int i = 0; i < k ; i++  ){
			if ( A < A_bound ){
				A_reg = A[0]  ; A += lda ;
			}else{
				break ;
			}
#pragma unroll
			for( int j = 0 ; j < 16 ; j++){
				float b_reg = b_ptr[j] ;
				c[j] += A_reg * b_reg ;
			} 						
			b_ptr += 17 ;	// b_ptr = &b[i][0]
	}// for each column index of sub-matrix of A
		
	if( row >= m )  return;
    
	store_block( n - iby, alpha, c, beta, C, ldc); 
}	


// case 5: onyl rank-k update needs to consider B
__device__  void  method5_sgemmNN_case5( int m, int n, float *A, int lda, 
float *B, int ldb, float* C, int ldc, int k, float alpha, float beta,
float *A_bound, float* B_bound, float* b )
{	
	const int inx = threadIdx.x;
	const int iny = threadIdx.y;
	const int ibx = blockIdx.x * 64;
	const int iby = blockIdx.y * 16;
	const int row = ibx + inx + iny*16;
	
	A += row;
	B += inx + ( iby + iny ) * ldb;
	C += row  + iby * ldc;
	
	float c[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
	
	float* b_base = (float*)b + inx*17 + iny ;
	for( ; k > 0; k -= 16 )
	{
		if( k < 16 )  break;

#pragma unroll
		for( int i = 0; i < 16; i += 4 ){
				b_base[i] = B[i*ldb];
		}	
		__syncthreads();
		
		float *b_ptr = (float*)b ;
#pragma unroll		 		
		for( int i = 0; i < 16 ; i++  ){
			float A_reg = A[0]  ; A += lda ;
#pragma unroll
			for( int j = 0 ; j < 16 ; j++){
				float b_reg = b_ptr[j] ;
				c[j] += A_reg * b_reg ;
			} 						
			b_ptr += 17 ;	// b_ptr = &b[i][0]
		}// for each column index of sub-matrix of A				    
	    __syncthreads();
		
		B += 16;
	};
	
	if ( k > 0 ){
		for( int i = 0; i < 16; i += 4 ){
			if( B < B_bound ){
				b_base[i] = B[0] ;
				B += 4 * ldb ;
			}else{
				break ;
			}
		}
		__syncthreads();		
		rankk_update( k, A, lda, b, 17, c );
	}// for (k > 0)
	
	if( row >= m )  return;
    
	store_block( n - iby, alpha, c, beta, C, ldc);
}	

// case 6: check A and rank-k update needs to consider B
__device__  void  method5_sgemmNN_case6( int m, int n, float *A, int lda, 
float *B, int ldb, float* C, int ldc, int k, float alpha, float beta,
float *A_bound, float* B_bound, float* b )
{	
	const int inx = threadIdx.x;
	const int iny = threadIdx.y;
	const int ibx = blockIdx.x * 64;
	const int iby = blockIdx.y * 16;
	const int row = ibx + inx + iny*16;
	
	A += row;
	B += inx + ( iby + iny ) * ldb;
	C += row  + iby * ldc;
	
	float c[16] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
	
	float A_reg ;
	float* b_base = (float*)b + inx*17 + iny ;
	for( ; k > 0; k -= 16 )
	{
		if( k < 16 )  break;

#pragma unroll
		for( int i = 0; i < 16; i += 4 ){
				b_base[i] = B[i*ldb];
		}	
		__syncthreads();
		
		float *b_ptr = (float*)b ;
//#pragma unroll		 		
		for( int i = 0; i < 16 ; i++  ){
			if ( A < A_bound ){
				A_reg = A[0]  ; A += lda ;
			}else{
				break ;
			}
#pragma unroll
			for( int j = 0 ; j < 16 ; j++){
				float b_reg = b_ptr[j] ;
				c[j] += A_reg * b_reg ;
			} 						
			b_ptr += 17 ;	// b_ptr = &b[i][0]
		}// for each column index of sub-matrix of A				    
	    __syncthreads();
		
		B += 16;
	};
	
	if ( k > 0 ){
		for( int i = 0; i < 16; i += 4 ){
			if( B < B_bound ){
				b_base[i] = B[0] ;
				B += 4 * ldb ;
			}else{
				break ;
			}
		}
		__syncthreads();		
//		rankk_update( k, A, lda, b, 17, c );		
		float *b_ptr = (float*)b ;
		for( int i = 0; i < k ; i++  ){
			if ( A < A_bound ){
				A_reg = A[0]  ; A += lda ;
			}else{
				break ;
			}
#pragma unroll
			for( int j = 0 ; j < 16 ; j++){
				float b_reg = b_ptr[j] ;
				c[j] += A_reg * b_reg ;
			} 						
			b_ptr += 17 ;	// b_ptr = &b[i][0]
		}// for each column index of sub-matrix of A

	}// for (k > 0)
	
	if( row >= m )  return;
    
	store_block( n - iby, alpha, c, beta, C, ldc);
}	



//
//  C = alpha*A*B + beta*C
//

//	smem = 1188
//	reg  = 39
//  active threads = 384

__global__ void  method5_sgemmNN( int m, int n, const float *A, int lda, 
const float *B, int ldb, float* C, int ldc, int k, float alpha, float beta,
float *A_bound, float* B_bound, int category )
{
	__shared__ float b[16][17];
	
	int case_sel ; 
	if ( 1 == category ){
		if ( blockIdx.y < (gridDim.y - 1) ){
			case_sel = 1 ;
		}else{
			case_sel = 5 ;
		}
		
	}else if ( 2 == category ){
		if ( blockIdx.y < (gridDim.y - 1) ){
			case_sel = 1 ;
		}else{
			case_sel = 2 ;
		}		
		
	}else if ( 3 == category ){
		if ( blockIdx.y < (gridDim.y - 1) ){
			case_sel = 1 ;
		}else{
			if ( blockIdx.x < (gridDim.x-1) ){
				case_sel = 5 ;
			}else{
				case_sel = 6 ;
			}
		}		 
		
	}else{ // category = 4
		if ( blockIdx.y < (gridDim.y - 1) ){
			case_sel = 1 ;
		}else{
			if ( blockIdx.x < (gridDim.x-1) ){
				case_sel = 2 ;
			}else{
				case_sel = 4 ;
			}
		}		  
	
	}// if (category)
	
	switch( case_sel ){
	case 1:
		method5_sgemmNN_case1( m, n, (float*)A, lda, (float*)B, ldb, C, ldc, k, alpha, beta, A_bound, B_bound,
					(float*)b ) ;		
		break ;
	case 2: 
		method5_sgemmNN_case2( m, n, (float*)A, lda, (float*)B, ldb, C, ldc, k, alpha, beta, A_bound, B_bound,
					(float*)b ) ;		
		break ;
	case 4:
		method5_sgemmNN_case4( m, n, (float*)A, lda, (float*)B, ldb, C, ldc, k, alpha, beta, A_bound, B_bound,
					(float*)b ) ;			
		break ;
	case 5:
		method5_sgemmNN_case5( m, n, (float*)A, lda, (float*)B, ldb, C, ldc, k, alpha, beta, A_bound, B_bound,
					(float*)b ) ;			
		break ;
	default: // case 6
		method5_sgemmNN_case6( m, n, (float*)A, lda, (float*)B, ldb, C, ldc, k, alpha, beta, A_bound, B_bound,
					(float*)b ) ;					
		break;
	}// switch( case_sel)
	
}	