Prerequisites

CGX, as a pytorch extension, requires pytorch>=1.10.0.

For faster build we recommend to have ninja installed (pip install ninja).

The compression is only supported for GPU-based buffers so either CUDA or ROCm is required. If CUDA or ROCm are installed not in the standard paths, set [CUDA|ROCM]_HOME or [CUDA|ROCM]_PATH accordingly.

As long as it is based on MPI, it requires OpenMPI with GPU support installed (other MPI implementations were not tested). Also, the library supports NCCL based communications, so it requires NVIDIA NCCL library.

Set MPI_HOME environment variable to mpi home. In case of AMD GPU, set CGX_CUDA to 0. Set NCCL_HOME environment variable to NCCL home, or NCCL_INCLUDE and NCCL_LIB. Set QSGD_DETERMENISTIC=0 if you want to have stochastic version QSGD.

Download

pip install pytorch_cgx

Build from source

git clone https://github.com/IST-DASLab/torch_cgx
python setup.py install