blob: 17f1e962b9a0a994c9cfbc88b14ba97e699768ab [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Test cases for :module:`artifact_service_client`."""
# pytype: skip-file
import contextlib
import io
import threading
import unittest
from urllib.parse import quote
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_artifact_api_pb2
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.portability import artifact_service
from apache_beam.utils import proto_utils
class InMemoryFileManager(object):
def __init__(self, contents=()):
self._contents = dict(contents)
def get(self, path):
return self._contents[path]
def file_reader(self, path):
return io.BytesIO(self._contents[path])
def file_writer(self, name):
path = 'prefix:' + name
@contextlib.contextmanager
def writable():
buffer = io.BytesIO()
yield buffer
buffer.seek(0)
self._contents[path] = buffer.read()
return writable(), path
class ArtifactServiceTest(unittest.TestCase):
def file_artifact(self, path):
return beam_runner_api_pb2.ArtifactInformation(
type_urn=common_urns.artifact_types.FILE.urn,
type_payload=beam_runner_api_pb2.ArtifactFilePayload(
path=path).SerializeToString())
def embedded_artifact(self, data, name=None):
return beam_runner_api_pb2.ArtifactInformation(
type_urn=common_urns.artifact_types.EMBEDDED.urn,
type_payload=beam_runner_api_pb2.EmbeddedFilePayload(
data=data).SerializeToString(),
role_urn=common_urns.artifact_roles.STAGING_TO.urn if name else None,
role_payload=beam_runner_api_pb2.ArtifactStagingToRolePayload(
staged_name=name).SerializeToString() if name else None)
def test_file_retrieval(self):
file_manager = InMemoryFileManager({
'path/to/a': b'a', 'path/to/b': b'b' * 37
})
retrieval_service = artifact_service.ArtifactRetrievalService(
file_manager.file_reader, chunk_size=10)
dep_a = self.file_artifact('path/to/a')
self.assertEqual(
retrieval_service.ResolveArtifacts(
beam_artifact_api_pb2.ResolveArtifactsRequest(artifacts=[dep_a])),
beam_artifact_api_pb2.ResolveArtifactsResponse(replacements=[dep_a]))
self.assertEqual(
list(
retrieval_service.GetArtifact(
beam_artifact_api_pb2.GetArtifactRequest(artifact=dep_a))),
[beam_artifact_api_pb2.GetArtifactResponse(data=b'a')])
self.assertEqual(
list(
retrieval_service.GetArtifact(
beam_artifact_api_pb2.GetArtifactRequest(
artifact=self.file_artifact('path/to/b')))),
[
beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10),
beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10),
beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 10),
beam_artifact_api_pb2.GetArtifactResponse(data=b'b' * 7)
])
def test_embedded_retrieval(self):
retrieval_service = artifact_service.ArtifactRetrievalService(None)
embedded_dep = self.embedded_artifact(b'some_data')
self.assertEqual(
list(
retrieval_service.GetArtifact(
beam_artifact_api_pb2.GetArtifactRequest(
artifact=embedded_dep))),
[beam_artifact_api_pb2.GetArtifactResponse(data=b'some_data')])
def test_url_retrieval(self):
retrieval_service = artifact_service.ArtifactRetrievalService(None)
url_dep = beam_runner_api_pb2.ArtifactInformation(
type_urn=common_urns.artifact_types.URL.urn,
type_payload=beam_runner_api_pb2.ArtifactUrlPayload(
url='file:' + quote(__file__)).SerializeToString())
content = b''.join([
r.data for r in retrieval_service.GetArtifact(
beam_artifact_api_pb2.GetArtifactRequest(artifact=url_dep))
])
with open(__file__, 'rb') as fin:
self.assertEqual(content, fin.read())
def test_push_artifacts(self):
unresolved = beam_runner_api_pb2.ArtifactInformation(type_urn='unresolved')
resolved_a = self.embedded_artifact(data=b'a', name='a.txt')
resolved_b = self.embedded_artifact(data=b'bb', name='b.txt')
dep_big = self.embedded_artifact(data=b'big ' * 100, name='big.txt')
class TestArtifacts(object):
def ResolveArtifacts(self, request):
replacements = []
for artifact in request.artifacts:
if artifact.type_urn == 'unresolved':
replacements += [resolved_a, resolved_b]
else:
replacements.append(artifact)
return beam_artifact_api_pb2.ResolveArtifactsResponse(
replacements=replacements)
def GetArtifact(self, request):
if request.artifact.type_urn == common_urns.artifact_types.EMBEDDED.urn:
content = proto_utils.parse_Bytes(
request.artifact.type_payload,
beam_runner_api_pb2.EmbeddedFilePayload).data
for k in range(0, len(content), 13):
yield beam_artifact_api_pb2.GetArtifactResponse(
data=content[k:k + 13])
else:
raise NotImplementedError
file_manager = InMemoryFileManager()
server = artifact_service.ArtifactStagingService(file_manager.file_writer)
server.register_job('staging_token', {'env': [unresolved, dep_big]})
# "Push" artifacts as if from a client.
t = threading.Thread(
target=lambda: artifact_service.offer_artifacts(
server, TestArtifacts(), 'staging_token'))
t.daemon = True
t.start()
resolved_deps = server.resolved_deps('staging_token', timeout=5)['env']
expected = {
'a.txt': b'a',
'b.txt': b'bb',
'big.txt': b'big ' * 100,
}
for dep in resolved_deps:
self.assertEqual(dep.type_urn, common_urns.artifact_types.FILE.urn)
self.assertEqual(dep.role_urn, common_urns.artifact_roles.STAGING_TO.urn)
type_payload = proto_utils.parse_Bytes(
dep.type_payload, beam_runner_api_pb2.ArtifactFilePayload)
role_payload = proto_utils.parse_Bytes(
dep.role_payload, beam_runner_api_pb2.ArtifactStagingToRolePayload)
self.assertTrue(
type_payload.path.endswith(role_payload.staged_name),
type_payload.path)
self.assertEqual(
file_manager.get(type_payload.path),
expected.pop(role_payload.staged_name))
self.assertEqual(expected, {})
if __name__ == '__main__':
unittest.main()