blob: 1ba96027e006cd04862fe1166b6609b920d9eb5a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Implementation of an Artifact{Staging,Retrieval}Service.
The staging service here can be backed by any beam filesystem.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import sys
import threading
import zipfile
from google.protobuf import json_format
from apache_beam.io import filesystems
from apache_beam.portability.api import beam_artifact_api_pb2
from apache_beam.portability.api import beam_artifact_api_pb2_grpc
class AbstractArtifactService(
beam_artifact_api_pb2_grpc.ArtifactStagingServiceServicer,
beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceServicer):
_DEFAULT_CHUNK_SIZE = 2 << 20 # 2mb
def __init__(self, root, chunk_size=None):
self._root = root
self._chunk_size = chunk_size or self._DEFAULT_CHUNK_SIZE
def _sha256(self, string):
return hashlib.sha256(string.encode('utf-8')).hexdigest()
def _join(self, *args):
raise NotImplementedError(type(self))
def _dirname(self, path):
raise NotImplementedError(type(self))
def _temp_path(self, path):
return path + '.tmp'
def _open(self, path, mode):
raise NotImplementedError(type(self))
def _rename(self, src, dest):
raise NotImplementedError(type(self))
def _delete(self, path):
raise NotImplementedError(type(self))
def _artifact_path(self, retrieval_token, name):
return self._join(self._dirname(retrieval_token), self._sha256(name))
def _manifest_path(self, retrieval_token):
return retrieval_token
def _get_manifest_proxy(self, retrieval_token):
with self._open(self._manifest_path(retrieval_token), 'r') as fin:
return json_format.Parse(
fin.read().decode('utf-8'), beam_artifact_api_pb2.ProxyManifest())
def retrieval_token(self, staging_session_token):
return self._join(
self._root, self._sha256(staging_session_token), 'MANIFEST')
def PutArtifact(self, request_iterator, context=None):
first = True
for request in request_iterator:
if first:
first = False
metadata = request.metadata.metadata
retrieval_token = self.retrieval_token(
request.metadata.staging_session_token)
artifact_path = self._artifact_path(retrieval_token, metadata.name)
temp_path = self._temp_path(artifact_path)
fout = self._open(temp_path, 'w')
hasher = hashlib.sha256()
else:
hasher.update(request.data.data)
fout.write(request.data.data)
fout.close()
data_hash = hasher.hexdigest()
if metadata.sha256 and metadata.sha256 != data_hash:
self._delete(temp_path)
raise ValueError('Bad metadata hash: %s vs %s' % (
metadata.sha256, data_hash))
self._rename(temp_path, artifact_path)
return beam_artifact_api_pb2.PutArtifactResponse()
def CommitManifest(self, request, context=None):
retrieval_token = self.retrieval_token(request.staging_session_token)
proxy_manifest = beam_artifact_api_pb2.ProxyManifest(
manifest=request.manifest,
location=[
beam_artifact_api_pb2.ProxyManifest.Location(
name=metadata.name,
uri=self._artifact_path(retrieval_token, metadata.name))
for metadata in request.manifest.artifact])
with self._open(self._manifest_path(retrieval_token), 'w') as fout:
fout.write(json_format.MessageToJson(proxy_manifest).encode('utf-8'))
return beam_artifact_api_pb2.CommitManifestResponse(
retrieval_token=retrieval_token)
def GetManifest(self, request, context=None):
return beam_artifact_api_pb2.GetManifestResponse(
manifest=self._get_manifest_proxy(request.retrieval_token).manifest)
def GetArtifact(self, request, context=None):
for artifact in self._get_manifest_proxy(request.retrieval_token).location:
if artifact.name == request.name:
with self._open(artifact.uri, 'r') as fin:
# This value is not emitted, but lets us yield a single empty
# chunk on an empty file.
chunk = True
while chunk:
chunk = fin.read(self._chunk_size)
yield beam_artifact_api_pb2.ArtifactChunk(data=chunk)
break
else:
raise ValueError('Unknown artifact: %s' % request.name)
class ZipFileArtifactService(AbstractArtifactService):
"""Stores artifacts in a zip file.
This is particularly useful for storing artifacts as part of an UberJar for
submitting to an upstream runner's cluster.
Writing to zip files requires Python 3.6+.
"""
def __init__(self, path, internal_root, chunk_size=None):
if sys.version_info < (3, 6):
raise RuntimeError(
'Writing to zip files requires Python 3.6+, '
'but current version is %s' % sys.version)
super(ZipFileArtifactService, self).__init__(internal_root, chunk_size)
self._zipfile = zipfile.ZipFile(path, 'a')
self._lock = threading.Lock()
def _join(self, *args):
return '/'.join(args)
def _dirname(self, path):
return path.rsplit('/', 1)[0]
def _temp_path(self, path):
return path # ZipFile offers no move operation.
def _rename(self, src, dest):
assert src == dest
def _delete(self, path):
# ZipFile offers no delete operation: https://bugs.python.org/issue6818
pass
def _open(self, path, mode):
if path.startswith('/'):
raise ValueError(
'ZIP file entry %s invalid: '
'path must not contain a leading slash.' % path)
return self._zipfile.open(path, mode, force_zip64=True)
def PutArtifact(self, request_iterator, context=None):
# ZipFile only supports one writable channel at a time.
with self._lock:
return super(
ZipFileArtifactService, self).PutArtifact(request_iterator, context)
def CommitManifest(self, request, context=None):
# ZipFile only supports one writable channel at a time.
with self._lock:
return super(
ZipFileArtifactService, self).CommitManifest(request, context)
def GetManifest(self, request, context=None):
# ZipFile appears to not be threadsafe on some platforms.
with self._lock:
return super(ZipFileArtifactService, self).GetManifest(request, context)
def GetArtifact(self, request, context=None):
# ZipFile appears to not be threadsafe on some platforms.
with self._lock:
for chunk in super(ZipFileArtifactService, self).GetArtifact(
request, context):
yield chunk
def close(self):
self._zipfile.close()
class BeamFilesystemArtifactService(AbstractArtifactService):
def _join(self, *args):
return filesystems.FileSystems.join(*args)
def _dirname(self, path):
return filesystems.FileSystems.split(path)[0]
def _rename(self, src, dest):
filesystems.FileSystems.rename([src], [dest])
def _delete(self, path):
filesystems.FileSystems.delete([path])
def _open(self, path, mode='r'):
dir = self._dirname(path)
if not filesystems.FileSystems.exists(dir):
try:
filesystems.FileSystems.mkdirs(dir)
except Exception:
pass
if 'w' in mode:
return filesystems.FileSystems.create(path)
else:
return filesystems.FileSystems.open(path)