Advanced Computing Platform for Theoretical Physics

Commit a0b5fa09 authored by Hao Xie's avatar Hao Xie
Browse files

cuda-aware mpi on delta103

parent cf627239
......@@ -12,80 +12,81 @@ rank = comm.Get_rank()
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
from functools import partial
import time
import numpy as np
import numpy.random as npr
import jax
from jax import jit, grad, pmap
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map
from jax import lax
import jax.numpy as jnp
import datasets
from jax.flatten_util import ravel_pytree
import numpy as np
from functools import partial
import time
import mpi4jax
def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
return [(scale * rng.randn(m, n), scale * rng.randn(n))
for m, n, in zip(layer_sizes[:-1], layer_sizes[1:])]
def init_random_params(scale, layer_sizes, key):
params = []
for m, n in zip(layer_sizes[:-1], layer_sizes[1:]):
key, subkey = jax.random.split(key)
params.append(( scale * jax.random.normal(subkey, (m, n)),
scale * jax.random.normal(subkey, (n,)) ))
return params
def predict(params, inputs):
activations = inputs
for w, b in params[:-1]:
outputs = jnp.dot(activations, w) + b
activations = jnp.tanh(outputs)
activations = inputs
for w, b in params[:-1]:
outputs = jnp.dot(activations, w) + b
activations = jnp.tanh(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(activations, final_w) + final_b
return logits - logsumexp(logits, axis=1, keepdims=True)
final_w, final_b = params[-1]
logits = jnp.dot(activations, final_w) + final_b
return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
@jit
@jax.jit
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
if __name__ == "__main__":
layer_sizes = [784, 1024, 1024, 10]
param_scale = 0.1
step_size = 0.001
num_epochs = 10
batch_size = 1024
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(rank)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
images, labels = train_images[batch_idx], train_labels[batch_idx]
batch_size_per_device, ragged = divmod(images.shape[0], size)
if ragged:
msg = "batch size must be divisible by device count, got {} and {}."
raise ValueError(msg.format(batch_size, size))
images = images[rank*batch_size_per_device:(rank+1)*batch_size_per_device]
labels = labels[rank*batch_size_per_device:(rank+1)*batch_size_per_device]
yield images, labels
batches = data_stream()
@jax.jit
def mpi_update(params, batch):
grads = grad(loss)(params, batch)
layer_sizes = [784, 1024, 1024, 10]
param_scale = 0.1
step_size = 0.001
num_epochs = 10
batch_size = 1024
# (60000, 784), (60000, 10), (10000, 784), (10000, 10)
train_images, train_labels, test_images, test_labels = [jnp.array(arr) for arr in datasets.mnist()]
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = np.random.RandomState(0)
  • Ha, I realized that I wrote a bug here: npr.RandomState(rank). Indeed, different threads should perform the SAME permutation.

    Glad you spotted that bug! 👍

Please register or sign in to reply
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
images, labels = train_images[batch_idx], train_labels[batch_idx]
batch_size_per_device, ragged = divmod(images.shape[0], size)
if ragged:
msg = "batch size must be divisible by device count, got {} and {}."
raise ValueError(msg.format(images.shape[0], size))
images = images[rank*batch_size_per_device:(rank+1)*batch_size_per_device]
labels = labels[rank*batch_size_per_device:(rank+1)*batch_size_per_device]
yield images, labels
batches = data_stream()
@jax.jit
def mpi_update(params, batch):
grads = jax.grad(loss)(params, batch)
grads, unravel = ravel_pytree(grads)
avg_grads, _ = mpi4jax.allreduce(grads, op=MPI.SUM, comm=comm)
......@@ -94,17 +95,18 @@ if __name__ == "__main__":
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
params = init_random_params(param_scale, layer_sizes)
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
params = mpi_update(params, next(batches))
epoch_time = time.time() - start_time
if rank==0:
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
key = jax.random.PRNGKey(42)
params = init_random_params(param_scale, layer_sizes, key)
for epoch in range(num_epochs):
start_time = time.time()
for i in range(num_batches):
params = mpi_update(params, next(batches))
epoch_time = time.time() - start_time
if rank==0:
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment