blob: 29b4ba009a9ad18e792d16e54214c374fdfed151 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Preprocessing script.
This script walks over the directories and dump the frames into a csv file
"""
import os
import csv
import sys
import random
import scipy
import numpy as np
import dicom
from skimage import io, transform
from joblib import Parallel, delayed
import dill
def mkdir(fname):
try:
os.mkdir(fname)
except:
pass
def get_frames(root_path):
"""Get path to all the frame in view SAX and contain complete frames"""
ret = []
for root, _, files in os.walk(root_path):
root=root.replace('\\','/')
files=[s for s in files if ".dcm" in s]
if len(files) == 0 or not files[0].endswith(".dcm") or root.find("sax") == -1:
continue
prefix = files[0].rsplit('-', 1)[0]
fileset = set(files)
expected = ["%s-%04d.dcm" % (prefix, i + 1) for i in range(30)]
if all(x in fileset for x in expected):
ret.append([root + "/" + x for x in expected])
# sort for reproduciblity
return sorted(ret, key = lambda x: x[0])
def get_label_map(fname):
labelmap = {}
fi = open(fname)
fi.readline()
for line in fi:
arr = line.split(',')
labelmap[int(arr[0])] = line
return labelmap
def write_label_csv(fname, frames, label_map):
fo = open(fname, "w")
for lst in frames:
index = int(lst[0].split("/")[3])
if label_map is not None:
fo.write(label_map[index])
else:
fo.write("%d,0,0\n" % index)
fo.close()
def get_data(lst,preproc):
data = []
result = []
for path in lst:
f = dicom.read_file(path)
img = preproc(f.pixel_array.astype(float) / np.max(f.pixel_array))
dst_path = path.rsplit(".", 1)[0] + ".64x64.jpg"
scipy.misc.imsave(dst_path, img)
result.append(dst_path)
data.append(img)
data = np.array(data, dtype=np.uint8)
data = data.reshape(data.size)
data = np.array(data, dtype=np.str_)
data = data.reshape(data.size)
return [data,result]
def write_data_csv(fname, frames, preproc):
"""Write data to csv file"""
fdata = open(fname, "w")
dr = Parallel()(delayed(get_data)(lst,preproc) for lst in frames)
data,result = zip(*dr)
for entry in data:
fdata.write(','.join(entry)+'\r\n')
print("All finished, %d slices in total" % len(data))
fdata.close()
result = np.ravel(result)
return result
def crop_resize(img, size):
"""crop center and resize"""
if img.shape[0] < img.shape[1]:
img = img.T
# we crop image from center
short_egde = min(img.shape[:2])
yy = int((img.shape[0] - short_egde) / 2)
xx = int((img.shape[1] - short_egde) / 2)
crop_img = img[yy : yy + short_egde, xx : xx + short_egde]
# resize to 64, 64
resized_img = transform.resize(crop_img, (size, size))
resized_img *= 255
return resized_img.astype("uint8")
def local_split(train_index):
random.seed(0)
train_index = set(train_index)
all_index = sorted(train_index)
num_test = int(len(all_index) / 3)
random.shuffle(all_index)
train_set = set(all_index[num_test:])
test_set = set(all_index[:num_test])
return train_set, test_set
def split_csv(src_csv, split_to_train, train_csv, test_csv):
ftrain = open(train_csv, "w")
ftest = open(test_csv, "w")
cnt = 0
for l in open(src_csv):
if split_to_train[cnt]:
ftrain.write(l)
else:
ftest.write(l)
cnt = cnt + 1
ftrain.close()
ftest.close()
# Load the list of all the training frames, and shuffle them
# Shuffle the training frames
random.seed(10)
train_frames = get_frames("./data/train")
random.shuffle(train_frames)
validate_frames = get_frames("./data/validate")
# Write the corresponding label information of each frame into file.
write_label_csv("./train-label.csv", train_frames, get_label_map("./data/train.csv"))
write_label_csv("./validate-label.csv", validate_frames, None)
# Dump the data of each frame into a CSV file, apply crop to 64 preprocessor
train_lst = write_data_csv("./train-64x64-data.csv", train_frames, lambda x: crop_resize(x, 64))
valid_lst = write_data_csv("./validate-64x64-data.csv", validate_frames, lambda x: crop_resize(x, 64))
# Generate local train/test split, which you could use to tune your model locally.
train_index = np.loadtxt("./train-label.csv", delimiter=",")[:,0].astype("int")
train_set, test_set = local_split(train_index)
split_to_train = [x in train_set for x in train_index]
split_csv("./train-label.csv", split_to_train, "./local_train-label.csv", "./local_test-label.csv")
split_csv("./train-64x64-data.csv", split_to_train, "./local_train-64x64-data.csv", "./local_test-64x64-data.csv")