| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| """Test cases for :module:`artifact_service_client`.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import collections |
| import hashlib |
| import os |
| import random |
| import shutil |
| import sys |
| import tempfile |
| import time |
| import unittest |
| |
| import grpc |
| |
| from apache_beam.portability.api import beam_artifact_api_pb2 |
| from apache_beam.portability.api import beam_artifact_api_pb2_grpc |
| from apache_beam.runners.portability import artifact_service |
| from apache_beam.utils.thread_pool_executor import UnboundedThreadPoolExecutor |
| |
| |
| class AbstractArtifactServiceTest(unittest.TestCase): |
| |
| def setUp(self): |
| self._staging_dir = tempfile.mkdtemp() |
| self._service = self.create_service(self._staging_dir) |
| |
| def tearDown(self): |
| if self._staging_dir: |
| shutil.rmtree(self._staging_dir) |
| |
| def create_service(self, staging_dir): |
| raise NotImplementedError(type(self)) |
| |
| @staticmethod |
| def put_metadata(staging_token, name, sha256=None): |
| return beam_artifact_api_pb2.PutArtifactRequest( |
| metadata=beam_artifact_api_pb2.PutArtifactMetadata( |
| staging_session_token=staging_token, |
| metadata=beam_artifact_api_pb2.ArtifactMetadata( |
| name=name, |
| sha256=sha256))) |
| |
| @staticmethod |
| def put_data(chunk): |
| return beam_artifact_api_pb2.PutArtifactRequest( |
| data=beam_artifact_api_pb2.ArtifactChunk( |
| data=chunk)) |
| |
| @staticmethod |
| def retrieve_artifact(retrieval_service, retrieval_token, name): |
| return b''.join(chunk.data for chunk in retrieval_service.GetArtifact( |
| beam_artifact_api_pb2.GetArtifactRequest( |
| retrieval_token=retrieval_token, |
| name=name))) |
| |
| def test_basic(self): |
| self._run_staging(self._service, self._service) |
| |
| def test_with_grpc(self): |
| server = grpc.server(UnboundedThreadPoolExecutor()) |
| try: |
| beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server( |
| self._service, server) |
| beam_artifact_api_pb2_grpc.add_ArtifactRetrievalServiceServicer_to_server( |
| self._service, server) |
| port = server.add_insecure_port('[::]:0') |
| server.start() |
| channel = grpc.insecure_channel('localhost:%d' % port) |
| self._run_staging( |
| beam_artifact_api_pb2_grpc.ArtifactStagingServiceStub( |
| channel), |
| beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceStub( |
| channel)) |
| channel.close() |
| finally: |
| server.stop(1) |
| |
| def _run_staging(self, staging_service, retrieval_service): |
| |
| staging_session_token = '/session_staging_token \n\0*' |
| |
| # First stage some files. |
| staging_service.PutArtifact(iter([ |
| self.put_metadata(staging_session_token, 'name'), |
| self.put_data(b'data')])) |
| |
| staging_service.PutArtifact(iter([ |
| self.put_metadata(staging_session_token, 'many_chunks'), |
| self.put_data(b'a'), |
| self.put_data(b'b'), |
| self.put_data(b'c')])) |
| |
| staging_service.PutArtifact(iter([ |
| self.put_metadata(staging_session_token, 'long'), |
| self.put_data(b'a' * 1000)])) |
| |
| staging_service.PutArtifact(iter([ |
| self.put_metadata(staging_session_token, |
| 'with_hash', |
| hashlib.sha256(b'data...').hexdigest()), |
| self.put_data(b'data'), |
| self.put_data(b'...')])) |
| |
| with self.assertRaises(Exception): |
| staging_service.PutArtifact(iter([ |
| self.put_metadata(staging_session_token, |
| 'bad_hash', |
| 'bad_hash'), |
| self.put_data(b'data')])) |
| |
| manifest = beam_artifact_api_pb2.Manifest(artifact=[ |
| beam_artifact_api_pb2.ArtifactMetadata(name='name'), |
| beam_artifact_api_pb2.ArtifactMetadata(name='many_chunks'), |
| beam_artifact_api_pb2.ArtifactMetadata(name='long'), |
| beam_artifact_api_pb2.ArtifactMetadata(name='with_hash'), |
| ]) |
| |
| retrieval_token = staging_service.CommitManifest( |
| beam_artifact_api_pb2.CommitManifestRequest( |
| staging_session_token=staging_session_token, |
| manifest=manifest)).retrieval_token |
| |
| # Now attempt to retrieve them. |
| |
| retrieved_manifest = retrieval_service.GetManifest( |
| beam_artifact_api_pb2.GetManifestRequest( |
| retrieval_token=retrieval_token)).manifest |
| self.assertEqual(manifest, retrieved_manifest) |
| |
| self.assertEqual( |
| b'data', |
| self.retrieve_artifact(retrieval_service, retrieval_token, 'name')) |
| |
| self.assertEqual( |
| b'abc', |
| self.retrieve_artifact( |
| retrieval_service, retrieval_token, 'many_chunks')) |
| |
| self.assertEqual( |
| b'a' * 1000, |
| self.retrieve_artifact(retrieval_service, retrieval_token, 'long')) |
| |
| self.assertEqual( |
| b'data...', |
| self.retrieve_artifact(retrieval_service, retrieval_token, 'with_hash')) |
| |
| with self.assertRaises(Exception): |
| self.retrieve_artifact(retrieval_service, retrieval_token, 'bad_hash') |
| |
| with self.assertRaises(Exception): |
| self.retrieve_artifact(retrieval_service, retrieval_token, 'missing') |
| |
| def test_concurrent_requests(self): |
| |
| num_sessions = 7 |
| artifacts = collections.defaultdict(list) |
| |
| def name(index): |
| # Overlapping names across sessions. |
| return 'name%d' % (index // num_sessions) |
| |
| def session(index): |
| return 'session%d' % (index % num_sessions) |
| |
| def delayed_data(data, index, max_msecs=1): |
| time.sleep(max_msecs / 1000.0 * random.random()) |
| return ('%s_%d' % (data, index)).encode('ascii') |
| |
| def put(index): |
| artifacts[session(index)].append( |
| beam_artifact_api_pb2.ArtifactMetadata(name=name(index))) |
| self._service.PutArtifact([ |
| self.put_metadata(session(index), name(index)), |
| self.put_data(delayed_data('a', index)), |
| self.put_data(delayed_data('b' * 20, index, 2))]) |
| return session(index) |
| |
| def commit(session): |
| return session, self._service.CommitManifest( |
| beam_artifact_api_pb2.CommitManifestRequest( |
| staging_session_token=session, |
| manifest=beam_artifact_api_pb2.Manifest( |
| artifact=artifacts[session]))).retrieval_token |
| |
| def check(index): |
| self.assertEqual( |
| delayed_data('a', index) + delayed_data('b' * 20, index, 2), |
| self.retrieve_artifact( |
| self._service, tokens[session(index)], name(index))) |
| |
| # pylint: disable=range-builtin-not-iterating |
| pool = UnboundedThreadPoolExecutor() |
| sessions = set(pool.map(put, range(100))) |
| tokens = dict(pool.map(commit, sessions)) |
| # List forces materialization. |
| _ = list(pool.map(check, range(100))) |
| |
| |
| @unittest.skipIf(sys.version_info < (3, 6), "Requires Python 3.6+") |
| class ZipFileArtifactServiceTest(AbstractArtifactServiceTest): |
| def create_service(self, staging_dir): |
| return artifact_service.ZipFileArtifactService( |
| os.path.join(staging_dir, 'test.zip'), 'root', chunk_size=10) |
| |
| |
| class BeamFilesystemArtifactServiceTest(AbstractArtifactServiceTest): |
| def create_service(self, staging_dir): |
| return artifact_service.BeamFilesystemArtifactService( |
| staging_dir, chunk_size=10) |
| |
| |
| # Don't discover/test the abstract base class. |
| del AbstractArtifactServiceTest |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |