Kubeflow PyTorch
Name: | flytekitplugins-kfpytorch |
Version: | 0.0.0+develop |
Author: | admin@flyte.org |
Provides: |
flytekitplugins.kfpytorch |
Requires: |
cloudpickle flyteidl>=1.5.1 flytekit>=1.6.1 kubernetes |
Python: | >=3.9 |
License: | apache2 |
Source Code: | https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-kf-pytorch |
- Intended Audience :: Science/Research
- Intended Audience :: Developers
- License :: OSI Approved :: Apache Software License
- Programming Language :: Python :: 3.9
- Programming Language :: Python :: 3.10
- Topic :: Scientific/Engineering
- Topic :: Scientific/Engineering :: Artificial Intelligence
- Topic :: Software Development
- Topic :: Software Development :: Libraries
- Topic :: Software Development :: Libraries :: Python Modules
This plugin uses the Kubeflow PyTorch Operator and provides an extremely simplified interface for executing distributed training using various PyTorch backends.
This plugin can execute torch elastic training, which is equivalent to run torchrun
. Elastic training can be executed
in a single Pod (without requiring the PyTorch operator, see below) as well as in a distributed multi-node manner.
To install the plugin, run the following command:
pip install flytekitplugins-kfpytorch
To set up PyTorch operator in the Flyte deployment’s backend, follow the PyTorch Operator Setup guide.
An example showcasing PyTorch operator can be found in the documentation.
Code Example
from flytekitplugins.kfpytorch import PyTorch, Worker
@task(
task_config = PyTorch(
worker=Worker(replicas=5)
)
image="test_image",
resources=Resources(cpu="1", mem="1Gi"),
)
def pytorch_job():
...
You can specify run policy and restart policy of a pytorch job. The default restart policy for both master and worker group is the never restart, you can set it to other policy.
from flytekitplugins.kfpytorch import PyTorch, Worker, RestartPolicy, RunPolicy
@task(
task_config = PyTorch(
worker=Worker(replicas=5, restart_policy=RestartPolicy.FAILURE),
run_policy=RunPolicy(
clean_pod_policy=CleanPodPolicy.ALL,
)
)
image="test_image",
resources=Resources(cpu="1", mem="1Gi"),
)
def pytorch_job():
...
Upgrade Pytorch Plugin from V0 to V1
Pytorch plugin is now updated from v0 to v1 to enable more configuration options. To migrate from v0 to v1, change the following:
- Update flytepropeller to v1.6.0
- Update flytekit version to v1.6.2
- Update your code from:
task_config=Pytorch(num_workers=10),
task_config=PyTorch(worker=Worker(replicas=10)),