Advanced Computing Platform for Theoretical Physics

Commit 72aac654 authored by mikeaclark's avatar mikeaclark
Browse files

Clean up of spinor_quda.cpp

git-svn-id: http://lattice.bu.edu/qcdalg/cuda/quda@405 be54200a-260c-0410-bdd7-ce6af2a381ab
parent 937f5509
......@@ -537,7 +537,7 @@ int dslashCudaSharedBytes() {
// Apply the even-odd preconditioned Dirac operator
void MatPCCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in, double kappa,
ParitySpinor tmp, MatPCType matpc_type) {
ParitySpinor tmp, MatPCType matpc_type, int dagger) {
checkSpinor(in, out);
checkSpinor(in, tmp);
......@@ -546,64 +546,27 @@ void MatPCCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in, double kappa,
if (in.precision == QUDA_DOUBLE_PRECISION) {
if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashDCuda(tmp, gauge, in, 1, 0);
dslashXpayDCuda(out, gauge, tmp, 0, 0, in, kappa2);
dslashDCuda(tmp, gauge, in, 1, dagger);
dslashXpayDCuda(out, gauge, tmp, 0, dagger, in, kappa2);
} else {
dslashDCuda(tmp, gauge, in, 0, 0);
dslashXpayDCuda(out, gauge, tmp, 1, 0, in, kappa2);
dslashDCuda(tmp, gauge, in, 0, dagger);
dslashXpayDCuda(out, gauge, tmp, 1, dagger, in, kappa2);
}
} else if (in.precision == QUDA_SINGLE_PRECISION) {
if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashSCuda(tmp, gauge, in, 1, 0);
dslashXpaySCuda(out, gauge, tmp, 0, 0, in, kappa2);
dslashSCuda(tmp, gauge, in, 1, dagger);
dslashXpaySCuda(out, gauge, tmp, 0, dagger, in, kappa2);
} else {
dslashSCuda(tmp, gauge, in, 0, 0);
dslashXpaySCuda(out, gauge, tmp, 1, 0, in, kappa2);
dslashSCuda(tmp, gauge, in, 0, dagger);
dslashXpaySCuda(out, gauge, tmp, 1, dagger, in, kappa2);
}
} else if (in.precision == QUDA_HALF_PRECISION) {
if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashHCuda(tmp, gauge, in, 1, 0);
dslashXpayHCuda(out, gauge, tmp, 0, 0, in, kappa2);
dslashHCuda(tmp, gauge, in, 1, dagger);
dslashXpayHCuda(out, gauge, tmp, 0, dagger, in, kappa2);
} else {
dslashHCuda(tmp, gauge, in, 0, 0);
dslashXpayHCuda(out, gauge, tmp, 1, 0, in, kappa2);
}
}
}
// Apply the even-odd preconditioned Dirac operator
void MatPCDagCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in, double kappa,
ParitySpinor tmp, MatPCType matpc_type) {
checkSpinor(in, out);
checkSpinor(in, tmp);
double kappa2 = -kappa*kappa;
if (in.precision == QUDA_DOUBLE_PRECISION) {
if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashDCuda(tmp, gauge, in, 1, 1);
dslashXpayDCuda(out, gauge, tmp, 0, 1, in, kappa2);
} else {
dslashDCuda(tmp, gauge, in, 0, 1);
dslashXpayDCuda(out, gauge, tmp, 1, 1, in, kappa2);
}
} else if (in.precision == QUDA_SINGLE_PRECISION) {
if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashSCuda(tmp, gauge, in, 1, 1);
dslashXpaySCuda(out, gauge, tmp, 0, 1, in, kappa2);
} else {
dslashSCuda(tmp, gauge, in, 0, 1);
dslashXpaySCuda(out, gauge, tmp, 1, 1, in, kappa2);
}
} else {
if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashHCuda(tmp, gauge, in, 1, 1);
dslashXpayHCuda(out, gauge, tmp, 0, 1, in, kappa2);
} else {
dslashHCuda(tmp, gauge, in, 0, 1);
dslashXpayHCuda(out, gauge, tmp, 1, 1, in, kappa2);
dslashHCuda(tmp, gauge, in, 0, dagger);
dslashXpayHCuda(out, gauge, tmp, 1, dagger, in, kappa2);
}
}
......@@ -611,43 +574,27 @@ void MatPCDagCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in, double kap
void MatPCDagMatPCCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in,
double kappa, ParitySpinor tmp, MatPCType matpc_type) {
MatPCCuda(out, gauge, in, kappa, tmp, matpc_type);
MatPCDagCuda(out, gauge, out, kappa, tmp, matpc_type);
MatPCCuda(out, gauge, in, kappa, tmp, matpc_type, 0);
MatPCCuda(out, gauge, out, kappa, tmp, matpc_type, 1);
}
// Apply the full operator
void MatCuda(FullSpinor out, FullGauge gauge, FullSpinor in, double kappa) {
void MatCuda(FullSpinor out, FullGauge gauge, FullSpinor in, double kappa, int dagger) {
checkSpinor(in.even, out.even);
if (in.even.precision == QUDA_DOUBLE_PRECISION) {
dslashXpayDCuda(out.odd, gauge, in.even, 1, 0, in.odd, -kappa);
dslashXpayDCuda(out.even, gauge, in.odd, 0, 0, in.even, -kappa);
dslashXpayDCuda(out.odd, gauge, in.even, 1, dagger, in.odd, -kappa);
dslashXpayDCuda(out.even, gauge, in.odd, 0, dagger, in.even, -kappa);
} else if (in.even.precision == QUDA_SINGLE_PRECISION) {
dslashXpaySCuda(out.odd, gauge, in.even, 1, 0, in.odd, -kappa);
dslashXpaySCuda(out.even, gauge, in.odd, 0, 0, in.even, -kappa);
dslashXpaySCuda(out.odd, gauge, in.even, 1, dagger, in.odd, -kappa);
dslashXpaySCuda(out.even, gauge, in.odd, 0, dagger, in.even, -kappa);
} else if (in.even.precision == QUDA_HALF_PRECISION) {
dslashXpayHCuda(out.odd, gauge, in.even, 1, 0, in.odd, -kappa);
dslashXpayHCuda(out.even, gauge, in.odd, 0, 0, in.even, -kappa);
dslashXpayHCuda(out.odd, gauge, in.even, 1, dagger, in.odd, -kappa);
dslashXpayHCuda(out.even, gauge, in.odd, 0, dagger, in.even, -kappa);
}
}
// Apply the full operator dagger
void MatDaggerCuda(FullSpinor out, FullGauge gauge, FullSpinor in, double kappa) {
checkSpinor(in.even, out.even);
if (in.even.precision == QUDA_SINGLE_PRECISION) {
dslashXpayDCuda(out.odd, gauge, in.even, 1, 1, in.odd, -kappa);
dslashXpayDCuda(out.even, gauge, in.odd, 0, 1, in.even, -kappa);
} else if (in.even.precision == QUDA_SINGLE_PRECISION) {
dslashXpaySCuda(out.odd, gauge, in.even, 1, 1, in.odd, -kappa);
dslashXpaySCuda(out.even, gauge, in.odd, 0, 1, in.even, -kappa);
} else if (in.even.precision == QUDA_HALF_PRECISION) {
dslashXpayHCuda(out.odd, gauge, in.even, 1, 1, in.odd, -kappa);
dslashXpayHCuda(out.even, gauge, in.odd, 0, 1, in.even, -kappa);
}
}
/*
// Apply the even-odd preconditioned Dirac operator
......
......@@ -59,13 +59,11 @@ extern "C" {
ParitySpinor x, double a);
// Full Wilson matrix
void MatCuda(FullSpinor out, FullGauge gauge, FullSpinor in, double kappa);
void MatDagCuda(FullSpinor out, FullGauge gauge, FullSpinor in, double kappa);
void MatCuda(FullSpinor out, FullGauge gauge, FullSpinor in, double kappa, int daggerBit);
void MatPCCuda(ParitySpinor outEven, FullGauge gauge, ParitySpinor inEven,
double kappa, ParitySpinor tmp, MatPCType matpc_type);
void MatPCDagCuda(ParitySpinor outEven, FullGauge gauge, ParitySpinor inEven,
double kappa, ParitySpinor tmp, MatPCType matpc_type);
double kappa, ParitySpinor tmp, MatPCType matpc_type, int daggerBit);
void MatPCDagMatPCCuda(ParitySpinor outEven, FullGauge gauge, ParitySpinor inEven,
double kappa, ParitySpinor tmp, MatPCType matpc_type);
......
......@@ -8,7 +8,7 @@
#include <gauge_quda.h>
// What test are we doing (0 = dslash, 1 = MatPC, 2 = Mat)
int test_type = 1;
int test_type = 2;
QudaGaugeParam gaugeParam;
QudaInvertParam inv_param;
......@@ -25,7 +25,7 @@ void *spinorEven, *spinorOdd;
double kappa = 1.0;
int ODD_BIT = 0;
int DAGGER_BIT = 0;
int TRANSFER = 0; // include transfer time in the benchmark?
int TRANSFER = 1; // include transfer time in the benchmark?
void init() {
......@@ -126,12 +126,12 @@ double dslashCUDA() {
else dslashCuda(cudaSpinor.odd, gauge, cudaSpinor.even, ODD_BIT, DAGGER_BIT);
break;
case 1:
if (TRANSFER) MatPCQuda(spinorOdd, spinorEven, &inv_param);
else MatPCCuda(cudaSpinor.odd, gauge, cudaSpinor.even, kappa, tmp, QUDA_MATPC_EVEN_EVEN);
if (TRANSFER) MatPCQuda(spinorOdd, spinorEven, &inv_param, DAGGER_BIT);
else MatPCCuda(cudaSpinor.odd, gauge, cudaSpinor.even, kappa, tmp, QUDA_MATPC_EVEN_EVEN, DAGGER_BIT);
break;
case 2:
if (TRANSFER) MatQuda(spinorGPU, spinor, &inv_param);
else MatCuda(cudaSpinorOut, gauge, cudaSpinor, kappa);
if (TRANSFER) MatQuda(spinorGPU, spinor, &inv_param, DAGGER_BIT);
else MatCuda(cudaSpinorOut, gauge, cudaSpinor, kappa, DAGGER_BIT);
}
}
......
......@@ -36,7 +36,11 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy,
zeroCuda(x_sloppy);
copyCuda(b, src);
copyCuda(r_sloppy, src_sloppy);
copyCuda(r_sloppy, src);
/*MatPCDagCuda(y, gaugePrecise, src, invert_param->kappa, tmp, invert_param->matpc_type);
copyCuda(src_sloppy, y);*/ // uncomment for BiCRstab
zeroCuda(y);
double b2 = normCuda(b);
......@@ -73,7 +77,7 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy,
while (r2 > stop && k<invert_param->maxiter) {
if (k==0) {
rho = make_cuDoubleComplex(r2, 0.0);
rho = make_cuDoubleComplex(r2, 0.0); // cDotProductCuda(src_sloppy, r_sloppy); // BiCRstab
copyCuda(p, r_sloppy);
} else {
alpha_omega = cuCdiv(alpha, omega);
......@@ -85,11 +89,9 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy,
cxpaypbzCuda(r_sloppy, beta_omega, v, beta, p); // 8
}
if (dag_type == QUDA_DAG_NO)
MatPCCuda(v, gaugeSloppy, p, invert_param->kappa, tmp_sloppy, invert_param->matpc_type);
else
MatPCDagCuda(v, gaugeSloppy, p, invert_param->kappa, tmp_sloppy, invert_param->matpc_type);
MatPCCuda(v, gaugeSloppy, p, invert_param->kappa, tmp_sloppy, invert_param->matpc_type, dag_type);
// rv = (r0,v)
rv = cDotProductCuda(src_sloppy, v);
alpha = cuCdiv(rho, rv);
......@@ -99,10 +101,7 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy,
caxpyCuda(alpha, v, r_sloppy); // 4
alpha.x *= -1.0; alpha.y *= -1.0;
if (dag_type == QUDA_DAG_NO)
MatPCCuda(t, gaugeSloppy, r_sloppy, invert_param->kappa, tmp_sloppy, invert_param->matpc_type);
else
MatPCDagCuda(t, gaugeSloppy, r_sloppy, invert_param->kappa, tmp_sloppy, invert_param->matpc_type);
MatPCCuda(t, gaugeSloppy, r_sloppy, invert_param->kappa, tmp_sloppy, invert_param->matpc_type, dag_type);
// omega = (t, r) / (t, t)
omega_t2 = cDotProductNormACuda(t, r_sloppy); // 6
......@@ -122,10 +121,7 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy,
if (updateR) {
if (x.precision != x_sloppy.precision) copyCuda(x, x_sloppy);
if (dag_type == QUDA_DAG_NO)
MatPCCuda(r, gaugePrecise, x, invert_param->kappa, tmp, invert_param->matpc_type);
else
MatPCDagCuda(r, gaugePrecise, x, invert_param->kappa, tmp, invert_param->matpc_type);
MatPCCuda(r, gaugePrecise, x, invert_param->kappa, tmp, invert_param->matpc_type, dag_type);
r2 = xmyNormCuda(b, r);
if (x.precision != r_sloppy.precision) copyCuda(r_sloppy, r);
......@@ -168,10 +164,7 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy,
#if 0
// Calculate the true residual
if (dag_type == QUDA_DAG_NO)
MatPCCuda(r, gaugePrecise, x, invert_param->kappa, tmp, invert_param->matpc_type);
else
MatPCDagCuda(r, gaugePrecise, x, invert_param->kappa, tmp, invert_param->matpc_type);
MatPCCuda(r, gaugePrecise, x, invert_param->kappa, tmp, invert_param->matpc_type, dag_type);
double true_res = xmyNormCuda(src, r);
printf("Converged after %d iterations, r2 = %e, true_r2 = %e\n", k, sqrt(r2/b2), sqrt(true_res / b2));
......
......@@ -134,29 +134,14 @@ void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int parity,
freeParitySpinor(in);
}
void MatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param)
void MatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int dagger)
{
ParitySpinor in = allocateParitySpinor(Nh, inv_param->cuda_prec);
ParitySpinor out = allocateParitySpinor(Nh, inv_param->cuda_prec);
ParitySpinor tmp = allocateParitySpinor(Nh, inv_param->cuda_prec);
loadParitySpinor(in, h_in, inv_param->cpu_prec, inv_param->dirac_order);
MatPCCuda(out, cudaGaugePrecise, in, inv_param->kappa, tmp, inv_param->matpc_type);
retrieveParitySpinor(h_out, out, inv_param->cpu_prec, inv_param->dirac_order);
freeParitySpinor(tmp);
freeParitySpinor(out);
freeParitySpinor(in);
}
void MatPCDagQuda(void *h_out, void *h_in, QudaInvertParam *inv_param)
{
ParitySpinor in = allocateParitySpinor(Nh, inv_param->cuda_prec);
ParitySpinor out = allocateParitySpinor(Nh, inv_param->cuda_prec);
ParitySpinor tmp = allocateParitySpinor(Nh, inv_param->cuda_prec);
loadParitySpinor(in, h_in, inv_param->cpu_prec, inv_param->dirac_order);
MatPCDagCuda(out, cudaGaugePrecise, in, inv_param->kappa, tmp, inv_param->matpc_type);
MatPCCuda(out, cudaGaugePrecise, in, inv_param->kappa, tmp, inv_param->matpc_type, dagger);
retrieveParitySpinor(h_out, out, inv_param->cpu_prec, inv_param->dirac_order);
freeParitySpinor(tmp);
......@@ -179,29 +164,14 @@ void MatPCDagMatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param)
freeParitySpinor(in);
}
void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) {
FullSpinor in = allocateSpinorField(N, inv_param->cuda_prec);
FullSpinor out = allocateSpinorField(N, inv_param->cuda_prec);
loadSpinorField(in, h_in, inv_param->cpu_prec, inv_param->dirac_order);
dslashXpayCuda(out.odd, cudaGaugePrecise, in.even, 1, 0, in.odd, -inv_param->kappa);
dslashXpayCuda(out.even, cudaGaugePrecise, in.odd, 0, 0, in.even, -inv_param->kappa);
retrieveSpinorField(h_out, out, inv_param->cpu_prec, inv_param->dirac_order);
freeSpinorField(out);
freeSpinorField(in);
}
void MatDagQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) {
void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int dagger) {
FullSpinor in = allocateSpinorField(N, inv_param->cuda_prec);
FullSpinor out = allocateSpinorField(N, inv_param->cuda_prec);
loadSpinorField(in, h_in, inv_param->cpu_prec, inv_param->dirac_order);
dslashXpayCuda(out.odd, cudaGaugePrecise, in.even, 1, 1, in.odd, -inv_param->kappa);
dslashXpayCuda(out.even, cudaGaugePrecise, in.odd, 0, 1, in.even, -inv_param->kappa);
dslashXpayCuda(out.odd, cudaGaugePrecise, in.even, 1, dagger, in.odd, -inv_param->kappa);
dslashXpayCuda(out.even, cudaGaugePrecise, in.odd, 0, dagger, in.even, -inv_param->kappa);
retrieveSpinorField(h_out, out, inv_param->cpu_prec, inv_param->dirac_order);
......@@ -277,7 +247,7 @@ void invertQuda(void *h_x, void *h_b, QudaInvertParam *param)
case QUDA_CG_INVERTER:
if (param->solution_type != QUDA_MATPCDAG_MATPC_SOLUTION) {
copyCuda(out, in);
MatPCDagCuda(in, cudaGaugePrecise, out, kappa, tmp, param->matpc_type);
MatPCCuda(in, cudaGaugePrecise, out, kappa, tmp, param->matpc_type, QUDA_DAG_YES);
}
invertCgCuda(out, in, cudaGaugeSloppy, tmp, param);
break;
......
......@@ -70,12 +70,10 @@ extern "C" {
void invertQuda(void *h_x, void *h_b, QudaInvertParam *param);
void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int parity, int dagger);
void MatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param);
void MatPCDagQuda(void *h_out, void *h_in, QudaInvertParam *inv_param);
void MatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int dagger);
void MatPCDagMatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param);
void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param);
void MatDagQuda(void *h_out, void *h_in, QudaInvertParam *inv_param);
void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int dagger);
void endQuda(void);
......
......@@ -20,7 +20,7 @@ int main(int argc, char **argv)
Gauge_param.cuda_prec = QUDA_DOUBLE_PRECISION;
Gauge_param.reconstruct = QUDA_RECONSTRUCT_12;
Gauge_param.cuda_prec_sloppy = QUDA_DOUBLE_PRECISION;
Gauge_param.cuda_prec_sloppy = QUDA_SINGLE_PRECISION;
Gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_12;
Gauge_param.gauge_fix = QUDA_GAUGE_FIXED_NO;
......@@ -38,13 +38,13 @@ int main(int argc, char **argv)
double mass = -0.97;
inv_param.kappa = 1.0 / (2.0*(4 + mass));
inv_param.tol = 1e-7;
inv_param.maxiter = 5000;
inv_param.tol = 1e-12;
inv_param.maxiter = 10000;
inv_param.reliable_delta = 1e-2;
inv_param.mass_normalization = QUDA_KAPPA_NORMALIZATION;
inv_param.cpu_prec = QUDA_DOUBLE_PRECISION;
inv_param.cuda_prec = QUDA_DOUBLE_PRECISION;
inv_param.cuda_prec_sloppy = QUDA_DOUBLE_PRECISION;
inv_param.cuda_prec_sloppy = QUDA_SINGLE_PRECISION;
inv_param.solution_type = QUDA_MAT_SOLUTION;
inv_param.matpc_type = QUDA_MATPC_EVEN_EVEN;
inv_param.preserve_source = QUDA_PRESERVE_SOURCE_NO;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment