

/*
 *   File: check_general_sgemm_suqare.cpp
 *   author: Lung-Sheng Chien
 *		Department of Mathematics, Tsing Hua univeristy, R.O.C. (Taiwan).
 *		Email: d947207@oz.nthu.edu.tw
 *	 date: 2010/01/31
 *
 *	 description: automatically check C = alpha*A*B + beta*C for any m,n,k
 *
 *		one can change value of alpha and beta in function "check_sgemm_square"
 */


#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <math.h>
#include <assert.h>

#include <iostream>
#include <fstream>
#include <iomanip>

using namespace std ;

#include <lsc_cuda_utility.h>

typedef  void  (*sgemmWrapper_prototype)(CUfunction hfunc,
		CUdeviceptr C, CUdeviceptr A, CUdeviceptr B, int hA, int wA, int wB,
		int lda, int ldb, int ldc,
		float alpha, float beta ) ;

void  check_general_sgemm_square_unit( sgemmWrapper_prototype sgemmWrapper, CUfunction  sgemm, 
		unsigned int m, unsigned int n, unsigned int k, float alpha, float beta,
		float &rel_max_err ) ;

void  check_general_sgemm_square( char* cubin_filename, char* sgemm_funcName, 
						   sgemmWrapper_prototype sgemmWrapper )
{
 
	cout << "check sgemm C = alpha*A*B + beta *C on square matriices A, B, C " << endl ;

// ------------------ load context ------------------------
	CUresult status ;

// Create module from binary file
	CUmodule cuModule;
	status = cuModuleLoad(&cuModule, cubin_filename );
	if ( CUDA_SUCCESS != status ){
		cerr << "Error: module " << cubin_filename << " cannot be loaded" << endl ; 
		exit(1) ;
	}else{
		cout << "Succ: load module " << cubin_filename << endl ; 
	}
 
// Get function handle from module
	CUfunction  sgemm ;
	status = cuModuleGetFunction(&sgemm, cuModule, sgemm_funcName ) ;
	if ( CUDA_SUCCESS != status ){
		cerr << "Error: kernel " <<  sgemm_funcName << " cannot be found" << endl ;
		exit(1) ;
	}else{
		cout << "Succ: load kernel " <<  sgemm_funcName << endl ;
	}

// --------------------------------------------------------
	CUdevice device ;
	cuCtxGetDevice( &device) ;
	char dev_name[128] ;
	cuDeviceGetName(dev_name, 128, device) ;
	
	float  rel_max_err = 1.0 ;
	float  eps = 1.E-4f ;
	float  alpha = 72.41f ;
	float  beta  = 37.72f ; 
	
	int n1, n2, n3 ;
	for( n1 = 5 ; n1 <= 257 ; n1++){
		for( n2 = 5 ; n2 <= 257 ; n2++){
			for( n3 = 5 ; n3 <= 129 ; n3++){
					check_general_sgemm_square_unit(sgemmWrapper, sgemm,  
							n1, n2, n3, alpha, beta, rel_max_err ) ;

					if ( eps < rel_max_err ){
						printf("Error: rel_max_err(n = %d) = %7.2E \n", n1, rel_max_err);
						return ;
					}
					printf("(n1,n2,n3) = (%d, %d, %d) is complete with rel_max_err = %.2E\n", 
						n1, n2, n3, rel_max_err );
				}// for n3
		} // for n2 
	}// for n1

	printf("alpha = %f, beta = %f: complete\n", alpha, beta );
}


void  check_general_sgemm_square_unit( sgemmWrapper_prototype sgemmWrapper, CUfunction  sgemm, 
		unsigned int m, unsigned int n, unsigned int k, float alpha, float beta,
		float &rel_max_err )
{
	int lda, ldb, ldc ;
	float* h_A ;
	float* h_B ;
	float* h_C ;
	float* reference ;
	float max_err = 0.0 ; 

	lda = m ;
	ldb = k ;
	ldc = m ; 

// allocate host memory for matrices A and B
	unsigned int size_A = lda*k ;
	unsigned int mem_size_A = sizeof(float) * size_A ;
	h_A = (float*) malloc(mem_size_A); assert( h_A ) ;

	unsigned int size_B = ldb*n;
	unsigned int mem_size_B = sizeof(float) * size_B;
	h_B = (float*) malloc(mem_size_B); assert( h_B ) ;

    // allocate host memory for the result
	unsigned int size_C = ldc*n;
	unsigned int mem_size_C = sizeof(float) * size_C;
	h_C = (float*) malloc(mem_size_C); assert( h_C ) ;
	
	reference = (float*) malloc(mem_size_C); assert( reference ) ;

	// initialize host memory
	randomInit(h_A, size_A);
	randomInit(h_B, size_B);
	randomInit(h_C, size_C);
	memcpy( reference, h_C, mem_size_C ) ;

    // allocate device memory
	CUdeviceptr d_A, d_B, d_C;
	cutilDrvSafeCallNoSync( cuMemAlloc( &d_C, mem_size_C ) ) ;
	cutilDrvSafeCallNoSync( cuMemAlloc( &d_A, mem_size_A ) ) ;
	cutilDrvSafeCallNoSync( cuMemAlloc( &d_B, mem_size_B ) ) ;
	
// step 2: copy data from host to device
// cuMemcpyHtoD (CUdeviceptr dstDevice, const void *srcHost, unsigned int ByteCount)
	cutilDrvSafeCallNoSync( cuMemcpyHtoD(d_A, h_A, mem_size_A) ) ;
	cutilDrvSafeCallNoSync( cuMemcpyHtoD(d_B, h_B, mem_size_B) ) ;
	cutilDrvSafeCallNoSync( cuMemcpyHtoD(d_C, h_C, mem_size_C) ) ;
	
// step 3: execute the kernel and evaluate average timing
	(*sgemmWrapper)( sgemm, d_C, d_A, d_B, m, k, n, lda, ldb, ldc, alpha, beta ) ;
	cudaThreadSynchronize();

	// check if kernel execution generated and error
	cutilCheckMsg("Kernel execution failed");

    // copy result from device to host
	cutilDrvSafeCallNoSync( cuMemcpyDtoH(h_C, d_C, mem_size_C) ) ;
	
// step 4: compute reference solution vi CUBLAS
	matrixMul_cublas(reference, h_A, h_B, m, k, n, lda, ldb, ldc, alpha, beta ) ;
 
// step 5: check result
	compare_supnorm( m, n, reference, ldc, h_C, ldc, max_err, rel_max_err ) ;

// step 6: cleanup memory
	free( h_A );
	free( h_B );
	free( h_C );
	free( reference);

	cutilDrvSafeCallNoSync( cuMemFree(d_A) ) ;
	cutilDrvSafeCallNoSync( cuMemFree(d_B) ) ;
	cutilDrvSafeCallNoSync( cuMemFree(d_C) ) ;

}