Merge pull request #913 from NLGithubWP/update-processing
Update date processing in largedataset_cnn example
diff --git a/examples/largedataset_cnn/process_data.py b/examples/largedataset_cnn/process_data.py
index 04a1750..1fdf185 100644
--- a/examples/largedataset_cnn/process_data.py
+++ b/examples/largedataset_cnn/process_data.py
@@ -35,13 +35,14 @@
return im
+
def process_data(dataset_root, classes):
- # load class names
+ # Load class names
with open(classes, 'r', encoding='utf-8') as f:
classes = f.readlines()
classes = list(map(lambda x: x.strip(), classes))
- # make input_paths and labels
+ # Make input_paths and labels
input_paths, labels = [], []
for class_name in os.listdir(dataset_root):
class_root = os.path.join(dataset_root, class_name)
@@ -54,17 +55,17 @@
input_paths.append(path)
labels.append(class_id)
- # convert to numpy array
+ # Convert to numpy array
input_paths = np.array(input_paths)
labels = np.array(labels, dtype=np.int32)
- # shuffle dataset
+ # Shuffle dataset
np.random.seed(0)
perm = np.random.permutation(len(input_paths))
input_paths = input_paths[perm]
labels = labels[perm]
- # split dataset for training and validation
+ # Split dataset for training and validation
border = int(len(input_paths) * 0.8)
train_labels = labels[:border]
val_labels = labels[border:]
@@ -78,13 +79,13 @@
return train_input_paths, train_labels, val_input_paths, val_labels
def loaddata():
- dataset_root = '/Dataset/Data/' # need to set this path an argument
+ dataset_root = '/Dataset/Data/'
classes = '/Dataset/classes.txt'
return process_data(dataset_root, classes)
if __name__ == '__main__':
- # Test loaddata() function in main
+ # test script in main
train_input_paths, train_labels, val_input_paths, val_labels = loaddata()
print(train_input_paths.shape)