_cas: Add support for remote cache
diff --git a/src/buildstream/_cas/cascache.py b/src/buildstream/_cas/cascache.py
index b80460a..609feb5 100644
--- a/src/buildstream/_cas/cascache.py
+++ b/src/buildstream/_cas/cascache.py
@@ -37,7 +37,7 @@
 from .._exceptions import CASCacheError
 
 from .casdprocessmanager import CASDProcessManager
-from .casremote import _CASBatchRead, _CASBatchUpdate, BlobNotFound
+from .casremote import CASRemote, _CASBatchRead, _CASBatchUpdate, BlobNotFound
 
 _BUFFER_SIZE = 65536
 
@@ -69,6 +69,7 @@
         *,
         casd=True,
         cache_quota=None,
+        remote_cache_spec=None,
         protect_session_blobs=True,
         log_level=CASLogLevel.WARNING,
         log_directory=None
@@ -80,18 +81,25 @@
         self._cache_usage_monitor = None
         self._cache_usage_monitor_forbidden = False
 
+        self._remote_cache = bool(remote_cache_spec)
+
         self._casd_process_manager = None
         self._casd_channel = None
         if casd:
             assert log_directory is not None, "log_directory is required when casd is True"
             log_dir = os.path.join(log_directory, "_casd")
             self._casd_process_manager = CASDProcessManager(
-                path, log_dir, log_level, cache_quota, protect_session_blobs
+                path, log_dir, log_level, cache_quota, remote_cache_spec, protect_session_blobs
             )
 
             self._casd_channel = self._casd_process_manager.create_channel()
             self._cache_usage_monitor = _CASCacheUsageMonitor(self._casd_channel)
             self._cache_usage_monitor.start()
+        else:
+            assert not self._remote_cache
+
+        self._default_remote = CASRemote(None, self)
+        self._default_remote.init()
 
     # get_cas():
     #
@@ -142,6 +150,9 @@
             self._casd_process_manager.release_resources(messenger)
             self._casd_process_manager = None
 
+    def get_default_remote(self):
+        return self._default_remote
+
     # contains_files():
     #
     # Check whether file digests exist in the local CAS cache
@@ -168,13 +179,17 @@
     def contains_directory(self, digest, *, with_files):
         local_cas = self.get_local_cas()
 
+        # Without a remote cache, `FetchTree` simply checks the local cache.
         request = local_cas_pb2.FetchTreeRequest()
         request.root_digest.CopyFrom(digest)
-        request.fetch_file_blobs = with_files
+        # Always fetch Directory protos as they are needed to enumerate subdirectories and files.
+        # Don't implicitly fetch file blobs from the remote cache as we don't need them.
+        request.fetch_file_blobs = with_files and not self._remote_cache
 
         try:
             local_cas.FetchTree(request)
-            return True
+            if not self._remote_cache:
+                return True
         except grpc.RpcError as e:
             if e.code() == grpc.StatusCode.NOT_FOUND:
                 return False
@@ -182,6 +197,10 @@
                 raise CASCacheError("Unsupported buildbox-casd version: FetchTree unimplemented") from e
             raise
 
+        # Check whether everything is available in the remote cache.
+        missing_blobs = self.missing_blobs_for_directory(digest, remote=self._default_remote)
+        return not missing_blobs
+
     # checkout():
     #
     # Checkout the specified directory digest.
@@ -191,7 +210,17 @@
     #     tree (Digest): The directory digest to extract
     #     can_link (bool): Whether we can create hard links in the destination
     #
-    def checkout(self, dest, tree, *, can_link=False):
+    def checkout(self, dest, tree, *, can_link=False, _fetch=True):
+        if _fetch and self._remote_cache:
+            # We need the files in the local cache
+            local_cas = self.get_local_cas()
+
+            request = local_cas_pb2.FetchTreeRequest()
+            request.root_digest.CopyFrom(tree)
+            request.fetch_file_blobs = True
+
+            local_cas.FetchTree(request)
+
         os.makedirs(dest, exist_ok=True)
 
         directory = remote_execution_pb2.Directory()
@@ -229,7 +258,7 @@
 
         for dirnode in directory.directories:
             fullpath = os.path.join(dest, dirnode.name)
-            self.checkout(fullpath, dirnode.digest, can_link=can_link)
+            self.checkout(fullpath, dirnode.digest, can_link=can_link, _fetch=False)
 
         for symlinknode in directory.symlinks:
             # symlink
@@ -286,6 +315,11 @@
 
         objpath = self.objpath(digest)
 
+        if self._remote_cache and not os.path.exists(objpath):
+            batch = _CASBatchRead(self._default_remote)
+            batch.add(digest)
+            batch.send()
+
         return open(objpath, mode=mode)
 
     # add_object():
@@ -399,7 +433,7 @@
         if tree_response.status.code == code_pb2.RESOURCE_EXHAUSTED:
             raise CASCacheError("Cache too full", reason="cache-too-full")
         if tree_response.status.code != code_pb2.OK:
-            raise CASCacheError("Failed to capture tree {}: {}".format(path, tree_response.status.code))
+            raise CASCacheError("Failed to capture tree {}: {}".format(path, tree_response.status))
 
         treepath = self.objpath(tree_response.tree_digest)
         tree = remote_execution_pb2.Tree()
@@ -469,10 +503,20 @@
     # Generator that returns the Digests of all blobs in the tree specified by
     # the Digest of the toplevel Directory object.
     #
-    def required_blobs_for_directory(self, directory_digest, *, excluded_subdirs=None):
+    def required_blobs_for_directory(self, directory_digest, *, excluded_subdirs=None, _fetch_tree=True):
         if not excluded_subdirs:
             excluded_subdirs = []
 
+        if self._remote_cache and _fetch_tree:
+            # Ensure we have the directory protos in the local cache
+            local_cas = self.get_local_cas()
+
+            request = local_cas_pb2.FetchTreeRequest()
+            request.root_digest.CopyFrom(directory_digest)
+            request.fetch_file_blobs = False
+
+            local_cas.FetchTree(request)
+
         # parse directory, and recursively add blobs
 
         yield directory_digest
@@ -487,7 +531,7 @@
 
         for dirnode in directory.directories:
             if dirnode.name not in excluded_subdirs:
-                yield from self.required_blobs_for_directory(dirnode.digest)
+                yield from self.required_blobs_for_directory(dirnode.digest, _fetch_tree=False)
 
     ################################################
     #             Local Private Methods            #
@@ -569,6 +613,10 @@
     # Returns: The Digests of the blobs that were not available on the remote CAS
     #
     def fetch_blobs(self, remote, digests, *, allow_partial=False):
+        if self._remote_cache:
+            # Determine blobs missing in the remote cache and only fetch those
+            digests = self.missing_blobs(digests)
+
         missing_blobs = [] if allow_partial else None
 
         remote.init()
@@ -581,6 +629,15 @@
 
         batch.send(missing_blobs=missing_blobs)
 
+        if self._remote_cache:
+            # Upload fetched blobs to the remote cache as we can't transfer
+            # blobs directly from another remote to the remote cache
+            batch = _CASBatchUpdate(self._default_remote)
+            for digest in digests:
+                if missing_blobs is None or digest not in missing_blobs:  # pylint: disable=unsupported-membership-test
+                    batch.add(digest)
+            batch.send()
+
         return missing_blobs
 
     # send_blobs():
@@ -592,6 +649,17 @@
     #    digests (list): The Digests of Blobs to upload
     #
     def send_blobs(self, remote, digests):
+        if self._remote_cache:
+            # First fetch missing blobs from the remote cache as we can't
+            # transfer blobs directly from the remote cache to another remote.
+
+            remote_missing_blobs = self.missing_blobs(digests, remote=remote)
+
+            batch = _CASBatchRead(self._default_remote)
+            for digest in remote_missing_blobs:
+                batch.add(digest)
+            batch.send()
+
         batch = _CASBatchUpdate(remote)
 
         for digest in digests:
diff --git a/src/buildstream/_cas/casdprocessmanager.py b/src/buildstream/_cas/casdprocessmanager.py
index 11a16f2..0a7d768 100644
--- a/src/buildstream/_cas/casdprocessmanager.py
+++ b/src/buildstream/_cas/casdprocessmanager.py
@@ -51,10 +51,11 @@
 #     log_dir (str): The directory for the logs
 #     log_level (LogLevel): Log level to give to buildbox-casd for logging
 #     cache_quota (int): User configured cache quota
+#     remote_cache_spec (RemoteSpec): Optional remote cache server
 #     protect_session_blobs (bool): Disable expiry for blobs used in the current session
 #
 class CASDProcessManager:
-    def __init__(self, path, log_dir, log_level, cache_quota, protect_session_blobs):
+    def __init__(self, path, log_dir, log_level, cache_quota, remote_cache_spec, protect_session_blobs):
         self._log_dir = log_dir
 
         self._socket_path = self._make_socket_path(path)
@@ -71,6 +72,16 @@
         if protect_session_blobs:
             casd_args.append("--protect-session-blobs")
 
+        if remote_cache_spec:
+            casd_args.append("--cas-remote={}".format(remote_cache_spec.url))
+            if remote_cache_spec.instance_name:
+                casd_args.append("--cas-instance={}".format(remote_cache_spec.instance_name))
+            if remote_cache_spec.server_cert:
+                casd_args.append("--cas-server-cert={}".format(remote_cache_spec.server_cert))
+            if remote_cache_spec.client_key:
+                casd_args.append("--cas-client-key={}".format(remote_cache_spec.client_key))
+                casd_args.append("--cas-client-cert={}".format(remote_cache_spec.client_cert))
+
         casd_args.append(path)
 
         self._start_time = time.time()
diff --git a/src/buildstream/_cas/casremote.py b/src/buildstream/_cas/casremote.py
index 3799f95..b9ae3d7 100644
--- a/src/buildstream/_cas/casremote.py
+++ b/src/buildstream/_cas/casremote.py
@@ -55,6 +55,11 @@
     # be called outside of init().
     #
     def _configure_protocols(self):
+        if not self.spec:
+            # Remote cache (handled by default instance in casd)
+            self.local_cas_instance_name = ""
+            return
+
         local_cas = self.cascache.get_local_cas()
         request = local_cas_pb2.GetInstanceNameForRemotesRequest()
         cas_endpoint = request.content_addressable_storage
diff --git a/src/buildstream/_cas/casserver.py b/src/buildstream/_cas/casserver.py
index 5ddd446..3a89736 100644
--- a/src/buildstream/_cas/casserver.py
+++ b/src/buildstream/_cas/casserver.py
@@ -109,7 +109,7 @@
     logger.addHandler(handler)
 
     casd_manager = CASDProcessManager(
-        os.path.abspath(repo), os.path.join(os.path.abspath(repo), "logs"), log_level, quota, False
+        os.path.abspath(repo), os.path.join(os.path.abspath(repo), "logs"), log_level, quota, None, False
     )
     casd_channel = casd_manager.create_channel()
 
diff --git a/src/buildstream/_remote.py b/src/buildstream/_remote.py
index 0d47921..42314eb 100644
--- a/src/buildstream/_remote.py
+++ b/src/buildstream/_remote.py
@@ -52,7 +52,10 @@
         return False
 
     def __str__(self):
-        return self.spec.url
+        if self.spec:
+            return self.spec.url
+        else:
+            return "(default remote)"
 
     ####################################################
     #                   Remote API                     #
@@ -68,7 +71,9 @@
             if self._initialized:
                 return
 
-            self.channel = self.spec.open_channel()
+            if self.spec:
+                self.channel = self.spec.open_channel()
+
             self._configure_protocols()
             self._initialized = True