Advanced Computing Platform for Theoretical Physics

Commit 72aac654 authored by mikeaclark's avatar mikeaclark
Browse files

Clean up of spinor_quda.cpp

git-svn-id: http://lattice.bu.edu/qcdalg/cuda/quda@405 be54200a-260c-0410-bdd7-ce6af2a381ab
parent 937f5509
...@@ -537,7 +537,7 @@ int dslashCudaSharedBytes() { ...@@ -537,7 +537,7 @@ int dslashCudaSharedBytes() {
// Apply the even-odd preconditioned Dirac operator // Apply the even-odd preconditioned Dirac operator
void MatPCCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in, double kappa, void MatPCCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in, double kappa,
ParitySpinor tmp, MatPCType matpc_type) { ParitySpinor tmp, MatPCType matpc_type, int dagger) {
checkSpinor(in, out); checkSpinor(in, out);
checkSpinor(in, tmp); checkSpinor(in, tmp);
...@@ -546,64 +546,27 @@ void MatPCCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in, double kappa, ...@@ -546,64 +546,27 @@ void MatPCCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in, double kappa,
if (in.precision == QUDA_DOUBLE_PRECISION) { if (in.precision == QUDA_DOUBLE_PRECISION) {
if (matpc_type == QUDA_MATPC_EVEN_EVEN) { if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashDCuda(tmp, gauge, in, 1, 0); dslashDCuda(tmp, gauge, in, 1, dagger);
dslashXpayDCuda(out, gauge, tmp, 0, 0, in, kappa2); dslashXpayDCuda(out, gauge, tmp, 0, dagger, in, kappa2);
} else { } else {
dslashDCuda(tmp, gauge, in, 0, 0); dslashDCuda(tmp, gauge, in, 0, dagger);
dslashXpayDCuda(out, gauge, tmp, 1, 0, in, kappa2); dslashXpayDCuda(out, gauge, tmp, 1, dagger, in, kappa2);
} }
} else if (in.precision == QUDA_SINGLE_PRECISION) { } else if (in.precision == QUDA_SINGLE_PRECISION) {
if (matpc_type == QUDA_MATPC_EVEN_EVEN) { if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashSCuda(tmp, gauge, in, 1, 0); dslashSCuda(tmp, gauge, in, 1, dagger);
dslashXpaySCuda(out, gauge, tmp, 0, 0, in, kappa2); dslashXpaySCuda(out, gauge, tmp, 0, dagger, in, kappa2);
} else { } else {
dslashSCuda(tmp, gauge, in, 0, 0); dslashSCuda(tmp, gauge, in, 0, dagger);
dslashXpaySCuda(out, gauge, tmp, 1, 0, in, kappa2); dslashXpaySCuda(out, gauge, tmp, 1, dagger, in, kappa2);
} }
} else if (in.precision == QUDA_HALF_PRECISION) { } else if (in.precision == QUDA_HALF_PRECISION) {
if (matpc_type == QUDA_MATPC_EVEN_EVEN) { if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashHCuda(tmp, gauge, in, 1, 0); dslashHCuda(tmp, gauge, in, 1, dagger);
dslashXpayHCuda(out, gauge, tmp, 0, 0, in, kappa2); dslashXpayHCuda(out, gauge, tmp, 0, dagger, in, kappa2);
} else { } else {
dslashHCuda(tmp, gauge, in, 0, 0); dslashHCuda(tmp, gauge, in, 0, dagger);
dslashXpayHCuda(out, gauge, tmp, 1, 0, in, kappa2); dslashXpayHCuda(out, gauge, tmp, 1, dagger, in, kappa2);
}
}
}
// Apply the even-odd preconditioned Dirac operator
void MatPCDagCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in, double kappa,
ParitySpinor tmp, MatPCType matpc_type) {
checkSpinor(in, out);
checkSpinor(in, tmp);
double kappa2 = -kappa*kappa;
if (in.precision == QUDA_DOUBLE_PRECISION) {
if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashDCuda(tmp, gauge, in, 1, 1);
dslashXpayDCuda(out, gauge, tmp, 0, 1, in, kappa2);
} else {
dslashDCuda(tmp, gauge, in, 0, 1);
dslashXpayDCuda(out, gauge, tmp, 1, 1, in, kappa2);
}
} else if (in.precision == QUDA_SINGLE_PRECISION) {
if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashSCuda(tmp, gauge, in, 1, 1);
dslashXpaySCuda(out, gauge, tmp, 0, 1, in, kappa2);
} else {
dslashSCuda(tmp, gauge, in, 0, 1);
dslashXpaySCuda(out, gauge, tmp, 1, 1, in, kappa2);
}
} else {
if (matpc_type == QUDA_MATPC_EVEN_EVEN) {
dslashHCuda(tmp, gauge, in, 1, 1);
dslashXpayHCuda(out, gauge, tmp, 0, 1, in, kappa2);
} else {
dslashHCuda(tmp, gauge, in, 0, 1);
dslashXpayHCuda(out, gauge, tmp, 1, 1, in, kappa2);
} }
} }
...@@ -611,43 +574,27 @@ void MatPCDagCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in, double kap ...@@ -611,43 +574,27 @@ void MatPCDagCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in, double kap
void MatPCDagMatPCCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in, void MatPCDagMatPCCuda(ParitySpinor out, FullGauge gauge, ParitySpinor in,
double kappa, ParitySpinor tmp, MatPCType matpc_type) { double kappa, ParitySpinor tmp, MatPCType matpc_type) {
MatPCCuda(out, gauge, in, kappa, tmp, matpc_type); MatPCCuda(out, gauge, in, kappa, tmp, matpc_type, 0);
MatPCDagCuda(out, gauge, out, kappa, tmp, matpc_type); MatPCCuda(out, gauge, out, kappa, tmp, matpc_type, 1);
} }
// Apply the full operator // Apply the full operator
void MatCuda(FullSpinor out, FullGauge gauge, FullSpinor in, double kappa) { void MatCuda(FullSpinor out, FullGauge gauge, FullSpinor in, double kappa, int dagger) {
checkSpinor(in.even, out.even); checkSpinor(in.even, out.even);
if (in.even.precision == QUDA_DOUBLE_PRECISION) { if (in.even.precision == QUDA_DOUBLE_PRECISION) {
dslashXpayDCuda(out.odd, gauge, in.even, 1, 0, in.odd, -kappa); dslashXpayDCuda(out.odd, gauge, in.even, 1, dagger, in.odd, -kappa);
dslashXpayDCuda(out.even, gauge, in.odd, 0, 0, in.even, -kappa); dslashXpayDCuda(out.even, gauge, in.odd, 0, dagger, in.even, -kappa);
} else if (in.even.precision == QUDA_SINGLE_PRECISION) { } else if (in.even.precision == QUDA_SINGLE_PRECISION) {
dslashXpaySCuda(out.odd, gauge, in.even, 1, 0, in.odd, -kappa); dslashXpaySCuda(out.odd, gauge, in.even, 1, dagger, in.odd, -kappa);
dslashXpaySCuda(out.even, gauge, in.odd, 0, 0, in.even, -kappa); dslashXpaySCuda(out.even, gauge, in.odd, 0, dagger, in.even, -kappa);
} else if (in.even.precision == QUDA_HALF_PRECISION) { } else if (in.even.precision == QUDA_HALF_PRECISION) {
dslashXpayHCuda(out.odd, gauge, in.even, 1, 0, in.odd, -kappa); dslashXpayHCuda(out.odd, gauge, in.even, 1, dagger, in.odd, -kappa);
dslashXpayHCuda(out.even, gauge, in.odd, 0, 0, in.even, -kappa); dslashXpayHCuda(out.even, gauge, in.odd, 0, dagger, in.even, -kappa);
} }
} }
// Apply the full operator dagger
void MatDaggerCuda(FullSpinor out, FullGauge gauge, FullSpinor in, double kappa) {
checkSpinor(in.even, out.even);
if (in.even.precision == QUDA_SINGLE_PRECISION) {
dslashXpayDCuda(out.odd, gauge, in.even, 1, 1, in.odd, -kappa);
dslashXpayDCuda(out.even, gauge, in.odd, 0, 1, in.even, -kappa);
} else if (in.even.precision == QUDA_SINGLE_PRECISION) {
dslashXpaySCuda(out.odd, gauge, in.even, 1, 1, in.odd, -kappa);
dslashXpaySCuda(out.even, gauge, in.odd, 0, 1, in.even, -kappa);
} else if (in.even.precision == QUDA_HALF_PRECISION) {
dslashXpayHCuda(out.odd, gauge, in.even, 1, 1, in.odd, -kappa);
dslashXpayHCuda(out.even, gauge, in.odd, 0, 1, in.even, -kappa);
}
}
/* /*
// Apply the even-odd preconditioned Dirac operator // Apply the even-odd preconditioned Dirac operator
......
...@@ -59,13 +59,11 @@ extern "C" { ...@@ -59,13 +59,11 @@ extern "C" {
ParitySpinor x, double a); ParitySpinor x, double a);
// Full Wilson matrix // Full Wilson matrix
void MatCuda(FullSpinor out, FullGauge gauge, FullSpinor in, double kappa); void MatCuda(FullSpinor out, FullGauge gauge, FullSpinor in, double kappa, int daggerBit);
void MatDagCuda(FullSpinor out, FullGauge gauge, FullSpinor in, double kappa);
void MatPCCuda(ParitySpinor outEven, FullGauge gauge, ParitySpinor inEven, void MatPCCuda(ParitySpinor outEven, FullGauge gauge, ParitySpinor inEven,
double kappa, ParitySpinor tmp, MatPCType matpc_type); double kappa, ParitySpinor tmp, MatPCType matpc_type, int daggerBit);
void MatPCDagCuda(ParitySpinor outEven, FullGauge gauge, ParitySpinor inEven,
double kappa, ParitySpinor tmp, MatPCType matpc_type);
void MatPCDagMatPCCuda(ParitySpinor outEven, FullGauge gauge, ParitySpinor inEven, void MatPCDagMatPCCuda(ParitySpinor outEven, FullGauge gauge, ParitySpinor inEven,
double kappa, ParitySpinor tmp, MatPCType matpc_type); double kappa, ParitySpinor tmp, MatPCType matpc_type);
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include <gauge_quda.h> #include <gauge_quda.h>
// What test are we doing (0 = dslash, 1 = MatPC, 2 = Mat) // What test are we doing (0 = dslash, 1 = MatPC, 2 = Mat)
int test_type = 1; int test_type = 2;
QudaGaugeParam gaugeParam; QudaGaugeParam gaugeParam;
QudaInvertParam inv_param; QudaInvertParam inv_param;
...@@ -25,7 +25,7 @@ void *spinorEven, *spinorOdd; ...@@ -25,7 +25,7 @@ void *spinorEven, *spinorOdd;
double kappa = 1.0; double kappa = 1.0;
int ODD_BIT = 0; int ODD_BIT = 0;
int DAGGER_BIT = 0; int DAGGER_BIT = 0;
int TRANSFER = 0; // include transfer time in the benchmark? int TRANSFER = 1; // include transfer time in the benchmark?
void init() { void init() {
...@@ -126,12 +126,12 @@ double dslashCUDA() { ...@@ -126,12 +126,12 @@ double dslashCUDA() {
else dslashCuda(cudaSpinor.odd, gauge, cudaSpinor.even, ODD_BIT, DAGGER_BIT); else dslashCuda(cudaSpinor.odd, gauge, cudaSpinor.even, ODD_BIT, DAGGER_BIT);
break; break;
case 1: case 1:
if (TRANSFER) MatPCQuda(spinorOdd, spinorEven, &inv_param); if (TRANSFER) MatPCQuda(spinorOdd, spinorEven, &inv_param, DAGGER_BIT);
else MatPCCuda(cudaSpinor.odd, gauge, cudaSpinor.even, kappa, tmp, QUDA_MATPC_EVEN_EVEN); else MatPCCuda(cudaSpinor.odd, gauge, cudaSpinor.even, kappa, tmp, QUDA_MATPC_EVEN_EVEN, DAGGER_BIT);
break; break;
case 2: case 2:
if (TRANSFER) MatQuda(spinorGPU, spinor, &inv_param); if (TRANSFER) MatQuda(spinorGPU, spinor, &inv_param, DAGGER_BIT);
else MatCuda(cudaSpinorOut, gauge, cudaSpinor, kappa); else MatCuda(cudaSpinorOut, gauge, cudaSpinor, kappa, DAGGER_BIT);
} }
} }
......
...@@ -36,7 +36,11 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy, ...@@ -36,7 +36,11 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy,
zeroCuda(x_sloppy); zeroCuda(x_sloppy);
copyCuda(b, src); copyCuda(b, src);
copyCuda(r_sloppy, src_sloppy); copyCuda(r_sloppy, src);
/*MatPCDagCuda(y, gaugePrecise, src, invert_param->kappa, tmp, invert_param->matpc_type);
copyCuda(src_sloppy, y);*/ // uncomment for BiCRstab
zeroCuda(y); zeroCuda(y);
double b2 = normCuda(b); double b2 = normCuda(b);
...@@ -73,7 +77,7 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy, ...@@ -73,7 +77,7 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy,
while (r2 > stop && k<invert_param->maxiter) { while (r2 > stop && k<invert_param->maxiter) {
if (k==0) { if (k==0) {
rho = make_cuDoubleComplex(r2, 0.0); rho = make_cuDoubleComplex(r2, 0.0); // cDotProductCuda(src_sloppy, r_sloppy); // BiCRstab
copyCuda(p, r_sloppy); copyCuda(p, r_sloppy);
} else { } else {
alpha_omega = cuCdiv(alpha, omega); alpha_omega = cuCdiv(alpha, omega);
...@@ -85,11 +89,9 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy, ...@@ -85,11 +89,9 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy,
cxpaypbzCuda(r_sloppy, beta_omega, v, beta, p); // 8 cxpaypbzCuda(r_sloppy, beta_omega, v, beta, p); // 8
} }
if (dag_type == QUDA_DAG_NO) MatPCCuda(v, gaugeSloppy, p, invert_param->kappa, tmp_sloppy, invert_param->matpc_type, dag_type);
MatPCCuda(v, gaugeSloppy, p, invert_param->kappa, tmp_sloppy, invert_param->matpc_type);
else
MatPCDagCuda(v, gaugeSloppy, p, invert_param->kappa, tmp_sloppy, invert_param->matpc_type);
// rv = (r0,v)
rv = cDotProductCuda(src_sloppy, v); rv = cDotProductCuda(src_sloppy, v);
alpha = cuCdiv(rho, rv); alpha = cuCdiv(rho, rv);
...@@ -99,10 +101,7 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy, ...@@ -99,10 +101,7 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy,
caxpyCuda(alpha, v, r_sloppy); // 4 caxpyCuda(alpha, v, r_sloppy); // 4
alpha.x *= -1.0; alpha.y *= -1.0; alpha.x *= -1.0; alpha.y *= -1.0;
if (dag_type == QUDA_DAG_NO) MatPCCuda(t, gaugeSloppy, r_sloppy, invert_param->kappa, tmp_sloppy, invert_param->matpc_type, dag_type);
MatPCCuda(t, gaugeSloppy, r_sloppy, invert_param->kappa, tmp_sloppy, invert_param->matpc_type);
else
MatPCDagCuda(t, gaugeSloppy, r_sloppy, invert_param->kappa, tmp_sloppy, invert_param->matpc_type);
// omega = (t, r) / (t, t) // omega = (t, r) / (t, t)
omega_t2 = cDotProductNormACuda(t, r_sloppy); // 6 omega_t2 = cDotProductNormACuda(t, r_sloppy); // 6
...@@ -122,10 +121,7 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy, ...@@ -122,10 +121,7 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy,
if (updateR) { if (updateR) {
if (x.precision != x_sloppy.precision) copyCuda(x, x_sloppy); if (x.precision != x_sloppy.precision) copyCuda(x, x_sloppy);
if (dag_type == QUDA_DAG_NO) MatPCCuda(r, gaugePrecise, x, invert_param->kappa, tmp, invert_param->matpc_type, dag_type);
MatPCCuda(r, gaugePrecise, x, invert_param->kappa, tmp, invert_param->matpc_type);
else
MatPCDagCuda(r, gaugePrecise, x, invert_param->kappa, tmp, invert_param->matpc_type);
r2 = xmyNormCuda(b, r); r2 = xmyNormCuda(b, r);
if (x.precision != r_sloppy.precision) copyCuda(r_sloppy, r); if (x.precision != r_sloppy.precision) copyCuda(r_sloppy, r);
...@@ -168,10 +164,7 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy, ...@@ -168,10 +164,7 @@ void invertBiCGstabCuda(ParitySpinor x, ParitySpinor src, FullGauge gaugeSloppy,
#if 0 #if 0
// Calculate the true residual // Calculate the true residual
if (dag_type == QUDA_DAG_NO) MatPCCuda(r, gaugePrecise, x, invert_param->kappa, tmp, invert_param->matpc_type, dag_type);
MatPCCuda(r, gaugePrecise, x, invert_param->kappa, tmp, invert_param->matpc_type);
else
MatPCDagCuda(r, gaugePrecise, x, invert_param->kappa, tmp, invert_param->matpc_type);
double true_res = xmyNormCuda(src, r); double true_res = xmyNormCuda(src, r);
printf("Converged after %d iterations, r2 = %e, true_r2 = %e\n", k, sqrt(r2/b2), sqrt(true_res / b2)); printf("Converged after %d iterations, r2 = %e, true_r2 = %e\n", k, sqrt(r2/b2), sqrt(true_res / b2));
......
...@@ -134,29 +134,14 @@ void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int parity, ...@@ -134,29 +134,14 @@ void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int parity,
freeParitySpinor(in); freeParitySpinor(in);
} }
void MatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) void MatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int dagger)
{ {
ParitySpinor in = allocateParitySpinor(Nh, inv_param->cuda_prec); ParitySpinor in = allocateParitySpinor(Nh, inv_param->cuda_prec);
ParitySpinor out = allocateParitySpinor(Nh, inv_param->cuda_prec); ParitySpinor out = allocateParitySpinor(Nh, inv_param->cuda_prec);
ParitySpinor tmp = allocateParitySpinor(Nh, inv_param->cuda_prec); ParitySpinor tmp = allocateParitySpinor(Nh, inv_param->cuda_prec);
loadParitySpinor(in, h_in, inv_param->cpu_prec, inv_param->dirac_order); loadParitySpinor(in, h_in, inv_param->cpu_prec, inv_param->dirac_order);
MatPCCuda(out, cudaGaugePrecise, in, inv_param->kappa, tmp, inv_param->matpc_type); MatPCCuda(out, cudaGaugePrecise, in, inv_param->kappa, tmp, inv_param->matpc_type, dagger);
retrieveParitySpinor(h_out, out, inv_param->cpu_prec, inv_param->dirac_order);
freeParitySpinor(tmp);
freeParitySpinor(out);
freeParitySpinor(in);
}
void MatPCDagQuda(void *h_out, void *h_in, QudaInvertParam *inv_param)
{
ParitySpinor in = allocateParitySpinor(Nh, inv_param->cuda_prec);
ParitySpinor out = allocateParitySpinor(Nh, inv_param->cuda_prec);
ParitySpinor tmp = allocateParitySpinor(Nh, inv_param->cuda_prec);
loadParitySpinor(in, h_in, inv_param->cpu_prec, inv_param->dirac_order);
MatPCDagCuda(out, cudaGaugePrecise, in, inv_param->kappa, tmp, inv_param->matpc_type);
retrieveParitySpinor(h_out, out, inv_param->cpu_prec, inv_param->dirac_order); retrieveParitySpinor(h_out, out, inv_param->cpu_prec, inv_param->dirac_order);
freeParitySpinor(tmp); freeParitySpinor(tmp);
...@@ -179,29 +164,14 @@ void MatPCDagMatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) ...@@ -179,29 +164,14 @@ void MatPCDagMatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param)
freeParitySpinor(in); freeParitySpinor(in);
} }
void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) { void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int dagger) {
FullSpinor in = allocateSpinorField(N, inv_param->cuda_prec);
FullSpinor out = allocateSpinorField(N, inv_param->cuda_prec);
loadSpinorField(in, h_in, inv_param->cpu_prec, inv_param->dirac_order);
dslashXpayCuda(out.odd, cudaGaugePrecise, in.even, 1, 0, in.odd, -inv_param->kappa);
dslashXpayCuda(out.even, cudaGaugePrecise, in.odd, 0, 0, in.even, -inv_param->kappa);
retrieveSpinorField(h_out, out, inv_param->cpu_prec, inv_param->dirac_order);
freeSpinorField(out);
freeSpinorField(in);
}
void MatDagQuda(void *h_out, void *h_in, QudaInvertParam *inv_param) {
FullSpinor in = allocateSpinorField(N, inv_param->cuda_prec); FullSpinor in = allocateSpinorField(N, inv_param->cuda_prec);
FullSpinor out = allocateSpinorField(N, inv_param->cuda_prec); FullSpinor out = allocateSpinorField(N, inv_param->cuda_prec);
loadSpinorField(in, h_in, inv_param->cpu_prec, inv_param->dirac_order); loadSpinorField(in, h_in, inv_param->cpu_prec, inv_param->dirac_order);
dslashXpayCuda(out.odd, cudaGaugePrecise, in.even, 1, 1, in.odd, -inv_param->kappa); dslashXpayCuda(out.odd, cudaGaugePrecise, in.even, 1, dagger, in.odd, -inv_param->kappa);
dslashXpayCuda(out.even, cudaGaugePrecise, in.odd, 0, 1, in.even, -inv_param->kappa); dslashXpayCuda(out.even, cudaGaugePrecise, in.odd, 0, dagger, in.even, -inv_param->kappa);
retrieveSpinorField(h_out, out, inv_param->cpu_prec, inv_param->dirac_order); retrieveSpinorField(h_out, out, inv_param->cpu_prec, inv_param->dirac_order);
...@@ -277,7 +247,7 @@ void invertQuda(void *h_x, void *h_b, QudaInvertParam *param) ...@@ -277,7 +247,7 @@ void invertQuda(void *h_x, void *h_b, QudaInvertParam *param)
case QUDA_CG_INVERTER: case QUDA_CG_INVERTER:
if (param->solution_type != QUDA_MATPCDAG_MATPC_SOLUTION) { if (param->solution_type != QUDA_MATPCDAG_MATPC_SOLUTION) {
copyCuda(out, in); copyCuda(out, in);
MatPCDagCuda(in, cudaGaugePrecise, out, kappa, tmp, param->matpc_type); MatPCCuda(in, cudaGaugePrecise, out, kappa, tmp, param->matpc_type, QUDA_DAG_YES);
} }
invertCgCuda(out, in, cudaGaugeSloppy, tmp, param); invertCgCuda(out, in, cudaGaugeSloppy, tmp, param);
break; break;
......
...@@ -70,12 +70,10 @@ extern "C" { ...@@ -70,12 +70,10 @@ extern "C" {
void invertQuda(void *h_x, void *h_b, QudaInvertParam *param); void invertQuda(void *h_x, void *h_b, QudaInvertParam *param);
void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int parity, int dagger); void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int parity, int dagger);
void MatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param); void MatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int dagger);
void MatPCDagQuda(void *h_out, void *h_in, QudaInvertParam *inv_param);
void MatPCDagMatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param); void MatPCDagMatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param);
void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param); void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int dagger);
void MatDagQuda(void *h_out, void *h_in, QudaInvertParam *inv_param);
void endQuda(void); void endQuda(void);
......
...@@ -20,7 +20,7 @@ int main(int argc, char **argv) ...@@ -20,7 +20,7 @@ int main(int argc, char **argv)
Gauge_param.cuda_prec = QUDA_DOUBLE_PRECISION; Gauge_param.cuda_prec = QUDA_DOUBLE_PRECISION;
Gauge_param.reconstruct = QUDA_RECONSTRUCT_12; Gauge_param.reconstruct = QUDA_RECONSTRUCT_12;
Gauge_param.cuda_prec_sloppy = QUDA_DOUBLE_PRECISION; Gauge_param.cuda_prec_sloppy = QUDA_SINGLE_PRECISION;
Gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_12; Gauge_param.reconstruct_sloppy = QUDA_RECONSTRUCT_12;
Gauge_param.gauge_fix = QUDA_GAUGE_FIXED_NO; Gauge_param.gauge_fix = QUDA_GAUGE_FIXED_NO;
...@@ -38,13 +38,13 @@ int main(int argc, char **argv) ...@@ -38,13 +38,13 @@ int main(int argc, char **argv)
double mass = -0.97; double mass = -0.97;
inv_param.kappa = 1.0 / (2.0*(4 + mass)); inv_param.kappa = 1.0 / (2.0*(4 + mass));
inv_param.tol = 1e-7; inv_param.tol = 1e-12;
inv_param.maxiter = 5000; inv_param.maxiter = 10000;
inv_param.reliable_delta = 1e-2; inv_param.reliable_delta = 1e-2;
inv_param.mass_normalization = QUDA_KAPPA_NORMALIZATION; inv_param.mass_normalization = QUDA_KAPPA_NORMALIZATION;
inv_param.cpu_prec = QUDA_DOUBLE_PRECISION; inv_param.cpu_prec = QUDA_DOUBLE_PRECISION;
inv_param.cuda_prec = QUDA_DOUBLE_PRECISION; inv_param.cuda_prec = QUDA_DOUBLE_PRECISION;
inv_param.cuda_prec_sloppy = QUDA_DOUBLE_PRECISION; inv_param.cuda_prec_sloppy = QUDA_SINGLE_PRECISION;
inv_param.solution_type = QUDA_MAT_SOLUTION; inv_param.solution_type = QUDA_MAT_SOLUTION;
inv_param.matpc_type = QUDA_MATPC_EVEN_EVEN; inv_param.matpc_type = QUDA_MATPC_EVEN_EVEN;
inv_param.preserve_source = QUDA_PRESERVE_SOURCE_NO; inv_param.preserve_source = QUDA_PRESERVE_SOURCE_NO;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment