Julian Bilcke
we are going to hack into finetrainers
9fd1204
from enum import Enum
from typing import Union
from .accelerate import AccelerateParallelBackend
from .ptd import PytorchDTensorParallelBackend
from .utils import dist_max, dist_mean
ParallelBackendType = Union[AccelerateParallelBackend, PytorchDTensorParallelBackend]
class ParallelBackendEnum(str, Enum):
ACCELERATE = "accelerate"
PTD = "ptd"
def get_parallel_backend_cls(backend: ParallelBackendEnum) -> ParallelBackendType:
if backend == ParallelBackendEnum.ACCELERATE:
return AccelerateParallelBackend
if backend == ParallelBackendEnum.PTD:
return PytorchDTensorParallelBackend
raise ValueError(f"Unknown parallel backend: {backend}")