blob: e8816fe2977306bf6b69c6a95fea78ca3d686c7e [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import datetime
import os
import tempfile
import urllib
import apache_beam as beam
from apache_beam import coders
from apache_beam.io import filesystems
from apache_beam.transforms import combiners
try: # Python 3
unquote_to_bytes = urllib.parse.unquote_to_bytes
quote = urllib.parse.quote
except AttributeError: # Python 2
# pylint: disable=deprecated-urllib-function
unquote_to_bytes = urllib.unquote
quote = urllib.quote
class CacheManager(object):
"""Abstract class for caching PCollections.
A PCollection cache is identified by labels, which consist of a prefix (either
'full' or 'sample') and a cache_label which is a hash of the PCollection
derivation.
"""
def exists(self, *labels):
"""Returns if the PCollection cache exists."""
raise NotImplementedError
def is_latest_version(self, version, *labels):
"""Returns if the given version number is the latest."""
return version == self._latest_version(*labels)
def _latest_version(self, *labels):
"""Returns the latest version number of the PCollection cache."""
raise NotImplementedError
def read(self, *labels):
"""Return the PCollection as a list as well as the version number.
Returns:
(List[PCollection])
(int) the version number
It is possible that the version numbers from read() and_latest_version()
are different. This usually means that the cache's been evicted (thus
unavailable => read() returns version = -1), but it had reached version n
before eviction.
"""
raise NotImplementedError
def source(self, *labels):
"""Returns a beam.io.Source that reads the PCollection cache."""
raise NotImplementedError
def sink(self, *labels):
"""Returns a beam.io.Sink that writes the PCollection cache."""
raise NotImplementedError
def cleanup(self):
"""Cleans up all the PCollection caches."""
raise NotImplementedError
class FileBasedCacheManager(CacheManager):
"""Maps PCollections to local temp files for materialization."""
def __init__(self, cache_dir=None):
if cache_dir:
self._cache_dir = filesystems.FileSystems.join(
cache_dir,
datetime.datetime.now().strftime("cache-%y-%m-%d-%H_%M_%S"))
else:
self._cache_dir = tempfile.mkdtemp(
prefix='interactive-temp-', dir=os.environ.get('TEST_TMPDIR', None))
self._versions = collections.defaultdict(lambda: self._CacheVersion())
def exists(self, *labels):
return bool(self._match(*labels))
def _latest_version(self, *labels):
timestamp = 0
for path in self._match(*labels):
timestamp = max(timestamp, filesystems.FileSystems.last_updated(path))
result = self._versions["-".join(labels)].get_version(timestamp)
return result
def read(self, *labels):
if not self.exists(*labels):
return [], -1
def _read_helper():
coder = SafeFastPrimitivesCoder()
for path in self._match(*labels):
for line in filesystems.FileSystems.open(path):
yield coder.decode(line.strip())
result, version = list(_read_helper()), self._latest_version(*labels)
return result, version
def source(self, *labels):
return beam.io.ReadFromText(self._glob_path(*labels),
coder=SafeFastPrimitivesCoder())._source
def sink(self, *labels):
return beam.io.WriteToText(self._path(*labels),
coder=SafeFastPrimitivesCoder())._sink
def cleanup(self):
if filesystems.FileSystems.exists(self._cache_dir):
filesystems.FileSystems.delete([self._cache_dir])
def _glob_path(self, *labels):
return self._path(*labels) + '-*-of-*'
def _path(self, *labels):
return filesystems.FileSystems.join(self._cache_dir, *labels)
def _match(self, *labels):
match = filesystems.FileSystems.match([self._glob_path(*labels)])
assert len(match) == 1
return [metadata.path for metadata in match[0].metadata_list]
class _CacheVersion(object):
"""This class keeps track of the timestamp and the corresponding version."""
def __init__(self):
self.current_version = -1
self.current_timestamp = 0
def get_version(self, timestamp):
"""Updates version if necessary and returns the version number.
Args:
timestamp: (int) unix timestamp when the cache is updated. This value is
zero if the cache has been evicted or doesn't exist.
"""
# Do not update timestamp if the cache's been evicted.
if timestamp != 0 and timestamp != self.current_timestamp:
assert timestamp > self.current_timestamp
self.current_version = self.current_version + 1
self.current_timestamp = timestamp
return self.current_version
class ReadCache(beam.PTransform):
"""A PTransform that reads the PCollections from the cache."""
def __init__(self, cache_manager, label):
self._cache_manager = cache_manager
self._label = label
def expand(self, pbegin):
# pylint: disable=expression-not-assigned
return pbegin | 'Read' >> beam.io.Read(
self._cache_manager.source('full', self._label))
class WriteCache(beam.PTransform):
"""A PTransform that writes the PCollections to the cache."""
def __init__(self, cache_manager, label, sample=False, sample_size=0):
self._cache_manager = cache_manager
self._label = label
self._sample = sample
self._sample_size = sample_size
def expand(self, pcoll):
prefix = 'sample' if self._sample else 'full'
if self._sample:
pcoll |= 'Sample' >> (
combiners.Sample.FixedSizeGlobally(self._sample_size)
| beam.FlatMap(lambda sample: sample))
# pylint: disable=expression-not-assigned
return pcoll | 'Write' >> beam.io.Write(
self._cache_manager.sink(prefix, self._label))
class SafeFastPrimitivesCoder(coders.Coder):
"""This class add an quote/unquote step to escape special characters."""
# pylint: disable=deprecated-urllib-function
def encode(self, value):
return quote(coders.coders.FastPrimitivesCoder().encode(value)).encode(
'utf-8')
def decode(self, value):
return coders.coders.FastPrimitivesCoder().decode(unquote_to_bytes(value))