Advanced Computing Platform for Theoretical Physics

Commit 9361c664 authored by Hao Xie's avatar Hao Xie
Browse files

mnist using multi-host pmap

parent a0b5fa09
import atexit
from absl import logging
from functools import partial
import jax
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_extension as xc
def _reset_backend_state():
xb._backends = {}
xb._backends_errors = {}
xb._default_backend = None
xb.get_backend.cache_clear()
def connect_to_gpu_cluster(server_ip, server_port, num_hosts, host_idx):
_reset_backend_state()
service = None
if host_idx == 0:
addr = f'{server_ip}:{server_port}'
logging.info('starting service on %s', addr)
service = xc.get_distributed_runtime_service(addr, num_hosts)
atexit.register(service.shutdown)
server_addr = f'{server_ip}:{server_port}'
logging.info('connecting to service on %s', server_addr)
dist_client = xc.get_distributed_runtime_client(server_addr, host_idx)
dist_client.connect()
atexit.register(dist_client.shutdown)
# register dist gpu backend
factory = partial(jax._src.lib.xla_client.make_gpu_client, dist_client, host_idx)
xb.register_backend_factory('gpu', factory, priority=300)
return service
#https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py
"""An MNIST example with jax.pmap in a multi-host environment.
"""
from multihost import connect_to_gpu_cluster
from absl import app, flags
import jax
from jax.scipy.special import logsumexp
from jax.tree_util import tree_map
import jax.numpy as jnp
import datasets
import numpy as np
from functools import partial
import time
flags.DEFINE_string('server_ip', '', help='IP of rank 0 server.')
flags.DEFINE_integer('server_port', 9999, help='port of rank 0 server.')
flags.DEFINE_integer('num_hosts', 1, help='number of nodes in GPU cluster.')
flags.DEFINE_integer('host_idx', 0, help='index of current node.')
FLAGS = flags.FLAGS
FLAGS.alsologtostderr = True
def init_random_params(scale, layer_sizes, key):
params = []
for m, n in zip(layer_sizes[:-1], layer_sizes[1:]):
key, subkey = jax.random.split(key)
params.append(( scale * jax.random.normal(subkey, (m, n)),
scale * jax.random.normal(subkey, (n,)) ))
return params
def predict(params, inputs):
activations = inputs
for w, b in params[:-1]:
outputs = jnp.dot(activations, w) + b
activations = jnp.tanh(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(activations, final_w) + final_b
return logits - logsumexp(logits, axis=1, keepdims=True)
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
@jax.jit
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
# each training node needs to connect to rank 0 server
service = connect_to_gpu_cluster(FLAGS.server_ip, FLAGS.server_port,
FLAGS.num_hosts, FLAGS.host_idx)
print('Cluster connected with totally %d GPUs: \n%s' %
(jax.device_count(), jax.devices()))
print("This is process %d with %d local GPUs: \n%s" %
(jax.process_index(), jax.local_device_count(), jax.local_devices()))
num_local_devices, process_count = jax.local_device_count(), jax.process_count()
size = process_count * num_local_devices
process_idx = jax.process_index()
layer_sizes = [784, 1024, 1024, 10]
param_scale = 0.1
step_size = 0.001
num_epochs = 10
batch_size = 1024
# (60000, 784), (60000, 10), (10000, 784), (10000, 10)
train_images, train_labels, test_images, test_labels = [jnp.array(arr) for arr in datasets.mnist()]
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = np.random.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
images, labels = train_images[batch_idx], train_labels[batch_idx]
batch_size_per_device, ragged = divmod(images.shape[0], size)
if ragged:
msg = "batch size must be divisible by device count, got {} and {}."
raise ValueError(msg.format(images.shape[0], size))
images = images.reshape(process_count, num_local_devices, batch_size_per_device, *images.shape[1:])[process_idx]
labels = labels.reshape(process_count, num_local_devices, batch_size_per_device, *labels.shape[1:])[process_idx]
yield images, labels
batches = data_stream()
@partial(jax.pmap, axis_name="p")
def mpi_update(params, batch):
grads = jax.grad(loss)(params, batch)
grads = jax.lax.pmean(grads, axis_name="p")
return [(w - step_size * dw, b - step_size * db)
for (w, b), (dw, db) in zip(params, grads)]
key = jax.random.PRNGKey(42)
params = init_random_params(param_scale, layer_sizes, key)
params = jax.tree_map(lambda arr: jnp.broadcast_to(arr, (num_local_devices, *arr.shape)), params)
for epoch in range(num_epochs):
start_time = time.time()
for i in range(num_batches):
params = mpi_update(params, next(batches))
epoch_time = time.time() - start_time
params_single_copy = jax.tree_map(lambda arr: arr[0], params)
train_acc = accuracy(params_single_copy, (train_images, train_labels))
test_acc = accuracy(params_single_copy, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
if __name__ == '__main__':
app.run(main)
from multihost import connect_to_gpu_cluster
from absl import app, flags, logging
import jax
import jax.numpy as jnp
flags.DEFINE_string('server_ip', '', help='IP of rank 0 server.')
flags.DEFINE_integer('server_port', 9999, help='port of rank 0 server.')
flags.DEFINE_integer('num_hosts', 1, help='number of nodes in GPU cluster.')
flags.DEFINE_integer('host_idx', 0, help='index of current node.')
FLAGS = flags.FLAGS
FLAGS.alsologtostderr = True
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
# each training node needs to connect to rank 0 server
service = connect_to_gpu_cluster(FLAGS.server_ip, FLAGS.server_port,
FLAGS.num_hosts, FLAGS.host_idx)
logging.info('Cluster connected with totally %d GPUs: \n%s' %
(jax.device_count(), jax.devices()))
logging.info("This is process %d with %d local GPUs: \n%s" %
(jax.process_index(), jax.local_device_count(), jax.local_devices()))
# 4 GPUs on each host
x = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8)
f = lambda x: x + jax.lax.psum(x, axis_name="p")
y = jax.pmap(f, axis_name="p")(x)
print(f"x: {x}\ny: {y}")
if __name__ == '__main__':
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment