blob: 4299783bb237ddecfa992fc71f0fbda7c1fe9f76 [file] [log] [blame]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import grpc
from .._protos.google.rpc import code_pb2
from .._protos.build.buildgrid import local_cas_pb2
from .._remote import BaseRemote
from .._exceptions import CASRemoteError
# The default limit for gRPC messages is 4 MiB.
# Limit payload to 1 MiB to leave sufficient headroom for metadata.
_MAX_PAYLOAD_BYTES = 1024 * 1024
# How many digests to put in a single gRPC message.
# A 256-bit hash requires 64 bytes of space (hexadecimal encoding).
# 80 bytes provide sufficient space for hash, size, and protobuf overhead.
_MAX_DIGESTS = _MAX_PAYLOAD_BYTES / 80
class BlobNotFound(CASRemoteError):
def __init__(self, blob, msg):
self.blob = blob
super().__init__(msg)
# Represents a single remote CAS cache.
#
class CASRemote(BaseRemote):
def __init__(self, spec, cascache, **kwargs):
super().__init__(spec, **kwargs)
self.cascache = cascache
self.local_cas_instance_name = None
# check_remote
# _configure_protocols():
#
# Configure remote-specific protocols. This method should *never*
# be called outside of init().
#
def _configure_protocols(self):
if not self.spec:
# Remote cache (handled by default instance in casd)
self.local_cas_instance_name = ""
return
local_cas = self.cascache.get_local_cas()
request = local_cas_pb2.GetInstanceNameForRemotesRequest()
cas_endpoint = request.content_addressable_storage
cas_endpoint.url = self.spec.url
if self.spec.instance_name:
cas_endpoint.instance_name = self.spec.instance_name
if self.spec.server_cert:
cas_endpoint.server_cert = self.spec.server_cert
if self.spec.client_key:
cas_endpoint.client_key = self.spec.client_key
if self.spec.client_cert:
cas_endpoint.client_cert = self.spec.client_cert
try:
response = local_cas.GetInstanceNameForRemotes(request)
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.UNIMPLEMENTED:
raise CASRemoteError(
"Unsupported buildbox-casd version: GetInstanceNameForRemotes unimplemented"
) from e
raise
self.local_cas_instance_name = response.instance_name
# push_message():
#
# Push the given protobuf message to a remote.
#
# Args:
# message (Message): A protobuf message to push.
#
# Raises:
# (CASRemoteError): if there was an error
#
def push_message(self, message):
message_buffer = message.SerializeToString()
self.init()
return self.cascache.add_object(buffer=message_buffer, instance_name=self.local_cas_instance_name)
# Represents a batch of blobs queued for fetching.
#
class _CASBatchRead:
def __init__(self, remote):
self._remote = remote
self._requests = []
self._request = None
self._sent = False
def add(self, digest):
assert not self._sent
if not self._request or len(self._request.blob_digests) >= _MAX_DIGESTS:
self._request = local_cas_pb2.FetchMissingBlobsRequest()
self._request.instance_name = self._remote.local_cas_instance_name
self._requests.append(self._request)
request_digest = self._request.blob_digests.add()
request_digest.CopyFrom(digest)
def send(self, *, missing_blobs=None):
assert not self._sent
self._sent = True
if not self._requests:
return
local_cas = self._remote.cascache.get_local_cas()
for request in self._requests:
batch_response = local_cas.FetchMissingBlobs(request)
for response in batch_response.responses:
if response.status.code == code_pb2.NOT_FOUND:
if missing_blobs is None:
raise BlobNotFound(
response.digest.hash,
"Failed to download blob {}: {}".format(response.digest.hash, response.status.code),
)
missing_blobs.append(response.digest)
if response.status.code != code_pb2.OK:
raise CASRemoteError(
"Failed to download blob {}: {}".format(response.digest.hash, response.status.code)
)
if response.digest.size_bytes != len(response.data):
raise CASRemoteError(
"Failed to download blob {}: expected {} bytes, received {} bytes".format(
response.digest.hash, response.digest.size_bytes, len(response.data)
)
)
# Represents a batch of blobs queued for upload.
#
class _CASBatchUpdate:
def __init__(self, remote):
self._remote = remote
self._requests = []
self._request = None
self._sent = False
def add(self, digest):
assert not self._sent
if not self._request or len(self._request.blob_digests) >= _MAX_DIGESTS:
self._request = local_cas_pb2.UploadMissingBlobsRequest()
self._request.instance_name = self._remote.local_cas_instance_name
self._requests.append(self._request)
request_digest = self._request.blob_digests.add()
request_digest.CopyFrom(digest)
def send(self):
assert not self._sent
self._sent = True
if not self._requests:
return
local_cas = self._remote.cascache.get_local_cas()
for request in self._requests:
batch_response = local_cas.UploadMissingBlobs(request)
for response in batch_response.responses:
if response.status.code != code_pb2.OK:
if response.status.code == code_pb2.RESOURCE_EXHAUSTED:
reason = "cache-too-full"
else:
reason = None
raise CASRemoteError(
"Failed to upload blob {}: {}".format(response.digest.hash, response.status.code),
reason=reason,
)