Out-of-memory error using Joblib with JAX
Error Message:
jaxlib. xla_ extension. XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 178053120 bytes.
...
BrokenProcessPool: A task has failed to un-serialize. Please ensure that the arguments of the function are all picklable.
Summary: When using Joblib with JAX, I encounter an out-of-memory error (RESOURCE_EXHAUSTED) and a subsequent BrokenProcessPool error. The error message suggests that one or more tasks failed to unserialize due to non-picklable arguments.
@wangruisai any suggestions?
for code details, see a2174331