blob: 6e65a3ecdb8746d42d6d84b9e2298ff639905e2b [file] [log] [blame]
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
"""
Data utilities for the TUPAC16 breast cancer project.
This is early, experimental code.
TODO: Cleanup & add proper comments to all functions.
"""
import multiprocessing as mp
import os
import threading
import numpy as np
import py4j
# Utils for reading data
def compute_channel_means(rdd, channels, size):
"""Compute the means of each color channel across the dataset."""
# TODO: Replace this with pyspark.ml.feature.VectorSlicer
# to cut vector into separate channel vectors, then grab the mean
# of those new columns, all using DataFrame functions, rather than
# RDD functions.
# from pyspark.ml.linalg import VectorUDT
# from pyspark.sql.functions import udf
# from pyspark.sql import functions as F
# as_ml = udf(lambda v: v.asML() if v is not None else None, VectorUDT())
# slicers[0].transform(train_df.withColumn("sample", as_ml("sample"))).select(F.avg("ch0"))
# slicers = [VectorSlicer(inputCol="sample", outputCol="ch{}".format(c), indices=range(c*pixels, c*pixels + pixels)) for c in range(CHANNELS)]
def helper(x):
x = x.sample.values
x = np.array(x)
x = (x.reshape((-1,channels,size,size)) # shape (N,C,H,W)
.transpose((0,2,3,1)) # shape (N,H,W,C)
.astype(np.float32))
mu = np.mean(x, axis=(0,1,2))
return mu
means = rdd.map(helper).collect()
means = np.array(means)
means = np.mean(means, axis=0)
return means
def gen_class_weights(df):
"""Generate class weights to even out the class distribution during training."""
class_counts_df = df.select("tumor_score").groupBy("tumor_score").count()
class_counts = {row["tumor_score"]:row["count"] for row in class_counts_df.collect()}
max_count = max(class_counts.values())
class_weights = {k-1:max_count/v for k,v in class_counts.items()}
return class_weights
def read_data(spark_session, filename_template, sample_size, channels, sample_prob, normalize_class_distribution, seed):
"""Read and return training & validation Spark DataFrames."""
# TODO: Clean this function up!!!
assert channels in (1, 3)
grayscale = False if channels == 3 else True
# Sample (Optional)
if sample_prob < 1:
try:
# Ex: `train_0.01_sample_256.parquet`
sampled_filename_template = filename_template.format("{}_sample_".format(sample_prob), sample_size, "_grayscale" if grayscale else "")
filename = os.path.join("data", sampled_filename_template)
df = spark_session.read.load(filename)
except: # Pre-sampled DataFrame not available
filename = os.path.join("data", filename_template.format("", sample_size, "_grayscale" if grayscale else ""))
df = spark_session.read.load(filename)
p = sample_prob # sample percentage
if normalize_class_distribution:
# stratified sample with even class proportions
n = df.count() # num examples
K = 3 # num classes
s = p * n # num examples in p% sample, as a fraction
s_k = s / K # num examples per class in evenly-distributed p% sample, as fraction
class_counts_df = df.select("tumor_score").groupBy("tumor_score").count()
class_counts = {row["tumor_score"]:row["count"] for row in class_counts_df.collect()}
ps = {k:s_k/v for k,v in class_counts.items()}
df = df.sampleBy("tumor_score", fractions=ps, seed=seed)
else:
# stratified sample maintaining the original class proportions
df = df.sampleBy("tumor_score", fractions={1: p, 2: p, 3: p}, seed=seed)
# TODO: Determine if coalesce actually provides a perf benefit on Spark 2.x
#train_df.cache(), val_df.cache() # cache here, or coalesce will hang
# tc = train_df.count()
# vc = val_df.count()
#
# # Reduce num partitions to ideal size (~128 MB/partition, determined empirically)
# current_tr_parts = train_df.rdd.getNumPartitions()
# current_val_parts = train_df.rdd.getNumPartitions()
# ex_mb = sample_size * sample_size * channels * 8 / 1024 / 1024 # size of one example in MB
# ideal_part_size_mb = 128 # 128 MB partitions sizes are empirically ideal
# ideal_exs_per_part = round(ideal_part_size_mb / ex_mb)
# tr_parts = round(tc / ideal_exs_per_part)
# val_parts = round(vc / ideal_exs_per_part)
# if current_tr_parts > tr_parts:
# train_df = train_df.coalesce(tr_parts)
# if current_val_parts > val_parts:
# val_df = val_df.coalesce(val_parts)
# train_df.cache(), val_df.cache()
else:
# Read in data
filename = os.path.join("data", filename_template.format("", sample_size, "_grayscale" if grayscale else ""))
df = spark_session.read.load(filename)
return df
def read_train_data(spark_session, sample_size, channels, sample_prob=1, normalize_class_distribution=False, seed=42):
"""Read training Spark DataFrame."""
filename = "train_{}{}{}_updated.parquet"
train_df = read_data(spark_session, filename, sample_size, channels, sample_prob, normalize_class_distribution, seed)
return train_df
def read_val_data(spark_session, sample_size, channels, sample_prob=1, normalize_class_distribution=False, seed=42):
"""Read validation Spark DataFrame."""
filename = "val_{}{}{}_updated.parquet"
train_df = read_data(spark_session, filename, sample_size, channels, sample_prob, normalize_class_distribution, seed)
return train_df
# Utils for creating asynchronous queuing batch generators
# TODO: Add comments to these functions
def fill_partition_num_queue(partition_num_queue, num_partitions, stop_event):
partition_num_queue.cancel_join_thread()
while not stop_event.is_set():
for i in range(num_partitions):
partition_num_queue.put(i)
def fill_partition_queue(partition_queue, partition_num_queue, rdd, stop_event):
partition_queue.cancel_join_thread()
while not stop_event.is_set():
# py4j has some issues with imports with first starting.
try:
partition_num = partition_num_queue.get()
partition = rdd.context.runJob(rdd, lambda x: x, [partition_num])
partition_queue.put(partition)
except (AttributeError, py4j.protocol.Py4JError, Exception) as err:
print("error: {}".format(err))
def fill_row_queue(row_queue, partition_queue, stop_event):
row_queue.cancel_join_thread()
while not stop_event.is_set():
rows = partition_queue.get()
for row in rows:
row_queue.put(row)
def gen_batch(row_queue, batch_size):
while True:
features = []
labels = []
for i in range(batch_size):
row = row_queue.get()
features.append(row.sample.values)
labels.append(row.tumor_score)
x_batch = np.array(features).astype(np.uint8)
y_batch = np.array(labels).astype(np.uint8)
yield x_batch, y_batch
def create_batch_generator(
rdd, batch_size=32, num_partition_threads=32, num_row_processes=16,
partition_num_queue_size=128, partition_queue_size=16, row_queue_size=2048):
"""
Create a multiprocess batch generator.
This creates a generator that uses processes and threads to create a
pipeline that asynchronously fetches data from Spark, filling a set
of queues, while yielding batches. The goal here is to amortize the
time needed to fetch data from Spark so that downstream consumers
are saturated.
"""
#rdd.cache()
partition_num_queue = mp.Queue(partition_num_queue_size)
partition_queue = mp.Queue(partition_queue_size)
row_queue = mp.Queue(row_queue_size)
num_partitions = rdd.getNumPartitions()
stop_event = mp.Event()
partition_num_process = mp.Process(target=fill_partition_num_queue, args=(partition_num_queue, num_partitions, stop_event), daemon=True)
partition_threads = [threading.Thread(target=fill_partition_queue, args=(partition_queue, partition_num_queue, rdd, stop_event), daemon=True) for _ in range(num_partition_threads)]
row_processes = [mp.Process(target=fill_row_queue, args=(row_queue, partition_queue, stop_event), daemon=True) for _ in range(num_row_processes)]
ps = [partition_num_process] + row_processes + partition_threads
queues = [partition_num_queue, partition_queue, row_queue]
for p in ps:
p.start()
generator = gen_batch(row_queue, batch_size)
return generator, ps, queues, stop_event
def stop(processes, stop_event):
"""Stop queuing processes."""
stop_event.set()
for p in processes:
if isinstance(p, mp.Process):
p.terminate()
mp.active_children() # Use to join the killed processes above.