

#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime_api.h>
#include <cublas.h>

typedef float2 Complex;

/*
 *   Given matrix A(1:hA, 1:wA) and B(1:wA, 1:wB) 
 *
 *   compute C = A*B via CUBLAS
 *
 */
 
void matrixMul_cublas(Complex* C, Complex* A, Complex* B, 
		const int hA, const int wA, const int wB,
		const int lda, const int ldb, const int ldc,
		Complex alpha, Complex beta )
{
	cublasStatus stat ;
	Complex *devPtrA , *devPtrB , *devPtrC ;
 
	stat = cublasInit() ; // initilization of CUDA application
	if (CUBLAS_STATUS_SUCCESS  != stat ){
		printf("Error: cublasInit fails \n");
		exit(0) ;		
	}
	
	stat = cublasAlloc( lda*wA, sizeof(Complex), (void**) &devPtrA) ;
	if ( CUBLAS_STATUS_SUCCESS != stat ){
		printf("device memory allocatin failed for matrix A \n");
		exit(0) ;
	}

	stat = cublasAlloc( ldb*wB, sizeof(Complex), (void**) &devPtrB) ;
	if ( CUBLAS_STATUS_SUCCESS != stat ){
		printf("device memory allocatin failed for matrix B \n");
		exit(0) ;
	}

	stat = cublasAlloc( ldc*wB, sizeof(Complex), (void**) &devPtrC) ;
	if ( CUBLAS_STATUS_SUCCESS != stat ){
		printf("device memory allocatin failed for matrix C \n");
		exit(0) ;
	}

// transfer host data to device
	stat = cublasSetMatrix( hA, wA, sizeof(Complex), A, lda, devPtrA, lda) ;
	if ( CUBLAS_STATUS_SUCCESS != stat ){
		printf("Error: transfer host data to device, hA --> dA \n");
		exit(0) ;
	}

	stat = cublasSetMatrix( wA, wB, sizeof(Complex), B, ldb, devPtrB, ldb) ;
	if ( CUBLAS_STATUS_SUCCESS != stat ){
		printf("Error: transfer host data to device, hB --> dB \n");
		exit(0) ;
	}

	stat = cublasSetMatrix( hA, wB, sizeof(Complex), C, ldc, devPtrC, ldc) ;
	if ( CUBLAS_STATUS_SUCCESS != stat ){
		printf("Error: transfer host data to device, hB --> dB \n");
		exit(0) ;
	}

// compute C = alpha*A*B + beta*C in device
	cublasCgemm('N', 'N', hA, wB, wA, alpha, devPtrA, lda, 
		devPtrB, ldb, beta, devPtrC, ldc) ;

	// make sure that all threads are done
	cudaThreadSynchronize();  

	stat = cublasGetError() ;
	if ( CUBLAS_STATUS_SUCCESS != stat ){
		printf("Error: cublasDgemm or cublasSgemm fails \n");
		exit(0) ;
	}

	stat = cublasGetMatrix(hA, wB, sizeof(Complex), devPtrC, ldc, C, ldc) ;
	cudaThreadSynchronize();  
	
	if ( CUBLAS_STATUS_SUCCESS != stat ){
		printf("Error: device to host, dC --> hC \n");
		exit(0) ;
	}

	cublasFree( devPtrA ) ;	 cublasFree( devPtrB ) ;  cublasFree( devPtrC ) ;

	stat = cublasShutdown() ;
	if ( CUBLAS_STATUS_SUCCESS  != stat ){
		printf("Error: cublasShutdown fails \n");
		exit(0) ;	
	}

}