PyTorch Example
Once you have a Union account, install union
:
pip install union
Export the following environment variable to build and push images to your own container registry:
# replace with your registry name
export IMAGE_SPEC_REGISTRY="<your-container-registry>"
Then run the following commands to run the workflow:
git clone https://github.com/unionai/unionai-examples
cd unionai-examples
union run --remote tutorials/sentiment_classifier/sentiment_classifier.py main --model distilbert-base-uncased
The source code for this tutorial can be found here {octicon}mark-github
.
from pathlib import Path
import flytekit
import numpy as np
import onnxruntime
import requests
import torch.nn.init as init
import torch.onnx
import torch.utils.model_zoo as model_zoo
import torchvision.transforms as transforms
from flytekit import task, workflow
from flytekit.types.file import JPEGImageFile, ONNXFile
from flytekitplugins.onnxpytorch import PyTorch2ONNX, PyTorch2ONNXConfig
from PIL import Image
from torch import nn
from typing_extensions import Annotated
Define a conv super resolution model.
class SuperResolutionNet(nn.Module):
def __init__(self, upscale_factor, inplace=False):
super(SuperResolutionNet, self).__init__()
self.relu = nn.ReLU(inplace=inplace)
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor**2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
self._initialize_weights()
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x
def _initialize_weights(self):
init.orthogonal_(self.conv1.weight, init.calculate_gain("relu"))
init.orthogonal_(self.conv2.weight, init.calculate_gain("relu"))
init.orthogonal_(self.conv3.weight, init.calculate_gain("relu"))
init.orthogonal_(self.conv4.weight)
Define a train
task to train the model.
Note the annotated output put.
This is a special annotation that tells Flytekit that this parameter is to be converted to an ONNX model with the given config.
@task
def train() -> (
Annotated[
PyTorch2ONNX,
PyTorch2ONNXConfig(
args=torch.randn(1, 1, 224, 224, requires_grad=True),
export_params=True, # store the trained parameter weights inside
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=["input"], # the model's input names
output_names=["output"], # the model's output names
dynamic_axes={
"input": {0: "batch_size"},
"output": {0: "batch_size"},
}, # variable length axes
),
]
):
# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)
# Load pretrained model weights
model_url = "https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth"
# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage # noqa: E731
if torch.cuda.is_available():
map_location = None
torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))
return PyTorch2ONNX(model=torch_model)
Define an onnx_predict
task to generate a super resolution image from the model, given an input image.
@task
def onnx_predict(model_file: ONNXFile) -> JPEGImageFile:
ort_session = onnxruntime.InferenceSession(model_file.download())
img = Image.open(
requests.get(
"https://raw.githubusercontent.com/flyteorg/static-resources/main/flytekit/onnx/cat.jpg",
stream=True,
).raw
)
img = transforms.Resize([224, 224])(img)
img_ycbcr = img.convert("YCbCr")
img_y, img_cb, img_cr = img_ycbcr.split()
to_tensor = transforms.ToTensor()
img_y = to_tensor(img_y)
img_y.unsqueeze_(0)
# compute ONNX Runtime output prediction
ort_inputs = {
ort_session.get_inputs()[0].name: img_y.detach().cpu().numpy() if img_y.requires_grad else img_y.cpu().numpy()
}
ort_outs = ort_session.run(None, ort_inputs)
img_out_y = ort_outs[0]
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode="L")
# get the output image follow post-processing step from PyTorch implementation
final_img = Image.merge(
"YCbCr",
[
img_out_y,
img_cb.resize(img_out_y.size, Image.BICUBIC),
img_cr.resize(img_out_y.size, Image.BICUBIC),
],
).convert("RGB")
img_path = Path(flytekit.current_context().working_directory) / "cat_superres_with_ort.jpg"
final_img.save(img_path)
return JPEGImageFile(path=str(img_path))
Define a workflow to run the above tasks.
@workflow
def wf() -> JPEGImageFile:
model = train()
return onnx_predict(model_file=model)
Run the workflow locally.
if __name__ == "__main__":
print(f"Prediction: {wf()}")