Skip to content

Commit

Permalink
Merge pull request #60 from aws/new-modalities
Browse files Browse the repository at this point in the history
Add new modalities
  • Loading branch information
philschmid authored Mar 25, 2022
2 parents 7cb5009 + 419b278 commit 2f1fae5
Show file tree
Hide file tree
Showing 20 changed files with 248 additions and 29 deletions.
14 changes: 13 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
</div>




# SageMaker Hugging Face Inference Toolkit

[![Latest Version](https://img.shields.io/pypi/v/sagemaker_huggingface_inference_toolkit.svg)](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [![Supported Python Versions](https://img.shields.io/pypi/pyversions/sagemaker_huggingface_inference_toolkit.svg)](https://pypi.python.org/pypi/sagemaker_huggingface_inference_toolkit) [![Code Style: Black](https://img.shields.io/badge/code_style-black-000000.svg)](https://github.com/python/black)
Expand Down Expand Up @@ -111,7 +113,7 @@ HF_API_TOKEN="api_XXXXXXXXXXXXXXXXXXXXXXXXXXXXX"

## 🧑🏻‍💻 User defined code/modules

The Hugging Face Inference Toolkit allows user to override the default methods of the `HuggingFaceHandlerService`. Therefor the need to create a named `code/` with a `inference.py` file in it.
The Hugging Face Inference Toolkit allows user to override the default methods of the `HuggingFaceHandlerService`. Therefor the need to create a named `code/` with a `inference.py` file in it. You can find an example for it in [sagemaker/17_customer_inference_script](https://github.com/huggingface/notebooks/blob/master/sagemaker/17_custom_inference_script/sagemaker-notebook.ipynb)
For example:
```bash
model.tar.gz/
Expand Down Expand Up @@ -144,3 +146,13 @@ requests to us.
## 📜 License

SageMaker Hugging Face Inference Toolkit is licensed under the Apache 2.0 License.

---

## 🧑🏻‍💻 Development Environment

Install all test and development packages with

```bash
pip3 install -e ".[test,dev]"
```
18 changes: 15 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,34 @@
# We don't declare our dependency on transformers here because we build with
# different packages for different variants

VERSION = "1.3.1"
VERSION = "2.0.0"


# Ubuntu packages
# libsndfile1-dev: torchaudio requires the development version of the libsndfile package which can be installed via a system package manager. On Ubuntu it can be installed as follows: apt install libsndfile1-dev
# ffmpeg: ffmpeg is required for audio processing. On Ubuntu it can be installed as follows: apt install ffmpeg
# libavcodec-extra : libavcodec-extra inculdes additional codecs for ffmpeg

install_requires = [
"sagemaker-inference>=1.5.11",
"huggingface_hub>=0.0.8",
"retrying",
"numpy",
# vision
"Pillow",
# speech + torchaudio
"librosa",
"pyctcdecode>=0.3.0",
"phonemizer",
]

extras = {}

# Hugging Face specific dependencies
extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.5.1"]
extras["transformers"] = ["transformers[sklearn,sentencepiece]>=4.17.0"]

# framework specific dependencies
extras["torch"] = ["torch>=1.8.0"]
extras["torch"] = ["torch>=1.8.0", "torchaudio"]
extras["tensorflow"] = ["tensorflow>=2.4.0"]

# MMS Server dependencies
Expand Down
37 changes: 37 additions & 0 deletions src/sagemaker_huggingface_inference_toolkit/content_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2021 The HuggingFace Team, Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains constants that define MIME content types."""
# Default Mime-Types
JSON = "application/json"
CSV = "text/csv"
OCTET_STREAM = "application/octet-stream"
ANY = "*/*"
NPY = "application/x-npy"
UTF8_TYPES = [JSON, CSV]
# Vision Mime-Types
JPEG = "image/jpeg"
PNG = "image/png"
TIFF = "image/tiff"
BMP = "image/bmp"
GIF = "image/gif"
WEBP = "image/webp"
X_IMAGE = "image/x-image"
VISION_TYPES = [JPEG, PNG, TIFF, BMP, GIF, WEBP, X_IMAGE]
# Speech Mime-Types
FLAC = "audio/x-flac"
MP3 = "audio/mpeg"
WAV = "audio/wave"
OGG = "audio/ogg"
X_AUDIO = "audio/x-audio"
AUDIO_TYPES = [FLAC, MP3, WAV, OGG, X_AUDIO]
51 changes: 47 additions & 4 deletions src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import csv
import datetime
import json
from io import StringIO
from io import BytesIO, StringIO

import numpy as np
from sagemaker_inference import content_types, errors
from sagemaker_inference.decoder import _npy_to_numpy, _npz_to_sparse
from sagemaker_inference import errors
from sagemaker_inference.decoder import _npy_to_numpy
from sagemaker_inference.encoder import _array_to_npy

from mms.service import PredictionException
from PIL import Image
from sagemaker_huggingface_inference_toolkit import content_types


def decode_json(content):
Expand Down Expand Up @@ -51,6 +54,28 @@ def decode_csv(string_like): # type: (str) -> np.array
return {"inputs": request_list}


def decode_image(bpayload: bytearray):
"""Convert a .jpeg / .png / .tiff... object to a proper inputs dict.
Args:
bpayload (bytes): byte stream.
Returns:
(dict): dictonatry for input
"""
image = Image.open(BytesIO(bpayload)).convert("RGB")
return {"inputs": image}


def decode_audio(bpayload: bytearray):
"""Convert a .wav / .flac / .mp3 object to a proper inputs dict.
Args:
bpayload (bytes): byte stream.
Returns:
(dict): dictonatry for input
"""

return {"inputs": bytes(bpayload)}


# https://github.com/automl/SMAC3/issues/453
class _JSONEncoder(json.JSONEncoder):
"""
Expand All @@ -66,6 +91,11 @@ def default(self, obj):
return obj.tolist()
elif isinstance(obj, datetime.datetime):
return obj.__str__()
elif isinstance(obj, Image.Image):
with BytesIO() as out:
obj.save(out, format="PNG")
png_string = out.getvalue()
return base64.b64encode(png_string).decode("utf-8")
else:
return super(_JSONEncoder, self).default(obj)

Expand Down Expand Up @@ -111,8 +141,21 @@ def encode_csv(content): # type: (str) -> np.array
_decoder_map = {
content_types.NPY: _npy_to_numpy,
content_types.CSV: decode_csv,
content_types.NPZ: _npz_to_sparse,
content_types.JSON: decode_json,
# image mime-types
content_types.JPEG: decode_image,
content_types.PNG: decode_image,
content_types.TIFF: decode_image,
content_types.BMP: decode_image,
content_types.GIF: decode_image,
content_types.WEBP: decode_image,
content_types.X_IMAGE: decode_image,
# audio mime-types
content_types.FLAC: decode_audio,
content_types.MP3: decode_audio,
content_types.WAV: decode_audio,
content_types.OGG: decode_audio,
content_types.X_AUDIO: decode_audio,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
import time
from abc import ABC

from sagemaker_inference import content_types, environment, utils
from sagemaker_inference import environment, utils
from transformers.pipelines import SUPPORTED_TASKS

from mms.service import PredictionException
from sagemaker_huggingface_inference_toolkit import decoder_encoder
from sagemaker_huggingface_inference_toolkit import content_types, decoder_encoder
from sagemaker_huggingface_inference_toolkit.transformers_utils import (
_is_gpu_available,
get_pipeline,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,21 @@ def get_pipeline(task: str, device: int, model_dir: Path, **kwargs) -> Pipeline:
raise EnvironmentError(
"The task for this model is not set: Please set one: https://huggingface.co/docs#how-is-a-models-type-of-inference-api-and-widget-determined"
)
# define tokenizer or feature extractor as kwargs to load it the pipeline correctly
if task in {
"automatic-speech-recognition",
"image-segmentation",
"image-classification",
"audio-classification",
"object-detection",
"zero-shot-image-classification",
}:
kwargs["feature_extractor"] = model_dir
else:
kwargs["tokenizer"] = model_dir

hf_pipeline = pipeline(task=task, model=model_dir, tokenizer=model_dir, device=device, **kwargs)
# load pipeline
hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs)

# wrapp specific pipeline to support better ux
if task == "conversational":
Expand Down
45 changes: 43 additions & 2 deletions tests/integ/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import os

from integ.utils import (
validate_automatic_speech_recognition,
validate_classification,
validate_feature_extraction,
validate_fill_mask,
validate_ner,
validate_question_answering,
validate_summarization,
validate_text2text_generation,
validate_text_classification,
validate_text_generation,
validate_translation,
validate_zero_shot_classification,
Expand Down Expand Up @@ -53,6 +56,14 @@
"pytorch": "gpt2",
"tensorflow": "gpt2",
},
"image-classification": {
"pytorch": "google/vit-base-patch16-224",
"tensorflow": "google/vit-base-patch16-224",
},
"automatic-speech-recognition": {
"pytorch": "facebook/wav2vec2-base-100h",
"tensorflow": "facebook/wav2vec2-base-960h",
},
}

task2input = {
Expand All @@ -78,6 +89,8 @@
"inputs": "question: What is 42 context: 42 is the answer to life, the universe and everything."
},
"text-generation": {"inputs": "My name is philipp and I am"},
"image-classification": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(),
"automatic-speech-recognition": open(os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb").read(),
}

task2output = {
Expand All @@ -98,6 +111,16 @@
"feature-extraction": None,
"fill-mask": None,
"text-generation": None,
"image-classification": [
{"score": 0.8858247399330139, "label": "tiger, Panthera tigris"},
{"score": 0.10940514504909515, "label": "tiger cat"},
{"score": 0.0006216464680619538, "label": "jaguar, panther, Panthera onca, Felis onca"},
{"score": 0.0004262699221726507, "label": "dhole, Cuon alpinus"},
{"score": 0.00030842673731967807, "label": "lion, king of beasts, Panthera leo"},
],
"automatic-speech-recognition": {
"text": "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP OAUDIENCES IN DROFTY SCHOOL ROOMS DAY AFTER DAY FOR A FORT NIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS"
},
}

task2performance = {
Expand Down Expand Up @@ -181,10 +204,26 @@
"average_request_time": 3,
},
},
"image-classification": {
"cpu": {
"average_request_time": 4,
},
"gpu": {
"average_request_time": 1,
},
},
"automatic-speech-recognition": {
"cpu": {
"average_request_time": 6,
},
"gpu": {
"average_request_time": 6,
},
},
}

task2validation = {
"text-classification": validate_text_classification,
"text-classification": validate_classification,
"zero-shot-classification": validate_zero_shot_classification,
"feature-extraction": validate_feature_extraction,
"ner": validate_ner,
Expand All @@ -194,4 +233,6 @@
"translation_xx_to_yy": validate_translation,
"text2text-generation": validate_text2text_generation,
"text-generation": validate_text_generation,
"image-classification": validate_classification,
"automatic-speech-recognition": validate_automatic_speech_recognition,
}
Loading

0 comments on commit 2f1fae5

Please sign in to comment.