blob: 43a0c3903074446a3b19a73a53642646756decbe [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
import unittest
from typing import NamedTuple
from typing import Union
import pytest
import urllib3
import apache_beam as beam
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import BeamAssertException
# pylint: disable=ungrouped-imports
try:
from apache_beam.io.requestresponse import UserCodeExecutionException
from apache_beam.io.requestresponse import UserCodeQuotaException
from apache_beam.io.requestresponse_it_test import _PAYLOAD
from apache_beam.io.requestresponse_it_test import EchoITOptions
from apache_beam.transforms.enrichment import Enrichment
from apache_beam.transforms.enrichment import EnrichmentSourceHandler
except ImportError:
raise unittest.SkipTest('RequestResponseIO dependencies are not installed.')
class Request(NamedTuple):
id: str
payload: bytes
def _custom_join(left, right):
"""custom_join returns the id and resp_payload along with a timestamp"""
right['timestamp'] = time.time()
return beam.Row(**right)
class SampleHTTPEnrichment(EnrichmentSourceHandler[Request, beam.Row]):
"""Implements ``EnrichmentSourceHandler`` to call the ``EchoServiceGrpc``'s
HTTP handler.
"""
def __init__(self, url: str):
self.url = url + '/v1/echo' # append path to the mock API.
def __call__(self, request: Request, *args, **kwargs):
"""Overrides ``Caller``'s call method invoking the
``EchoServiceGrpc``'s HTTP handler with an `dict`, returning
either a successful ``tuple[dict,dict]`` or throwing either a
``UserCodeExecutionException``, ``UserCodeTimeoutException``,
or a ``UserCodeQuotaException``.
"""
try:
resp = urllib3.request(
"POST",
self.url,
json={
"id": request.id, "payload": str(request.payload, 'utf-8')
},
retries=False)
if resp.status < 300:
resp_body = resp.json()
resp_id = resp_body['id']
payload = resp_body['payload']
return (
request, beam.Row(id=resp_id, resp_payload=bytes(payload, 'utf-8')))
if resp.status == 429: # Too Many Requests
raise UserCodeQuotaException(resp.reason)
elif resp.status != 200:
raise UserCodeExecutionException(resp.status, resp.reason, request)
except urllib3.exceptions.HTTPError as e:
raise UserCodeExecutionException(e)
class ValidateFields(beam.DoFn):
"""ValidateFields validates if a PCollection of `beam.Row`
has certain fields."""
def __init__(self, n_fields: int, fields: list[str]):
self.n_fields = n_fields
self._fields = fields
def process(self, element: beam.Row, *args, **kwargs):
element_dict = element.as_dict()
if len(element_dict.keys()) != self.n_fields:
raise BeamAssertException(
"Expected %d fields in enriched PCollection:"
" id, payload and resp_payload" % self.n_fields)
for field in self._fields:
if field not in element_dict or element_dict[field] is None:
raise BeamAssertException(f"Expected a not None field: {field}")
@pytest.mark.uses_mock_api
class TestEnrichment(unittest.TestCase):
options: Union[EchoITOptions, None] = None
client: Union[SampleHTTPEnrichment, None] = None
@classmethod
def setUpClass(cls) -> None:
cls.options = EchoITOptions()
http_endpoint_address = cls.options.http_endpoint_address
if not http_endpoint_address or http_endpoint_address == '':
raise unittest.SkipTest('HTTP_ENDPOINT_ADDRESS is required.')
cls.client = SampleHTTPEnrichment(http_endpoint_address)
@classmethod
def _get_client_and_options(
cls) -> tuple[SampleHTTPEnrichment, EchoITOptions]:
assert cls.options is not None
assert cls.client is not None
return cls.client, cls.options
def test_http_enrichment(self):
"""Tests Enrichment Transform against the Mock-API HTTP endpoint
with the default cross join."""
client, options = TestEnrichment._get_client_and_options()
req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD)
fields = ['id', 'payload', 'resp_payload']
with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (
test_pipeline
| 'Create PCollection' >> beam.Create([req])
| 'Enrichment Transform' >> Enrichment(client)
| 'Assert Fields' >> beam.ParDo(
ValidateFields(len(fields), fields=fields)))
def test_http_enrichment_custom_join(self):
"""Tests Enrichment Transform against the Mock-API HTTP endpoint
with a custom join function."""
client, options = TestEnrichment._get_client_and_options()
req = Request(id=options.never_exceed_quota_id, payload=_PAYLOAD)
fields = ['id', 'resp_payload', 'timestamp']
with TestPipeline(is_integration_test=True) as test_pipeline:
_ = (
test_pipeline
| 'Create PCollection' >> beam.Create([req])
| 'Enrichment Transform' >> Enrichment(client, join_fn=_custom_join)
| 'Assert Fields' >> beam.ParDo(
ValidateFields(len(fields), fields=fields)))
if __name__ == '__main__':
unittest.main()