blob: 011ed96b9ae386f869a6f959e024ec0963fa580f [file]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
pipelines:
- pipeline:
type: chain
transforms:
- type: Create
config:
elements:
- text: "I love Apache Beam!"
- text: "I hate this error."
- type: RunInference
config:
model_handler:
type: "HuggingFacePipelineModelHandler"
config:
task: "text-classification"
inference_fn:
callable: |
def real_inference(batch, pipeline, inference_args):
predictions = pipeline(batch, **inference_args)
# If it's a single dictionary (batch size of 1), wrap it in a list
if isinstance(predictions, dict):
predictions = [predictions]
return {
'label': [p['label'] for p in predictions],
'score': [p['score'] for p in predictions]
}
preprocess:
callable: 'lambda x: x.text'
- type: MapToFields
config:
language: python
fields:
text: text
sentiment:
callable: 'lambda x: x.inference.inference["label"]'
- type: AssertEqual
config:
elements:
- text: "I love Apache Beam!"
sentiment: "POSITIVE"
- text: "I hate this error."
sentiment: "NEGATIVE"
options:
yaml_experimental_features: ['ML']