Advanced Computing Platform for Theoretical Physics

commit大文件会使得服务器变得不稳定,请大家尽量只commit代码,不要commit大的文件。

Commit a61f7cf1 authored by mikeaclark's avatar mikeaclark
Browse files

Tidy of Float2 loads in blas kernels

git-svn-id: http://lattice.bu.edu/qcdalg/cuda/quda@491 be54200a-260c-0410-bdd7-ce6af2a381ab
parent bf3ede9d
......@@ -483,7 +483,7 @@ void axpbyCuda(double a, ParitySpinor x, double b, ParitySpinor y) {
int blocks = min(REDUCE_MAX_BLOCKS, max(x.length/REDUCE_THREADS, 1));
dim3 dimBlock(REDUCE_THREADS, 1, 1);
dim3 dimGrid(blocks, 1, 1);
blas_quda_bytes += 3*x.length*sizeof(x.precision);
blas_quda_bytes += 3*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
axpbyKernel<<<dimGrid, dimBlock>>>(a, (double*)x.spinor, b, (double*)y.spinor, x.length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -534,7 +534,7 @@ void xpyCuda(ParitySpinor x, ParitySpinor y) {
int blocks = min(REDUCE_MAX_BLOCKS, max(x.length/REDUCE_THREADS, 1));
dim3 dimBlock(REDUCE_THREADS, 1, 1);
dim3 dimGrid(blocks, 1, 1);
blas_quda_bytes += 3*x.length*sizeof(x.precision);
blas_quda_bytes += 3*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
xpyKernel<<<dimGrid, dimBlock>>>((double*)x.spinor, (double*)y.spinor, x.length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -584,7 +584,7 @@ void axpyCuda(double a, ParitySpinor x, ParitySpinor y) {
int blocks = min(REDUCE_MAX_BLOCKS, max(x.length/REDUCE_THREADS, 1));
dim3 dimBlock(REDUCE_THREADS, 1, 1);
dim3 dimGrid(blocks, 1, 1);
blas_quda_bytes += 3*x.length*sizeof(x.precision);
blas_quda_bytes += 3*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
axpyKernel<<<dimGrid, dimBlock>>>(a, (double*)x.spinor, (double*)y.spinor, x.length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -634,7 +634,7 @@ void xpayCuda(ParitySpinor x, double a, ParitySpinor y) {
int blocks = min(REDUCE_MAX_BLOCKS, max(x.length/REDUCE_THREADS, 1));
dim3 dimBlock(REDUCE_THREADS, 1, 1);
dim3 dimGrid(blocks, 1, 1);
blas_quda_bytes += 3*x.length*sizeof(x.precision);
blas_quda_bytes += 3*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
xpayKernel<<<dimGrid, dimBlock>>>((double*)x.spinor, a, (double*)y.spinor, x.length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -685,7 +685,7 @@ void mxpyCuda(ParitySpinor x, ParitySpinor y) {
int blocks = min(REDUCE_MAX_BLOCKS, max(x.length/REDUCE_THREADS, 1));
dim3 dimBlock(REDUCE_THREADS, 1, 1);
dim3 dimGrid(blocks, 1, 1);
blas_quda_bytes += 3*x.length*sizeof(x.precision);
blas_quda_bytes += 3*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
mxpyKernel<<<dimGrid, dimBlock>>>((double*)x.spinor, (double*)y.spinor, x.length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -729,7 +729,7 @@ void axCuda(double a, ParitySpinor x) {
int blocks = min(REDUCE_MAX_BLOCKS, max(x.length/REDUCE_THREADS, 1));
dim3 dimBlock(REDUCE_THREADS, 1, 1);
dim3 dimGrid(blocks, 1, 1);
blas_quda_bytes += 2*x.length*sizeof(x.precision);
blas_quda_bytes += 2*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
axKernel<<<dimGrid, dimBlock>>>(a, (double*)x.spinor, x.length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -743,15 +743,21 @@ void axCuda(double a, ParitySpinor x) {
blas_quda_flops += x.length;
}
float2 __device__ make_Float2(float x, float y) {
return make_float2(x, y);
}
double2 __device__ make_Float2(double x, double y) {
return make_double2(x, y);
}
template <typename Float2>
__global__ void caxpyKernel(Float2 a, Float2 *x, Float2 *y, int len) {
unsigned int i = blockIdx.x*(blockDim.x) + threadIdx.x;
unsigned int gridSize = gridDim.x*blockDim.x;
while (i < len) {
Float2 Z;
Z.x = x[i].x;
Z.y = x[i].y;
Float2 Z = make_Float2(x[i].x, x[i].y);
y[i].x += a.x*Z.x - a.y*Z.y;
y[i].y += a.y*Z.x + a.x*Z.y;
i += gridSize;
......@@ -785,7 +791,8 @@ void caxpyCuda(double2 a, ParitySpinor x, ParitySpinor y) {
int blocks = min(REDUCE_MAX_BLOCKS, max(length/REDUCE_THREADS, 1));
dim3 dimBlock(REDUCE_THREADS, 1, 1);
dim3 dimGrid(blocks, 1, 1);
blas_quda_bytes += 3*x.length*sizeof(x.precision);
blas_quda_bytes += 3*x.length*x.precision;
blas_quda_flops += 4*x.length;
if (x.precision == QUDA_DOUBLE_PRECISION) {
caxpyKernel<<<dimGrid, dimBlock>>>(a, (double2*)x.spinor, (double2*)y.spinor, length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -800,7 +807,6 @@ void caxpyCuda(double2 a, ParitySpinor x, ParitySpinor y) {
float2 af2 = make_float2((float)a.x, (float)a.y);
caxpyHKernel<<<dimGrid, dimBlock>>>(af2, (short4*)y.spinor, (float*)y.spinorNorm, y.length/spinorSiteSize);
}
blas_quda_flops += 4*x.length;
}
template <typename Float2>
......@@ -809,11 +815,8 @@ __global__ void caxpbyKernel(Float2 a, Float2 *x, Float2 b, Float2 *y, int len)
unsigned int i = blockIdx.x*(blockDim.x) + threadIdx.x;
unsigned int gridSize = gridDim.x*blockDim.x;
while (i < len) {
Float2 Z1, Z2;
Z1.x = x[i].x;
Z2.y = x[i].y;
Z1.x = y[i].x;
Z2.y = y[i].y;
Float2 Z1 = make_Float2(x[i].x, x[i].y);
Float2 Z2 = make_Float2(y[i].x, y[i].y);
y[i].x = a.x*Z1.x + b.x*Z2.x - a.y*Z1.y - b.y*Z2.y;
y[i].y = a.y*Z1.x + b.y*Z2.x + a.x*Z1.y + b.x*Z2.y;
i += gridSize;
......@@ -846,7 +849,7 @@ void caxpbyCuda(double2 a, ParitySpinor x, double2 b, ParitySpinor y) {
int blocks = min(REDUCE_MAX_BLOCKS, max(length/REDUCE_THREADS, 1));
dim3 dimBlock(REDUCE_THREADS, 1, 1);
dim3 dimGrid(blocks, 1, 1);
blas_quda_bytes += 3*x.length*sizeof(x.precision);
blas_quda_bytes += 3*x.length*x.precision;
blas_quda_flops += 7*x.length;
if (x.precision == QUDA_DOUBLE_PRECISION) {
caxpbyKernel<<<dimGrid, dimBlock>>>(a, (double2*)x.spinor, b, (double2*)y.spinor, length);
......@@ -872,21 +875,16 @@ __global__ void cxpaypbzKernel(Float2 *x, Float2 a, Float2 *y, Float2 b, Float2
unsigned int i = blockIdx.x*(blockDim.x) + threadIdx.x;
unsigned int gridSize = gridDim.x*blockDim.x;
while (i < len) {
Float2 T1, T2, T3;
T1.x = x[i].x;
T1.y = x[i].y;
T2.x = y[i].x;
T2.y = y[i].y;
T3.x = z[i].x;
T3.y = z[i].y;
Float2 T1 = make_Float2(x[i].x, x[i].y);
Float2 T2 = make_Float2(y[i].x, y[i].y);
Float2 T3 = make_Float2(z[i].x, z[i].y);
T1.x += a.x*T2.x - a.y*T2.y;
T1.y += a.y*T2.x + a.x*T2.y;
T1.x += b.x*T3.x - b.y*T3.y;
T1.y += b.y*T3.x + b.x*T3.y;
z[i].x = T1.x;
z[i].y = T1.y;
z[i] = make_Float2(T1.x, T1.y);
i += gridSize;
}
......@@ -920,7 +918,8 @@ void cxpaypbzCuda(ParitySpinor x, double2 a, ParitySpinor y, double2 b, ParitySp
int blocks = min(REDUCE_MAX_BLOCKS, max(length/REDUCE_THREADS, 1));
dim3 dimBlock(REDUCE_THREADS, 1, 1);
dim3 dimGrid(blocks, 1, 1);
blas_quda_bytes += 4*x.length*sizeof(x.precision);
blas_quda_bytes += 4*x.length*x.precision;
blas_quda_flops += 8*x.length;
if (x.precision == QUDA_DOUBLE_PRECISION) {
cxpaypbzKernel<<<dimGrid, dimBlock>>>((double2*)x.spinor, a, (double2*)y.spinor, b, (double2*)z.spinor, length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -939,7 +938,6 @@ void cxpaypbzCuda(ParitySpinor x, double2 a, ParitySpinor y, double2 b, ParitySp
float2 bf2 = make_float2((float)b.x, (float)b.y);
cxpaypbzHKernel<<<dimGrid, dimBlock>>>(af2, bf2, (short4*)z.spinor, (float*)z.spinorNorm, z.length/spinorSiteSize);
}
blas_quda_flops += 8*x.length;
}
template <typename Float>
......@@ -988,7 +986,7 @@ void axpyZpbxCuda(double a, ParitySpinor x, ParitySpinor y, ParitySpinor z, doub
int blocks = min(REDUCE_MAX_BLOCKS, max(x.length/REDUCE_THREADS, 1));
dim3 dimBlock(REDUCE_THREADS, 1, 1);
dim3 dimGrid(blocks, 1, 1);
blas_quda_bytes += 5*x.length*sizeof(x.precision);
blas_quda_bytes += 5*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
axpyZpbxKernel<<<dimGrid, dimBlock>>>(a, (double*)x.spinor, (double*)y.spinor, (double*)z.spinor, b, x.length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -1013,25 +1011,23 @@ __global__ void caxpbypzYmbwKernel(Float2 a, Float2 *x, Float2 b, Float2 *y, Flo
unsigned int i = blockIdx.x*(blockDim.x) + threadIdx.x;
unsigned int gridSize = gridDim.x*blockDim.x;
while (i < len) {
Float2 X, Y, Z, W;
X.x = x[i].x;
X.y = x[i].y;
Y.x = y[i].x;
Y.y = y[i].y;
W.x = w[i].x;
W.y = w[i].y;
Z.x = a.x*X.x - a.y*X.y;
Z.y = a.y*X.x + a.x*X.y;
Float2 X = make_Float2(x[i].x, x[i].y);
Float2 Z = make_Float2(z[i].x, z[i].y);
Z.x += a.x*X.x - a.y*X.y;
Z.y += a.y*X.x + a.x*X.y;
Float2 Y = make_Float2(y[i].x, y[i].y);
Z.x += b.x*Y.x - b.y*Y.y;
Z.y += b.y*Y.x + b.x*Y.y;
z[i] = make_Float2(Z.x, Z.y);
Float2 W = make_Float2(w[i].x, w[i].y);
Y.x -= b.x*W.x - b.y*W.y;
Y.y -= b.y*W.x + b.x*W.y;
z[i].x += Z.x;
z[i].y += Z.y;
y[i].x = Y.x;
y[i].y = Y.y;
y[i] = make_Float2(Y.x, Y.y);
i += gridSize;
}
}
......@@ -1073,7 +1069,7 @@ void caxpbypzYmbwCuda(double2 a, ParitySpinor x, double2 b, ParitySpinor y,
int blocks = min(REDUCE_MAX_BLOCKS, max(length/REDUCE_THREADS, 1));
dim3 dimBlock(REDUCE_THREADS, 1, 1);
dim3 dimGrid(blocks, 1, 1);
blas_quda_bytes += 6*x.length*sizeof(x.precision);
blas_quda_bytes += 6*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
caxpbypzYmbwKernel<<<dimGrid, dimBlock>>>(a, (double2*)x.spinor, b, (double2*)y.spinor,
(double2*)z.spinor, (double2*)w.spinor, length);
......@@ -1200,7 +1196,7 @@ template <typename Float>
double sumCuda(ParitySpinor a) {
blas_quda_flops += a.length;
blas_quda_bytes += a.length*sizeof(a.precision);
blas_quda_bytes += a.length*a.precision;
if (a.precision == QUDA_DOUBLE_PRECISION) {
return sumFCuda((double*)a.spinor, a.length);
} else if (a.precision == QUDA_SINGLE_PRECISION) {
......@@ -1255,7 +1251,7 @@ template <typename Float>
double normCuda(ParitySpinor a) {
blas_quda_flops += 2*a.length;
blas_quda_bytes += a.length*sizeof(a.precision);
blas_quda_bytes += a.length*a.precision;
if (a.precision == QUDA_DOUBLE_PRECISION) {
return normFCuda((double*)a.spinor, a.length);
} else if (a.precision == QUDA_SINGLE_PRECISION) {
......@@ -1314,7 +1310,7 @@ template <typename Float>
double reDotProductCuda(ParitySpinor a, ParitySpinor b) {
blas_quda_flops += 2*a.length;
checkSpinor(a, b);
blas_quda_bytes += 2*a.length*sizeof(a.precision);
blas_quda_bytes += 2*a.length*a.precision;
if (a.precision == QUDA_DOUBLE_PRECISION) {
return reDotProductFCuda((double*)a.spinor, (double*)b.spinor, a.length);
} else if (a.precision == QUDA_SINGLE_PRECISION) {
......@@ -1381,7 +1377,7 @@ template <typename Float>
double axpyNormCuda(double a, ParitySpinor x, ParitySpinor y) {
blas_quda_flops += 4*x.length;
checkSpinor(x,y);
blas_quda_bytes += 3*x.length*sizeof(x.precision);
blas_quda_bytes += 3*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
return axpyNormFCuda(a, (double*)x.spinor, (double*)y.spinor, x.length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -1449,7 +1445,7 @@ template <typename Float>
double xmyNormCuda(ParitySpinor x, ParitySpinor y) {
blas_quda_flops +=3*x.length;
checkSpinor(x,y);
blas_quda_bytes += 3*x.length*sizeof(x.precision);
blas_quda_bytes += 3*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
return xmyNormFCuda((double*)x.spinor, (double*)y.spinor, x.length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -1522,7 +1518,7 @@ double2 cDotProductCuda(ParitySpinor x, ParitySpinor y) {
blas_quda_flops += 4*x.length;
checkSpinor(x,y);
int length = x.length/2;
blas_quda_bytes += 2*x.length*sizeof(x.precision);
blas_quda_bytes += 2*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
char c = NULL;
return cDotProductFCuda((double2*)x.spinor, (double2*)y.spinor, c, length);
......@@ -1609,7 +1605,7 @@ double2 xpaycDotzyCuda(ParitySpinor x, double a, ParitySpinor y, ParitySpinor z)
checkSpinor(x,y);
checkSpinor(x,z);
int length = x.length/2;
blas_quda_bytes += 4*x.length*sizeof(x.precision);
blas_quda_bytes += 4*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
return xpaycDotzyFCuda((double2*)x.spinor, a, (double2*)y.spinor, (double2*)z.spinor, length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -1699,7 +1695,7 @@ double3 cDotProductNormACuda(ParitySpinor x, ParitySpinor y) {
blas_quda_flops += 6*x.length;
checkSpinor(x,y);
int length = x.length/2;
blas_quda_bytes += 2*x.length*sizeof(x.precision);
blas_quda_bytes += 2*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
return cDotProductNormAFCuda((double2*)x.spinor, (double2*)y.spinor, length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -1786,7 +1782,7 @@ double3 cDotProductNormBCuda(ParitySpinor x, ParitySpinor y) {
blas_quda_flops += 6*x.length;
checkSpinor(x,y);
int length = x.length/2;
blas_quda_bytes += 2*x.length*sizeof(x.precision);
blas_quda_bytes += 2*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
return cDotProductNormBFCuda((double2*)x.spinor, (double2*)y.spinor, length);
} else if (x.precision == QUDA_SINGLE_PRECISION) {
......@@ -1810,27 +1806,21 @@ template <typename Float2>
#define REDUCE_FUNC_NAME(suffix) caxpbypzYmbwcDotProductWYNormYF##suffix
#define REDUCE_TYPES Float2 a, Float2 *x, Float2 b, Float2 *y, Float2 *z, Float2 *w, Float2 *u
#define REDUCE_PARAMS a, x, b, y, z, w, u
#define REDUCE_X_AUXILIARY(i) \
Float2 W, X, Y; \
X.x = x[i].x; \
X.y = x[i].y; \
Y.x = y[i].x; \
Y.y = y[i].y; \
W.x = w[i].x; \
W.y = w[i].y;
#define REDUCE_Y_AUXILIARY(i) \
Float2 Z; \
Z.x = a.x*X.x - a.y*X.y; \
Z.y = a.y*X.x + a.x*X.y; \
Z.x += b.x*Y.x - b.y*Y.y; \
Z.y += b.y*Y.x + b.x*Y.y; \
Y.x -= b.x*W.x - b.y*W.y; \
#define REDUCE_X_AUXILIARY(i) \
Float2 X = make_Float2(x[i].x, x[i].y); \
Float2 Y = make_Float2(y[i].x, y[i].y); \
Float2 W = make_Float2(w[i].x, w[i].y);
#define REDUCE_Y_AUXILIARY(i) \
Float2 Z = make_Float2(z[i].x, z[i].y); \
Z.x += a.x*X.x - a.y*X.y; \
Z.y += a.y*X.x + a.x*X.y; \
Z.x += b.x*Y.x - b.y*Y.y; \
Z.y += b.y*Y.x + b.x*Y.y; \
Y.x -= b.x*W.x - b.y*W.y; \
Y.y -= b.y*W.x + b.x*W.y;
#define REDUCE_Z_AUXILIARY(i) \
z[i].x += Z.x; \
z[i].y += Z.y; \
y[i].x = Y.x; \
y[i].y = Y.y;
z[i] = make_Float2(Z.x, Z.y); \
y[i] = make_Float2(Y.x, Y.y);
#define REDUCE_X_OPERATION(i) (u[i].x*y[i].x + u[i].y*y[i].y)
#define REDUCE_Y_OPERATION(i) (u[i].x*y[i].y - u[i].y*y[i].x)
#define REDUCE_Z_OPERATION(i) (y[i].x*y[i].x + y[i].y*y[i].y)
......@@ -1919,7 +1909,7 @@ double3 caxpbypzYmbwcDotProductWYNormYQuda(double2 a, ParitySpinor x, double2 b,
checkSpinor(x,w);
checkSpinor(x,u);
int length = x.length/2;
blas_quda_bytes += 7*x.length*sizeof(x.precision);
blas_quda_bytes += 7*x.length*x.precision;
if (x.precision == QUDA_DOUBLE_PRECISION) {
return caxpbypzYmbwcDotProductWYNormYFCuda(a, (double2*)x.spinor, b, (double2*)y.spinor, (double2*)z.spinor,
(double2*)w.spinor, (double2*)u.spinor, length);
......
......@@ -202,7 +202,7 @@ int main(int argc, char** argv) {
benchmark(kernels[i]);
}
nIters = 1000;
nIters = 300;
for (int i = 0; i <= 20; i++) {
blas_quda_flops = 0;
blas_quda_bytes = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment