Add support for TensorBoard

Only accepts safe options to pass to TensorBoard().  No
other callbacks allowed for now.

Co-authored-by: Domino Valdano <dvaldano@pivotal.io>
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index cd2d075..796f743 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -635,7 +635,7 @@
 
     # Fit segment model on data
     #TODO consider not doing this every time
-    fit_params = parse_and_validate_fit_params(fit_params)
+    fit_params = parse_and_validate_fit_params(fit_params, current_seg_id)
     segment_model.fit(x_train, y_train, **fit_params)
 
     # Aggregating number of images, loss and accuracy
@@ -734,7 +734,7 @@
                                                       custom_function_map)
 
         set_model_weights(segment_model, serialized_weights)
-        fit_params = parse_and_validate_fit_params(fit_params)
+        fit_params = parse_and_validate_fit_params(fit_params, current_seg_id)
 
         for i in range(len(GD['x_train'])):
             # Fit segment model on data
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
index f94cb44..4434e75 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_model_selection.py_in
@@ -364,7 +364,13 @@
             _assert(self.num_configs is None and self.random_state is None,
                     "DL: 'num_configs' and 'random_state' must be NULL for grid search")
             for distribution_type in self.accepted_distributions:
-                _assert(distribution_type not in compile_params_grid and distribution_type not in fit_params_grid,
+                # If the distribution is used, it will be in the following format:
+                # [123, 456, '<dist name>']
+                # Matching single quotes and the closing bracket minimizes false
+                # positives from the log_dir parameter.
+                tmp_dist = "'{0}']".format(distribution_type)
+                _assert(tmp_dist not in compile_params_grid and
+                        tmp_dist not in fit_params_grid,
                         "DL: Cannot search from a distribution with grid search")
         elif ModelSelectionSchema.RANDOM_SEARCH.startswith(self.search_type.lower()):
             _assert(self.num_configs is not None and self.num_configs > 0, "DL: 'num_configs' cannot be NULL and "
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index cbd882a..c23f8d3 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -31,11 +31,14 @@
 
 from utilities.utilities import _assert
 from utilities.utilities import is_platform_pg
+from utilities.utilities import current_user
+from utilities.utilities import is_superuser
 
 import tensorflow as tf
 from tensorflow.keras import backend as K
 from tensorflow.keras import utils as keras_utils
 from tensorflow.keras.optimizers import *
+from tensorflow.keras.callbacks import TensorBoard
 
 import tensorflow.keras.optimizers as opt
 import tensorflow.keras.losses as losses
@@ -141,9 +144,10 @@
 
 """
 Used to convert compile_params and fit_params to actual argument dictionaries
+If strip_quotes is True, each value in the dictionary will be stripped of quotes
 """
 
-def convert_string_of_args_to_dict(str_of_args):
+def convert_string_of_args_to_dict(str_of_args, strip_quotes=True):
     """Uses parenthases matching algorithm to intelligently convert
     a string with valid python code into an argument dictionary"""
     stack = []
@@ -164,16 +168,23 @@
                 stack.pop(-1)
             result_str += char
         elif not stack and char == "=":
-            key_str = result_str
+            key_str = result_str.strip()
             result_str = ""
         elif not stack and char == ",":
             value_str = result_str
             result_str = ""
-            compile_dict[key_str.strip()]=value_str.strip().strip('\'')
+            key_str = key_str.strip()
+            value_str = value_str.strip()
+            if strip_quotes:
+                value_str = value_str.strip('"\'')
+            compile_dict[key_str]=value_str
         else:
             result_str += char
     value_str = result_str
-    compile_dict[key_str.strip()]=value_str.strip().strip('\'')
+    value_str = value_str.strip()
+    if strip_quotes:
+        value_str = value_str.strip('"\'')
+    compile_dict[key_str]=value_str
     return compile_dict
 
 def get_metrics_from_compile_param(str_of_args):
@@ -286,27 +297,50 @@
 
 
 # Parse the fit parameters into a dictionary.
-def parse_and_validate_fit_params(fit_param_str):
+def parse_and_validate_fit_params(fit_param_str, current_seg_id=-1):
 
     if fit_param_str:
-        fit_params_dict = convert_string_of_args_to_dict(fit_param_str)
-
-        literal_eval_fit_params = ['batch_size','epochs','verbose',
+        fit_params_dict = convert_string_of_args_to_dict(fit_param_str, strip_quotes=False)
+        literal_eval_fit_params = ['batch_size','epochs','verbose', 'shuffle',
                                    'class_weight','initial_epoch','steps_per_epoch']
-        accepted_fit_params = literal_eval_fit_params + ['shuffle']
+        accepted_fit_params = literal_eval_fit_params + ['callbacks']
 
         fit_params_dict = validate_and_literal_eval_keys(fit_params_dict,
                                                          literal_eval_fit_params,
                                                          accepted_fit_params)
-        if 'shuffle' in fit_params_dict:
-            shuffle_value = fit_params_dict['shuffle']
-            if shuffle_value == 'True' or shuffle_value == 'False':
-                fit_params_dict['shuffle'] = bool(shuffle_value)
+
+        if 'callbacks' in fit_params_dict:
+            fit_params_dict['callbacks'] = parse_callbacks(fit_params_dict['callbacks'], current_seg_id)
 
         return fit_params_dict
     else:
         return {}
 
+# Parse the callback fit params and create the TensorBoard object in the dictionary
+def parse_callbacks(callbacks, current_seg_id=-1):
+    callbacks = callbacks.strip("'")
+    if not is_superuser(current_user()):
+        plpy.error("Only a superuser may use callbacks.")
+    try:
+        tree = ast.parse(callbacks, mode='eval')
+        assert(type(tree.body) == ast.List)
+        assert(len(tree.body.elts) == 1)
+        assert(type(tree.body.elts[0]) == ast.Call)
+        assert(tree.body.elts[0].func.id == 'TensorBoard')
+        tb_params = tree.body.elts[0].keywords
+        tb_params_dict = { tb_params[i].arg : tb_params[i].value \
+                        for i in range(len(tb_params)) }
+    except:
+        plpy.error("Invalid callbacks fit param.  Currently, "
+                    "only TensorBoard callbacks are accepted.")
+
+    accepted_tb_params = [ 'log_dir', 'histogram_freq', 'batch_size', 'update_freq',
+                           'write_graph', 'write_grad', 'write_images' ]
+    tb_params_dict = validate_and_literal_eval_keys(tb_params_dict, accepted_tb_params, accepted_tb_params)
+    tb_params_dict['log_dir'] = "{0}{1}".format(tb_params_dict['log_dir'],(current_seg_id))
+
+    return [TensorBoard(**tb_params_dict)]
+
 # Validate the keys of the given dictionary and run literal_eval on the
 # user-defined subset
 def validate_and_literal_eval_keys(keys_dict, literal_eval_list, accepted_list):
@@ -317,8 +351,8 @@
             try:
                 keys_dict[ckey] = ast.literal_eval(keys_dict[ckey])
             except ValueError:
-                plpy.error(("invalid input value for parameter {0}, "
-                            "please refer to the documentation").format(ckey))
+                plpy.error(("invalid input value for parameter {0}={1}, "
+                            "please refer to the documentation").format(ckey, keys_dict[ckey]))
     return keys_dict
 
 # Split and strip the whitespace of key=value formatted strings
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in
index 82f4301..6448ac8 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_automl.sql_in
@@ -35,7 +35,8 @@
 automl_mst_table_summary;
 SELECT madlib_keras_automl('iris_data_packed', 'automl_output', 'iris_model_arch', 'automl_mst_table',
                            ARRAY[1], $${'loss': ['categorical_crossentropy'], 'optimizer_params_list': [{'optimizer': ['Adam', 'SGD'],
-    'lr': [0.01, 0.011, 'log']} ], 'metrics':['accuracy'] }$$, $${'batch_size': [50], 'epochs': [1]}$$,
+    'lr': [0.01, 0.011, 'log']} ], 'metrics':['accuracy'] }$$,
+    $${'batch_size': [50], 'epochs': [1], 'callbacks': ['[TensorBoard(log_dir="/tmp/tensorflow/scalars/")]']}$$,
     'hyperopt', 'num_configs=5, num_iterations=6, algorithm=rand', NULL, NULL, FALSE, NULL, 1, 'test1', 'test1 descr');
 
 SELECT assert(COUNT(*)=1, 'The length of table does not match with the inputs') FROM automl_mst_table;
@@ -329,7 +330,8 @@
 SELECT madlib_keras_automl('iris_data_packed', 'automl_output', 'iris_model_arch', 'automl_mst_table',
     ARRAY[1,2], $${'loss': ['categorical_crossentropy'], 'optimizer_params_list': [ {'optimizer': ['Adagrad', 'Adam'],
     'lr': [0.9, 0.95, 'log'], 'epsilon': [0.3, 0.5, 'log_near_one']}, {'optimizer': ['Adam', 'SGD'],
-    'lr': [0.6, 0.65, 'log']} ], 'metrics':['accuracy'] }$$, $${'batch_size': [2, 4], 'epochs': [3]}$$);
+    'lr': [0.6, 0.65, 'log']} ], 'metrics':['accuracy'] }$$,
+    $${'batch_size': [2, 4], 'epochs': [3], 'callbacks': ['[TensorBoard(log_dir="/tmp/tensorflow/scalars/")]']}$$);
 
 SELECT assert(COUNT(*)=1, 'The length of table does not match with the inputs') FROM automl_mst_table;
 SELECT assert(COUNT(*)=1, 'The length of table does not match with the inputs') FROM automl_mst_table_summary;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
index 74aff3c..c816b1d 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
@@ -82,7 +82,7 @@
     'model_arch',
     1,
     $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
-    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    $$ batch_size=2, epochs=1, verbose=0, callbacks=[TensorBoard(log_dir='/tmp/tensorflow/single/')] $$::text,
     3,
     NULL,
     'cifar_10_sample_val');
@@ -107,7 +107,7 @@
     model_size > 0 AND
     madlib_version is NOT NULL AND
     compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text AND
-    fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
+    fit_params = $$ batch_size=2, epochs=1, verbose=0, callbacks=[TensorBoard(log_dir='/tmp/tensorflow/single/')] $$::text AND
     num_iterations = 3 AND
     metrics_compute_frequency = 3 AND
     num_classes[0] = 2 AND
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
index 2a815ba..7e5448b 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit_multiple.sql_in
@@ -394,7 +394,7 @@
         $$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
     ],
     ARRAY[
-        $$batch_size=5,epochs=1$$,
+        $$batch_size=5,epochs=1, callbacks=[TensorBoard(log_dir='/tmp/tensorflow/scalars/')]$$,
         $$batch_size=10,epochs=1$$
     ]
 );
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 164d743..0c91eb4 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -1259,6 +1259,31 @@
             self.subject.parse_and_validate_compile_params(test_str)
         self.assertIn('invalid optimizer', str(error.exception))
 
+    def test_parse_callbacks_pass(self):
+        test_str = """'[TensorBoard(log_dir="/tmp/logs/fit")]'"""
+        self.subject.current_user = MagicMock(return_value='aaa')
+        self.subject.is_superuser = MagicMock(return_value=True)
+        result = self.subject.parse_callbacks(test_str)
+        self.assertEqual("<class 'tensorflow.python.keras.callbacks_v1.TensorBoard'>",
+                         str(type(result[0])))
+
+    def test_parse_callbacks_fail(self):
+        test_str = """'[TensorBrd(log_dir="/tmp/logs/fit")]'"""
+        self.subject.current_user = MagicMock(return_value='aaa')
+        self.subject.is_superuser = MagicMock(return_value=True)
+        with self.assertRaises(plpy.PLPYException) as error:
+            result = self.subject.parse_callbacks(test_str)
+
+        self.assertIn('Invalid callbacks fit param', str(error.exception))
+
+    def test_parse_callbacks_superuser_fail(self):
+        test_str = """'[TensorBoard(log_dir="/tmp/logs/fit")]'"""
+        self.subject.current_user = MagicMock(return_value='aaa')
+        self.subject.is_superuser = MagicMock(return_value=False)
+        with self.assertRaises(plpy.PLPYException) as error:
+            result = self.subject.parse_callbacks(test_str)
+
+        self.assertIn('superuser', str(error.exception))
 
 class MadlibKerasFitCommonValidatorTestCase(unittest.TestCase):
     def setUp(self):
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
index 0a1d58f..7517be2 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras_model_selection_table.py_in
@@ -202,6 +202,39 @@
                 self.compile_params_grid,
                 self.fit_params_grid
             )
+    def test_callbacks_pass(self):
+
+        self.search_type = 'random'
+        self.fit_params_grid = """
+            {'batch_size': [8, 32, 'log'], 'epochs': [1, 2], 'callbacks': ['[TensorBoard(log_dir="/tmp/log/scalars/")]']}
+            """
+        generate_mst = self.subject(
+            self.madlib_schema,
+            self.model_selection_table,
+            self.model_arch_table,
+            self.model_id_list,
+            self.compile_params_grid,
+            self.fit_params_grid,
+            self.search_type,
+            10
+        )
+
+    def test_callbacks_fail(self):
+
+        self.search_type = 'grid'
+        self.fit_params_grid = """
+            {'batch_size': [8, 32, 'log'], 'epochs': [1, 2], 'callbacks': ['[TensorBoard(log_dir="/tmp/log/scalars/")]']}
+            """
+        with self.assertRaises(plpy.PLPYException):
+            generate_mst = self.subject(
+                self.madlib_schema,
+                self.model_selection_table,
+                self.model_arch_table,
+                self.model_id_list,
+                self.compile_params_grid,
+                self.fit_params_grid,
+                self.search_type
+            )
 
     def test_duplicate_params(self):
         self.model_id_list = [1, 1, 2]