Advanced Computing Platform for Theoretical Physics

invert_quda.cpp 18.5 KB
Newer Older
mikeaclark's avatar
mikeaclark committed
1
2
3
4
5
6
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <cuda_runtime.h>

#include <invert_quda.h>
mikeaclark's avatar
mikeaclark committed
7
#include <quda.h>
mikeaclark's avatar
mikeaclark committed
8
#include <util_quda.h>
9
10
#include <spinor_quda.h>
#include <gauge_quda.h>
mikeaclark's avatar
mikeaclark committed
11

12
13
#include <blas_reference.h>

14
15
FullGauge cudaGaugePrecise; // precise gauge field
FullGauge cudaGaugeSloppy; // sloppy gauge field
16

rbabich's avatar
rbabich committed
17
FullClover cudaCloverPrecise; // clover term
rbabich's avatar
rbabich committed
18
19
FullClover cudaCloverSloppy;

rbabich's avatar
rbabich committed
20
21
22
FullClover cudaCloverInvPrecise; // inverted clover term
FullClover cudaCloverInvSloppy;

mikeaclark's avatar
   
mikeaclark committed
23
24
25
void printGaugeParam(QudaGaugeParam *param) {

  printf("Gauge Params:\n");
26
27
28
  for (int d=0; d<4; d++) {
    printf("X[%d] = %d\n", d, param->X[d]);
  }
mikeaclark's avatar
   
mikeaclark committed
29
30
31
  printf("anisotropy = %e\n", param->anisotropy);
  printf("gauge_order = %d\n", param->gauge_order);
  printf("cpu_prec = %d\n", param->cpu_prec);
32
33
  printf("cuda_prec = %d\n", param->cuda_prec);
  printf("reconstruct = %d\n", param->reconstruct);
34
35
  printf("cuda_prec_sloppy = %d\n", param->cuda_prec_sloppy);
  printf("reconstruct_sloppy = %d\n", param->reconstruct_sloppy);
mikeaclark's avatar
   
mikeaclark committed
36
37
38
39
40
41
42
43
44
  printf("gauge_fix = %d\n", param->gauge_fix);
  printf("t_boundary = %d\n", param->t_boundary);
  printf("packed_size = %d\n", param->packed_size);
  printf("gaugeGiB = %e\n", param->gaugeGiB);
}

void printInvertParam(QudaInvertParam *param) {
  printf("kappa = %e\n", param->kappa);
  printf("mass_normalization = %d\n", param->mass_normalization);
rbabich's avatar
rbabich committed
45
  printf("dslash_type = %d\n", param->dslash_type);
mikeaclark's avatar
   
mikeaclark committed
46
47
48
49
50
51
52
53
54
  printf("inv_type = %d\n", param->inv_type);
  printf("tol = %e\n", param->tol);
  printf("iter = %d\n", param->iter);
  printf("maxiter = %d\n", param->maxiter);
  printf("matpc_type = %d\n", param->matpc_type);
  printf("solution_type = %d\n", param->solution_type);
  printf("preserve_source = %d\n", param->preserve_source);
  printf("cpu_prec = %d\n", param->cpu_prec);
  printf("cuda_prec = %d\n", param->cuda_prec);
rbabich's avatar
rbabich committed
55
  printf("cuda_prec_sloppy = %d\n", param->cuda_prec_sloppy);
mikeaclark's avatar
   
mikeaclark committed
56
57
  printf("dirac_order = %d\n", param->dirac_order);
  printf("spinorGiB = %e\n", param->spinorGiB);
rbabich's avatar
rbabich committed
58
59
60
61
62
63
64
  if (param->dslash_type == QUDA_CLOVER_WILSON_DSLASH) {
    printf("clover_cpu_prec = %d\n", param->clover_cpu_prec);
    printf("clover_cuda_prec = %d\n", param->clover_cuda_prec);
    printf("clover_cuda_prec_sloppy = %d\n", param->clover_cuda_prec_sloppy);
    printf("clover_order = %d\n", param->clover_order);
    printf("cloverGiB = %e\n", param->cloverGiB);
  }
mikeaclark's avatar
   
mikeaclark committed
65
66
  printf("gflops = %e\n", param->gflops);
  printf("secs = %f\n", param->secs);
rbabich's avatar
rbabich committed
67
  printf("verbosity = %d\n", param->verbosity);
mikeaclark's avatar
   
mikeaclark committed
68
69
}

mikeaclark's avatar
mikeaclark committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
void initQuda(int dev)
{
  int deviceCount;
  cudaGetDeviceCount(&deviceCount);
  if (deviceCount == 0) {
    fprintf(stderr, "No devices supporting CUDA.\n");
    exit(EXIT_FAILURE);
  }

  for(int i=0; i<deviceCount; i++) {
    cudaDeviceProp deviceProp;
    cudaGetDeviceProperties(&deviceProp, i);
    fprintf(stderr, "found device %d: %s\n", i, deviceProp.name);
  }

  if(dev<0) {
    dev = deviceCount - 1;
    //dev = 0;
  }

  cudaDeviceProp deviceProp;
  cudaGetDeviceProperties(&deviceProp, dev);
  if (deviceProp.major < 1) {
    fprintf(stderr, "Device %d does not support CUDA.\n", dev);
    exit(EXIT_FAILURE);
  }

  fprintf(stderr, "Using device %d: %s\n", dev, deviceProp.name);
  cudaSetDevice(dev);
mikeaclark's avatar
mikeaclark committed
99

100
101
  cudaGaugePrecise.even = NULL;
  cudaGaugePrecise.odd = NULL;
102

103
104
  cudaGaugeSloppy.even = NULL;
  cudaGaugeSloppy.odd = NULL;
105

rbabich's avatar
rbabich committed
106
107
108
109
110
111
112
113
114
115
116
  cudaCloverPrecise.even.clover = NULL;
  cudaCloverPrecise.odd.clover = NULL;

  cudaCloverSloppy.even.clover = NULL;
  cudaCloverSloppy.odd.clover = NULL;

  cudaCloverInvPrecise.even.clover = NULL;
  cudaCloverInvPrecise.odd.clover = NULL;

  cudaCloverInvSloppy.even.clover = NULL;
  cudaCloverInvSloppy.odd.clover = NULL;
mikeaclark's avatar
mikeaclark committed
117
118
}

mikeaclark's avatar
mikeaclark committed
119
void loadGaugeQuda(void *h_gauge, QudaGaugeParam *param)
mikeaclark's avatar
mikeaclark committed
120
121
{
  gauge_param = param;
122

123
  gauge_param->packed_size = (gauge_param->reconstruct == QUDA_RECONSTRUCT_8) ? 8 : 12;
124

125
  createGaugeField(&cudaGaugePrecise, h_gauge, gauge_param->reconstruct, 
126
		   gauge_param->cuda_prec, gauge_param->X, gauge_param->anisotropy, gauge_param->blockDim);
127
  gauge_param->gaugeGiB = 2.0*cudaGaugePrecise.bytes/ (1 << 30);
128
129
  if (gauge_param->cuda_prec_sloppy != gauge_param->cuda_prec ||
      gauge_param->reconstruct_sloppy != gauge_param->reconstruct) {
130
    createGaugeField(&cudaGaugeSloppy, h_gauge, gauge_param->reconstruct_sloppy, 
131
132
		     gauge_param->cuda_prec_sloppy, gauge_param->X, gauge_param->anisotropy,
		     gauge_param->blockDim_sloppy);
133
    gauge_param->gaugeGiB += 2.0*cudaGaugeSloppy.bytes/ (1 << 30);
134
135
136
  } else {
    cudaGaugeSloppy = cudaGaugePrecise;
  }
mikeaclark's avatar
mikeaclark committed
137
138
}

rbabich's avatar
rbabich committed
139
void loadCloverQuda(void *h_clover, void *h_clovinv, QudaInvertParam *inv_param)
rbabich's avatar
rbabich committed
140
{
rbabich's avatar
rbabich committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
  if (!cudaGaugePrecise.even) {
    printf("QUDA error: loadGaugeQuda() must precede call to loadCloverQuda()\n");
    exit(-1);
  }
  if (!h_clover && !h_clovinv) {
    printf("QUDA error: loadCloverQuda() called with neither clover term nor inverse\n");
    exit(-1);
  }
  if (inv_param->clover_cpu_prec == QUDA_HALF_PRECISION) {
    printf("QUDA error: half precision not supported on CPU\n");
    exit(-1);
  }

  int X[4];
  for (int i=0; i<4; i++) {
    X[i] = gauge_param->X[i];
  }
  X[0] /= 2; // dimensions of the even-odd sublattice

  inv_param->cloverGiB = 0;

  if (h_clover) {
    cudaCloverPrecise = allocateCloverField(X, inv_param->clover_cuda_prec);
    loadCloverField(cudaCloverPrecise, h_clover, inv_param->clover_cuda_prec, inv_param->clover_order);
    inv_param->cloverGiB += 2.0*cudaCloverPrecise.even.bytes / (1<<30);

    if (inv_param->matpc_type == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC ||
	inv_param->matpc_type == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
      if (inv_param->clover_cuda_prec != inv_param->clover_cuda_prec_sloppy) {
	cudaCloverSloppy = allocateCloverField(X, inv_param->clover_cuda_prec_sloppy);
	loadCloverField(cudaCloverSloppy, h_clover, inv_param->clover_cuda_prec_sloppy, inv_param->clover_order);
	inv_param->cloverGiB += 2.0*cudaCloverInvSloppy.even.bytes / (1<<30);
      } else {
	cudaCloverSloppy = cudaCloverPrecise;
      }
    } // sloppy precision clover term not needed otherwise
  }

  cudaCloverInvPrecise = allocateCloverField(X, inv_param->clover_cuda_prec);
  if (!h_clovinv) {
    printf("QUDA error: clover term inverse not implemented yet\n");
    exit(-1);
  } else {
    loadCloverField(cudaCloverInvPrecise, h_clovinv, inv_param->clover_cuda_prec, inv_param->clover_order);
  }
  inv_param->cloverGiB += 2.0*cudaCloverInvPrecise.even.bytes / (1<<30);
rbabich's avatar
rbabich committed
187

rbabich's avatar
rbabich committed
188
189
190
191
192
193
194
195
  if (inv_param->clover_cuda_prec != inv_param->clover_cuda_prec_sloppy) {
    cudaCloverInvSloppy = allocateCloverField(X, inv_param->clover_cuda_prec_sloppy);
    loadCloverField(cudaCloverInvSloppy, h_clovinv, inv_param->clover_cuda_prec_sloppy, inv_param->clover_order);
    inv_param->cloverGiB += 2.0*cudaCloverInvSloppy.even.bytes / (1<<30);
  } else {
    cudaCloverInvSloppy = cudaCloverInvPrecise;
  }
}
rbabich's avatar
rbabich committed
196

rbabich's avatar
rbabich committed
197
198
199
200
201
202
203
204
205
// discard clover term but keep the inverse
void discardCloverQuda(QudaInvertParam *inv_param)
{
  inv_param->cloverGiB -= 2.0*cudaCloverPrecise.even.bytes / (1<<30);
  freeCloverField(&cudaCloverPrecise);
  if (cudaCloverSloppy.even.clover) {
    inv_param->cloverGiB -= 2.0*cudaCloverSloppy.even.bytes / (1<<30);
    freeCloverField(&cudaCloverSloppy);
  }
rbabich's avatar
rbabich committed
206
207
}

rbabich's avatar
rbabich committed
208
void endQuda(void)
mikeaclark's avatar
mikeaclark committed
209
210
{
  freeSpinorBuffer();
211
212
  freeGaugeField(&cudaGaugePrecise);
  freeGaugeField(&cudaGaugeSloppy);
rbabich's avatar
rbabich committed
213
214
215
216
  if (cudaCloverPrecise.even.clover) freeCloverField(&cudaCloverPrecise);
  if (cudaCloverSloppy.even.clover) freeCloverField(&cudaCloverSloppy);
  if (cudaCloverInvPrecise.even.clover) freeCloverField(&cudaCloverInvPrecise);
  if (cudaCloverInvSloppy.even.clover) freeCloverField(&cudaCloverInvSloppy);
mikeaclark's avatar
mikeaclark committed
217
218
}

219
220
221
222
223
224
225
void checkPrecision(QudaInvertParam *param) {
  if (param->cpu_prec == QUDA_HALF_PRECISION) {
    printf("Half precision not supported on cpu\n");
    exit(-1);
  }
}

226
227
void dslashQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int parity, int dagger)
{
228
229
230
231
  checkPrecision(inv_param);

  ParitySpinor in = allocateParitySpinor(cudaGaugePrecise.X, inv_param->cuda_prec);
  ParitySpinor out = allocateParitySpinor(cudaGaugePrecise.X, inv_param->cuda_prec);
rbabich's avatar
rbabich committed
232

233
  loadParitySpinor(in, h_in, inv_param->cpu_prec, inv_param->dirac_order);
rbabich's avatar
rbabich committed
234
235
236
237
238
239
240
241
  if (inv_param->dslash_type == QUDA_WILSON_DSLASH) {
    dslashCuda(out, cudaGaugePrecise, in, parity, dagger);
  } else if (inv_param->dslash_type == QUDA_CLOVER_WILSON_DSLASH) {
    cloverDslashCuda(out, cudaGaugePrecise, cudaCloverInvPrecise, in, parity, dagger);
  } else {
    printf("QUDA error: unsupported dslash_type\n");
    exit(-1);
  }
242
  retrieveParitySpinor(h_out, out, inv_param->cpu_prec, inv_param->dirac_order);
243
244
245
246
247

  freeParitySpinor(out);
  freeParitySpinor(in);
}

mikeaclark's avatar
mikeaclark committed
248
void MatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int dagger)
249
{
250
251
252
253
254
  checkPrecision(inv_param);

  ParitySpinor in = allocateParitySpinor(cudaGaugePrecise.X, inv_param->cuda_prec);
  ParitySpinor out = allocateParitySpinor(cudaGaugePrecise.X, inv_param->cuda_prec);
  ParitySpinor tmp = allocateParitySpinor(cudaGaugePrecise.X, inv_param->cuda_prec);
255
  
256
  loadParitySpinor(in, h_in, inv_param->cpu_prec, inv_param->dirac_order);
rbabich's avatar
rbabich committed
257
258
259
260
261
262
263
264
265
  if (inv_param->dslash_type == QUDA_WILSON_DSLASH) {
    MatPCCuda(out, cudaGaugePrecise, in, inv_param->kappa, tmp, inv_param->matpc_type, dagger);
  } else if (inv_param->dslash_type == QUDA_CLOVER_WILSON_DSLASH) {
    cloverMatPCCuda(out, cudaGaugePrecise, cudaCloverPrecise, cudaCloverInvPrecise, in, inv_param->kappa,
		    tmp, inv_param->matpc_type, dagger);
  } else {
    printf("QUDA error: unsupported dslash_type\n");
    exit(-1);
  }
266
  retrieveParitySpinor(h_out, out, inv_param->cpu_prec, inv_param->dirac_order);
267
268
269
270
271
272
273
274

  freeParitySpinor(tmp);
  freeParitySpinor(out);
  freeParitySpinor(in);
}

void MatPCDagMatPCQuda(void *h_out, void *h_in, QudaInvertParam *inv_param)
{
275
276
277
278
279
  checkPrecision(inv_param);

  ParitySpinor in = allocateParitySpinor(cudaGaugePrecise.X, inv_param->cuda_prec);
  ParitySpinor out = allocateParitySpinor(cudaGaugePrecise.X, inv_param->cuda_prec);
  ParitySpinor tmp = allocateParitySpinor(cudaGaugePrecise.X, inv_param->cuda_prec);
280
  
281
  loadParitySpinor(in, h_in, inv_param->cpu_prec, inv_param->dirac_order);  
rbabich's avatar
rbabich committed
282
283
284
285
286
287
288
289
290
  if (inv_param->dslash_type == QUDA_WILSON_DSLASH) {
    MatPCDagMatPCCuda(out, cudaGaugePrecise, in, inv_param->kappa, tmp, inv_param->matpc_type);
  } else if (inv_param->dslash_type == QUDA_CLOVER_WILSON_DSLASH) {
    cloverMatPCDagMatPCCuda(out, cudaGaugePrecise, cudaCloverPrecise, cudaCloverInvPrecise, in, inv_param->kappa,
			    tmp, inv_param->matpc_type);
  } else {
    printf("QUDA error: unsupported dslash_type\n");
    exit(-1);
  }
291
  retrieveParitySpinor(h_out, out, inv_param->cpu_prec, inv_param->dirac_order);
292
293
294
295
296
297

  freeParitySpinor(tmp);
  freeParitySpinor(out);
  freeParitySpinor(in);
}

mikeaclark's avatar
mikeaclark committed
298
void MatQuda(void *h_out, void *h_in, QudaInvertParam *inv_param, int dagger) {
299
300
301
302
  checkPrecision(inv_param);

  FullSpinor in = allocateSpinorField(cudaGaugePrecise.X, inv_param->cuda_prec);
  FullSpinor out = allocateSpinorField(cudaGaugePrecise.X, inv_param->cuda_prec);
303

304
  loadSpinorField(in, h_in, inv_param->cpu_prec, inv_param->dirac_order);
305

rbabich's avatar
rbabich committed
306
307
308
309
310
311
312
313
314
315
  if (inv_param->dslash_type == QUDA_WILSON_DSLASH) {
    MatCuda(out, cudaGaugePrecise, in, -inv_param->kappa, dagger);
  } else if (inv_param->dslash_type == QUDA_CLOVER_WILSON_DSLASH) {
    ParitySpinor tmp = allocateParitySpinor(cudaGaugePrecise.X, inv_param->cuda_prec);
    cloverMatCuda(out, cudaGaugePrecise, cudaCloverPrecise, in, inv_param->kappa, tmp, dagger);
    freeParitySpinor(tmp);
  } else {
    printf("QUDA error: unsupported dslash_type\n");
    exit(-1);
  }
316
  retrieveSpinorField(h_out, out, inv_param->cpu_prec, inv_param->dirac_order);
317

318
319
  freeSpinorField(out);
  freeSpinorField(in);
320
321
}

mikeaclark's avatar
mikeaclark committed
322
void invertQuda(void *h_x, void *h_b, QudaInvertParam *param)
mikeaclark's avatar
mikeaclark committed
323
{
mikeaclark's avatar
mikeaclark committed
324
  invert_param = param;
mikeaclark's avatar
mikeaclark committed
325

326
  checkPrecision(param);
mikeaclark's avatar
mikeaclark committed
327

328
  int slenh = cudaGaugePrecise.volume*spinorSiteSize;
329
  param->spinorGiB = (double)slenh*(param->cuda_prec == QUDA_DOUBLE_PRECISION) ? sizeof(double): sizeof(float);
mikeaclark's avatar
mikeaclark committed
330
  if (param->preserve_source == QUDA_PRESERVE_SOURCE_NO)
331
    param->spinorGiB *= (param->inv_type == QUDA_CG_INVERTER ? 5 : 7)/(1<<30);
mikeaclark's avatar
mikeaclark committed
332
  else
333
    param->spinorGiB *= (param->inv_type == QUDA_CG_INVERTER ? 8 : 9)/(1<<30);
mikeaclark's avatar
mikeaclark committed
334

mikeaclark's avatar
mikeaclark committed
335
336
337
  param->secs = 0;
  param->gflops = 0;
  param->iter = 0;
mikeaclark's avatar
mikeaclark committed
338

339
  double kappa = param->kappa;
mikeaclark's avatar
mikeaclark committed
340
  if (param->dirac_order == QUDA_CPS_WILSON_DIRAC_ORDER) kappa /= cudaGaugePrecise.anisotropy;
mikeaclark's avatar
mikeaclark committed
341
342

  FullSpinor b, x;
343
344
345
  ParitySpinor in = allocateParitySpinor(cudaGaugePrecise.X, invert_param->cuda_prec); // source vector
  ParitySpinor out = allocateParitySpinor(cudaGaugePrecise.X, invert_param->cuda_prec); // solution vector
  ParitySpinor tmp = allocateParitySpinor(cudaGaugePrecise.X, invert_param->cuda_prec); // temporary used when applying operator
mikeaclark's avatar
mikeaclark committed
346

mikeaclark's avatar
mikeaclark committed
347
348
  if (param->solution_type == QUDA_MAT_SOLUTION) {
    if (param->preserve_source == QUDA_PRESERVE_SOURCE_YES) {
349
      b = allocateSpinorField(cudaGaugePrecise.X, invert_param->cuda_prec);
mikeaclark's avatar
mikeaclark committed
350
351
352
353
354
    } else {
      b.even = out;
      b.odd = tmp;
    }

rbabich's avatar
rbabich committed
355
356
357
358
359
360
361
362
    if (param->matpc_type == QUDA_MATPC_EVEN_EVEN ||
	param->matpc_type == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
      x.odd = tmp;
      x.even = out;
    } else {
      x.even = tmp;
      x.odd = out;
    }
mikeaclark's avatar
mikeaclark committed
363

364
    loadSpinorField(b, h_b, param->cpu_prec, param->dirac_order);
365

mikeaclark's avatar
mikeaclark committed
366
    // multiply the source to get the mass normalization
mikeaclark's avatar
mikeaclark committed
367
    if (param->mass_normalization == QUDA_MASS_NORMALIZATION) {
368
369
      axCuda(2.0*kappa, b.even);
      axCuda(2.0*kappa, b.odd);
mikeaclark's avatar
mikeaclark committed
370
371
    }

mikeaclark's avatar
mikeaclark committed
372
373
374
    // cps uses a different anisotropy normalization
    if (param->dirac_order == QUDA_CPS_WILSON_DIRAC_ORDER) {
      axCuda(1.0/gauge_param->anisotropy, b.even);
rbabich's avatar
rbabich committed
375
      axCuda(1.0/gauge_param->anisotropy, b.odd);
mikeaclark's avatar
mikeaclark committed
376
377
    }

rbabich's avatar
rbabich committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
    if (param->dslash_type == QUDA_WILSON_DSLASH) {
      if (param->matpc_type == QUDA_MATPC_EVEN_EVEN) {
	// in = b_e + k D_eo b_o
	dslashXpayCuda(in, cudaGaugePrecise, b.odd, 0, 0, b.even, kappa);
      } else if (param->matpc_type == QUDA_MATPC_ODD_ODD) {
	// in = b_o + k D_oe b_e
	dslashXpayCuda(in, cudaGaugePrecise, b.even, 1, 0, b.odd, kappa);
      } else {
	printf("QUDA error: matpc_type not valid for plain Wilson\n");
	exit(-1);
      }
    } else if (param->dslash_type == QUDA_CLOVER_WILSON_DSLASH) {
      if (param->matpc_type == QUDA_MATPC_EVEN_EVEN) {
	// in = A_ee^-1 (b_e + k D_eo A_oo^-1 b_o)
	ParitySpinor aux = tmp; // aliases b.odd when PRESERVE_SOURCE_NO is set
	cloverCuda(in, cudaGaugePrecise, cudaCloverInvPrecise, b.odd, 1);
	dslashXpayCuda(aux, cudaGaugePrecise, in, 0, 0, b.even, kappa);
	cloverCuda(in, cudaGaugePrecise, cudaCloverInvPrecise, aux, 0);
      } else if (param->matpc_type == QUDA_MATPC_ODD_ODD) {
	// in = A_oo^-1 (b_o + k D_oe A_ee^-1 b_e)
	ParitySpinor aux = tmp; // aliases b.odd when PRESERVE_SOURCE_NO is set
	cloverCuda(in, cudaGaugePrecise, cudaCloverInvPrecise, b.even, 0);
	dslashXpayCuda(aux, cudaGaugePrecise, in, 1, 0, b.odd, kappa);
	cloverCuda(in, cudaGaugePrecise, cudaCloverInvPrecise, aux, 1);
      } else if (param->matpc_type == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
	// in = b_e + k D_eo A_oo^-1 b_o
	ParitySpinor aux = out; // aliases b.even when PRESERVE_SOURCE_NO is set
	cloverCuda(in, cudaGaugePrecise, cudaCloverInvPrecise, b.odd, 1);
	dslashXpayCuda(aux, cudaGaugePrecise, in, 0, 0, b.even, kappa);
	copyCuda(in, aux);
      } else if (param->matpc_type == QUDA_MATPC_ODD_ODD_ASYMMETRIC) {
	// in = b_o + k D_oe A_ee^-1 b_e
	ParitySpinor aux = out; // aliases b.even when PRESERVE_SOURCE_NO is set
	cloverCuda(in, cudaGaugePrecise, cudaCloverInvPrecise, b.even, 0);
	dslashXpayCuda(aux, cudaGaugePrecise, in, 1, 0, b.odd, kappa);
	copyCuda(in, aux);
      } else {
	printf("QUDA error: invalid matpc_type\n");
	exit(-1);
      }
mikeaclark's avatar
mikeaclark committed
418
    } else {
rbabich's avatar
rbabich committed
419
420
      printf("QUDA error: unsupported dslash_type\n");
      exit(-1);
mikeaclark's avatar
mikeaclark committed
421
422
    }

mikeaclark's avatar
mikeaclark committed
423
424
  } else if (param->solution_type == QUDA_MATPC_SOLUTION || 
	     param->solution_type == QUDA_MATPCDAG_MATPC_SOLUTION){
425
    loadParitySpinor(in, h_b, param->cpu_prec, param->dirac_order);
mikeaclark's avatar
mikeaclark committed
426
427

    // multiply the source to get the mass normalization
mikeaclark's avatar
mikeaclark committed
428
429
    if (param->mass_normalization == QUDA_MASS_NORMALIZATION)
      if (param->solution_type == QUDA_MATPC_SOLUTION) 
430
	axCuda(4.0*kappa*kappa, in);
mikeaclark's avatar
mikeaclark committed
431
      else
432
	axCuda(16.0*pow(kappa,4), in);
mikeaclark's avatar
mikeaclark committed
433
434
435
436
437
438
439
440

    // cps uses a different anisotropy normalization
    if (param->dirac_order == QUDA_CPS_WILSON_DIRAC_ORDER)
      if (param->solution_type == QUDA_MATPC_SOLUTION) 
	axCuda(pow(1.0/gauge_param->anisotropy, 2), in);
      else 
	axCuda(pow(1.0/gauge_param->anisotropy, 4), in);

mikeaclark's avatar
mikeaclark committed
441
442
  }

mikeaclark's avatar
mikeaclark committed
443
  switch (param->inv_type) {
mikeaclark's avatar
mikeaclark committed
444
  case QUDA_CG_INVERTER:
mikeaclark's avatar
mikeaclark committed
445
    if (param->solution_type != QUDA_MATPCDAG_MATPC_SOLUTION) {
446
      copyCuda(out, in);
mikeaclark's avatar
mikeaclark committed
447
      MatPCCuda(in, cudaGaugePrecise, out, kappa, tmp, param->matpc_type, QUDA_DAG_YES);
mikeaclark's avatar
mikeaclark committed
448
    }
rbabich's avatar
rbabich committed
449
    invertCgCuda(out, in, tmp, param);
mikeaclark's avatar
mikeaclark committed
450
451
    break;
  case QUDA_BICGSTAB_INVERTER:
mikeaclark's avatar
mikeaclark committed
452
    if (param->solution_type == QUDA_MATPCDAG_MATPC_SOLUTION) {
rbabich's avatar
rbabich committed
453
      invertBiCGstabCuda(out, in, tmp, param, QUDA_DAG_YES);
454
      copyCuda(in, out);
mikeaclark's avatar
mikeaclark committed
455
    }
rbabich's avatar
rbabich committed
456
    invertBiCGstabCuda(out, in, tmp, param, QUDA_DAG_NO);
mikeaclark's avatar
mikeaclark committed
457
458
    break;
  default:
mikeaclark's avatar
mikeaclark committed
459
    printf("Inverter type %d not implemented\n", param->inv_type);
mikeaclark's avatar
mikeaclark committed
460
461
462
    exit(-1);
  }

mikeaclark's avatar
mikeaclark committed
463
  if (param->solution_type == QUDA_MAT_SOLUTION) {
mikeaclark's avatar
mikeaclark committed
464

mikeaclark's avatar
mikeaclark committed
465
    if (param->preserve_source == QUDA_PRESERVE_SOURCE_NO) {
mikeaclark's avatar
mikeaclark committed
466
467
      // qdp dirac fields are even-odd ordered
      b.even = in;
468
      loadSpinorField(b, h_b, param->cpu_prec, param->dirac_order);
mikeaclark's avatar
mikeaclark committed
469
470
    }

rbabich's avatar
rbabich committed
471
472
473
474
475
476
477
478
    if (param->dslash_type == QUDA_WILSON_DSLASH) {
      if (param->matpc_type == QUDA_MATPC_EVEN_EVEN) {
	// x_o = b_o + k D_oe x_e
	dslashXpayCuda(x.odd, cudaGaugePrecise, out, 1, 0, b.odd, kappa);
      } else {
	// x_e = b_e + k D_eo x_o
	dslashXpayCuda(x.even, cudaGaugePrecise, out, 0, 0, b.even, kappa);
      }
mikeaclark's avatar
mikeaclark committed
479
    } else {
rbabich's avatar
rbabich committed
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
      if (param->matpc_type == QUDA_MATPC_EVEN_EVEN) {
	// x_o = A_oo^-1 (b_o + k D_oe x_e)
	ParitySpinor aux = b.even;
	dslashXpayCuda(aux, cudaGaugePrecise, out, 1, 0, b.odd, kappa);
	cloverCuda(x.odd, cudaGaugePrecise, cudaCloverInvPrecise, aux, 1);
      } else if (param->matpc_type == QUDA_MATPC_ODD_ODD) {
	// x_e = A_ee^-1 (b_e + k D_eo x_o)
	ParitySpinor aux = b.odd;
	dslashXpayCuda(aux, cudaGaugePrecise, out, 0, 0, b.even, kappa);
	cloverCuda(x.even, cudaGaugePrecise, cudaCloverInvPrecise, aux, 0);
      } else if (param->matpc_type == QUDA_MATPC_EVEN_EVEN_ASYMMETRIC) {
	// x_o = b_o + k D_oe x_e
	dslashXpayCuda(x.odd, cudaGaugePrecise, out, 1, 0, b.odd, kappa);
      } else {
	// x_e = b_e + k D_eo x_o
	dslashXpayCuda(x.even, cudaGaugePrecise, out, 0, 0, b.even, kappa);
      }
mikeaclark's avatar
mikeaclark committed
497
498
    }

499
    retrieveSpinorField(h_x, x, param->cpu_prec, param->dirac_order);
mikeaclark's avatar
mikeaclark committed
500

mikeaclark's avatar
mikeaclark committed
501
    if (param->preserve_source == QUDA_PRESERVE_SOURCE_YES) freeSpinorField(b);
mikeaclark's avatar
mikeaclark committed
502

mikeaclark's avatar
mikeaclark committed
503
  } else {
504
    retrieveParitySpinor(h_x, out, param->cpu_prec, param->dirac_order);
mikeaclark's avatar
mikeaclark committed
505
506
507
508
509
510
511
512
  }

  freeParitySpinor(tmp);
  freeParitySpinor(in);
  freeParitySpinor(out);

  return;
}