IMP logo
IMP Manual  for IMP version 2.24.0
gpu.md
1 GPU support {#gpu}
2 ===========
3 
4 %IMP currently has rudimentary support for running on a graphics
5 processing unit (GPU) or similar systems such as Tensor Processing Units (TPUs).
6 This support uses the [JAX](https://docs.jax.dev/) Python library.
7 
8 To use the JAX support in optimization, first install the JAX library,
9 for example with `pip install jax`. Then set up the system as per usual and
10 replace any calls to IMP::core::MonteCarlo::optimize or
11 IMP::atom::MolecularDynamics::optimize with `_optimize_jax()`.
12 
13 The JAX code is still in active development and many caveats apply:
14 
15  - Only a small number of scoring function terms and optimizers currently
16  have JAX implementations. Trying to use others will result in a
17  Python NotImplementedError exception.
18  - Some IMP::ScoreState (aka constraint) classes do not yet work - this
19  includes common applications such as rigid bodies and close pair containers.
20  - There is currently no PMI support for JAX.
21 
22 To add JAX support for a particular IMP::Restraint, IMP::PairScore,
23 IMP::core::MonteCarloMover, or IMP::OptimizerState,
24 implement the `_get_jax()` method. See the
25 [IMP.example module](https://github.com/salilab/imp/blob/develop/modules/example/pyext/IMP_example.jax.i)
26 for some examples
27 
28 Note that the JAX code will also run on a CPU. In some circumstances the
29 JAX code will run faster than the native %IMP C++ code on a CPU, so it may
30 be worth benchmarking both approaches.
31 
32 Note that %IMP also has some very basic C++ support for NVIDIA GPUs using the
33 CUDA toolkit. This is unlikely to be further developed, however.
34 To build %IMP from source code with CUDA support (there are currently no
35 prebuilt %IMP binaries that use CUDA), ensure that the `nvcc` compiler
36 from NVIDIA's [CUDA toolkit](https://developer.nvidia.com/cuda-downloads)
37 is available, and add `-DIMP_CUDA` to your
38 [CMake invocation](@ref cmake_config).