blob: dbf8ac91864faf529a6cb61a29aec751a4782187 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Utility methods for testing on GCP."""
from __future__ import absolute_import
import logging
import random
import time
from apache_beam.io import filesystems
from apache_beam.io.gcp.pubsub import PubsubMessage
from apache_beam.utils import retry
# Protect against environments where bigquery library is not available.
try:
from google.api_core import exceptions as gexc
from google.cloud import bigquery
except ImportError:
gexc = None
bigquery = None
_LOGGER = logging.getLogger(__name__)
class GcpTestIOError(retry.PermanentException):
"""Basic GCP IO error for testing. Function that raises this error should
not be retried."""
pass
@retry.with_exponential_backoff(
num_retries=3,
retry_filter=retry.retry_on_server_errors_filter)
def create_bq_dataset(project, dataset_base_name):
"""Creates an empty BigQuery dataset.
Args:
project: Project to work in.
dataset_base_name: Prefix for dataset id.
Returns:
A ``google.cloud.bigquery.dataset.DatasetReference`` object pointing to the
new dataset.
"""
client = bigquery.Client(project=project)
unique_dataset_name = '%s%s%d' % (dataset_base_name,
str(int(time.time())),
random.randint(0, 10000))
dataset_ref = client.dataset(unique_dataset_name, project=project)
dataset = bigquery.Dataset(dataset_ref)
client.create_dataset(dataset)
return dataset_ref
@retry.with_exponential_backoff(
num_retries=3,
retry_filter=retry.retry_on_server_errors_filter)
def delete_bq_dataset(project, dataset_ref):
"""Deletes a BigQuery dataset and its contents.
Args:
project: Project to work in.
dataset_ref: A ``google.cloud.bigquery.dataset.DatasetReference`` object
pointing to the dataset to delete.
"""
client = bigquery.Client(project=project)
client.delete_dataset(dataset_ref, delete_contents=True)
@retry.with_exponential_backoff(
num_retries=3,
retry_filter=retry.retry_on_server_errors_filter)
def delete_bq_table(project, dataset_id, table_id):
"""Delete a BiqQuery table.
Args:
project: Name of the project.
dataset_id: Name of the dataset where table is.
table_id: Name of the table.
"""
_LOGGER.info('Clean up a BigQuery table with project: %s, dataset: %s, '
'table: %s.', project, dataset_id, table_id)
client = bigquery.Client(project=project)
table_ref = client.dataset(dataset_id).table(table_id)
try:
client.delete_table(table_ref)
except gexc.NotFound:
raise GcpTestIOError('BigQuery table does not exist: %s' % table_ref)
@retry.with_exponential_backoff(
num_retries=3,
retry_filter=retry.retry_on_server_errors_filter)
def delete_directory(directory):
"""Delete a directory in a filesystem.
Args:
directory: Full path to a directory supported by Beam filesystems (e.g.
"gs://mybucket/mydir/", "s3://...", ...)
"""
filesystems.FileSystems.delete([directory])
def write_to_pubsub(pub_client,
topic_path,
messages,
with_attributes=False,
chunk_size=100,
delay_between_chunks=0.1):
for start in range(0, len(messages), chunk_size):
message_chunk = messages[start:start + chunk_size]
if with_attributes:
futures = [
pub_client.publish(topic_path, message.data, **message.attributes)
for message in message_chunk
]
else:
futures = [
pub_client.publish(topic_path, message) for message in message_chunk
]
for future in futures:
future.result()
time.sleep(delay_between_chunks)
def read_from_pubsub(sub_client,
subscription_path,
with_attributes=False,
number_of_elements=None,
timeout=None):
if number_of_elements is None and timeout is None:
raise ValueError("Either number_of_elements or timeout must be specified.")
messages = []
start_time = time.time()
while ((number_of_elements is None or len(messages) < number_of_elements) and
(timeout is None or (time.time() - start_time) < timeout)):
try:
response = sub_client.pull(
subscription_path, max_messages=1000, retry=None, timeout=10)
except (gexc.RetryError, gexc.DeadlineExceeded):
continue
ack_ids = [msg.ack_id for msg in response.received_messages]
sub_client.acknowledge(subscription_path, ack_ids)
for msg in response.received_messages:
message = PubsubMessage._from_message(msg.message)
if with_attributes:
messages.append(message)
else:
messages.append(message.data)
return messages