[BEAM-1833] Preserve inputs names at graph construction and through proto transaltion. (#15202)

diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py
index 69362a4..d0aaa2f 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -61,6 +61,7 @@
 from typing import FrozenSet
 from typing import Iterable
 from typing import List
+from typing import Mapping
 from typing import Optional
 from typing import Sequence
 from typing import Set
@@ -271,7 +272,7 @@
     output_replacements = {
     }  # type: Dict[AppliedPTransform, List[Tuple[pvalue.PValue, Optional[str]]]]
     input_replacements = {
-    }  # type: Dict[AppliedPTransform, Sequence[Union[pvalue.PBegin, pvalue.PCollection]]]
+    }  # type: Dict[AppliedPTransform, Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]]
     side_input_replacements = {
     }  # type: Dict[AppliedPTransform, List[pvalue.AsSideInput]]
 
@@ -297,7 +298,7 @@
               original_transform_node.parent,
               replacement_transform,
               original_transform_node.full_label,
-              original_transform_node.inputs)
+              original_transform_node.main_inputs)
 
           replacement_transform_node.resource_hints = (
               original_transform_node.resource_hints)
@@ -437,11 +438,11 @@
                 output_replacements[transform_node].append((tag, replacement))
 
         if replace_input:
-          new_input = [
-              input if not input in output_map else output_map[input]
-              for input in transform_node.inputs
-          ]
-          input_replacements[transform_node] = new_input
+          new_inputs = {
+              tag: input if not input in output_map else output_map[input]
+              for (tag, input) in transform_node.main_inputs.items()
+          }
+          input_replacements[transform_node] = new_inputs
 
         if replace_side_inputs:
           new_side_inputs = []
@@ -670,15 +671,18 @@
 
     pvalueish, inputs = transform._extract_input_pvalues(pvalueish)
     try:
-      inputs = tuple(inputs)
-      for leaf_input in inputs:
-        if not isinstance(leaf_input, pvalue.PValue):
-          raise TypeError
+      if not isinstance(inputs, dict):
+        inputs = {str(ix): input for (ix, input) in enumerate(inputs)}
     except TypeError:
       raise NotImplementedError(
           'Unable to extract PValue inputs from %s; either %s does not accept '
           'inputs of this format, or it does not properly override '
           '_extract_input_pvalues' % (pvalueish, transform))
+    for t, leaf_input in inputs.items():
+      if not isinstance(leaf_input, pvalue.PValue) or not isinstance(t, str):
+        raise NotImplementedError(
+            '%s does not properly override _extract_input_pvalues, '
+            'returned %s from %s' % (transform, inputs, pvalueish))
 
     current = AppliedPTransform(
         self._current_transform(), transform, full_label, inputs)
@@ -705,7 +709,8 @@
         if result.producer is None:
           result.producer = current
 
-        self._infer_result_type(transform, inputs, result)
+        # TODO(BEAM-1833): Pass full tuples dict.
+        self._infer_result_type(transform, tuple(inputs.values()), result)
 
         assert isinstance(result.producer.inputs, tuple)
         # The DoOutputsTuple adds the PCollection to the outputs when accessed
@@ -940,7 +945,7 @@
     for id in proto.components.transforms:
       transform = context.transforms.get_by_id(id)
       if not transform.inputs and transform.transform.__class__ in has_pbegin:
-        transform.inputs = (pvalue.PBegin(p), )
+        transform.main_inputs = {'None': pvalue.PBegin(p)}
 
     if return_context:
       return p, context  # type: ignore  # too complicated for now
@@ -1030,7 +1035,7 @@
       parent,  # type:  Optional[AppliedPTransform]
       transform,  # type: Optional[ptransform.PTransform]
       full_label,  # type: str
-      inputs,  # type: Optional[Sequence[Union[pvalue.PBegin, pvalue.PCollection]]]
+      main_inputs,  # type: Optional[Mapping[str, Union[pvalue.PBegin, pvalue.PCollection]]]
       environment_id=None,  # type: Optional[str]
       annotations=None, # type: Optional[Dict[str, bytes]]
   ):
@@ -1043,7 +1048,7 @@
     # reusing PTransform instances in different contexts (apply() calls) without
     # any interference. This is particularly useful for composite transforms.
     self.full_label = full_label
-    self.inputs = inputs or ()
+    self.main_inputs = dict(main_inputs or {})
 
     self.side_inputs = tuple() if transform is None else transform.side_inputs
     self.outputs = {}  # type: Dict[Union[str, int, None], pvalue.PValue]
@@ -1076,6 +1081,10 @@
       }
     self.annotations = annotations
 
+  @property
+  def inputs(self):
+    return tuple(self.main_inputs.values())
+
   def __repr__(self):
     # type: () -> str
     return "%s(%s, %s)" % (
@@ -1109,8 +1118,8 @@
     if isinstance(self.transform, external.ExternalTransform):
       self.transform.replace_named_outputs(self.named_outputs())
 
-  def replace_inputs(self, inputs):
-    self.inputs = inputs
+  def replace_inputs(self, main_inputs):
+    self.main_inputs = main_inputs
 
     # Importing locally to prevent circular dependency issues.
     from apache_beam.transforms import external
@@ -1215,12 +1224,11 @@
 
   def named_inputs(self):
     # type: () -> Dict[str, pvalue.PValue]
-    # TODO(BEAM-1833): Push names up into the sdk construction.
     if self.transform is None:
-      assert not self.inputs and not self.side_inputs
+      assert not self.main_inputs and not self.side_inputs
       return {}
     else:
-      return self.transform._named_inputs(self.inputs, self.side_inputs)
+      return self.transform._named_inputs(self.main_inputs, self.side_inputs)
 
   def named_outputs(self):
     # type: () -> Dict[str, pvalue.PCollection]
@@ -1309,10 +1317,10 @@
       pardo_payload = None
       side_input_tags = []
 
-    main_inputs = [
-        context.pcollections.get_by_id(id) for tag,
-        id in proto.inputs.items() if tag not in side_input_tags
-    ]
+    main_inputs = {
+        tag: context.pcollections.get_by_id(id)
+        for (tag, id) in proto.inputs.items() if tag not in side_input_tags
+    }
 
     transform = ptransform.PTransform.from_runner_api(proto, context)
     if transform and proto.environment_id:
@@ -1334,7 +1342,7 @@
         parent=None,
         transform=transform,
         full_label=proto.unique_name,
-        inputs=main_inputs,
+        main_inputs=main_inputs,
         environment_id=None,
         annotations=proto.annotations)
 
diff --git a/sdks/python/apache_beam/pipeline_test.py b/sdks/python/apache_beam/pipeline_test.py
index c2a23fd..731d5e3 100644
--- a/sdks/python/apache_beam/pipeline_test.py
+++ b/sdks/python/apache_beam/pipeline_test.py
@@ -972,6 +972,24 @@
     for transform_id in runner_api_proto.components.transforms:
       self.assertRegex(transform_id, r'[a-zA-Z0-9-_]+')
 
+  def test_input_names(self):
+    class MyPTransform(beam.PTransform):
+      def expand(self, pcolls):
+        return pcolls.values() | beam.Flatten()
+
+    p = beam.Pipeline()
+    input_names = set('ABC')
+    inputs = {x: p | x >> beam.Create([x]) for x in input_names}
+    inputs | MyPTransform()  # pylint: disable=expression-not-assigned
+    runner_api_proto = Pipeline.to_runner_api(p)
+
+    for transform_proto in runner_api_proto.components.transforms.values():
+      if transform_proto.unique_name == 'MyPTransform':
+        self.assertEqual(set(transform_proto.inputs.keys()), input_names)
+        break
+    else:
+      self.fail('Unable to find transform.')
+
   def test_display_data(self):
     class MyParentTransform(beam.PTransform):
       def expand(self, p):
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
index c112bdd..bbaf52c 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
@@ -305,7 +305,7 @@
                     parent,
                     beam.Map(lambda x: (b'', x)),
                     transform_node.full_label + '/MapToVoidKey%s' % ix,
-                    (side_input.pvalue, ))
+                    {'input': side_input.pvalue})
                 new_side_input.pvalue.producer = map_to_void_key
                 map_to_void_key.add_output(new_side_input.pvalue, None)
                 parent.add_part(map_to_void_key)
diff --git a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
index ec4edc9..e7ce71b 100644
--- a/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
+++ b/sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
@@ -348,7 +348,8 @@
     pcoll2.element_type = typehints.Any
     pcoll3.element_type = typehints.KV[typehints.Any, typehints.Any]
     for pcoll in [pcoll1, pcoll2, pcoll3]:
-      applied = AppliedPTransform(None, beam.GroupByKey(), "label", [pcoll])
+      applied = AppliedPTransform(
+          None, beam.GroupByKey(), "label", {'pcoll': pcoll})
       applied.outputs[None] = PCollection(None)
       common.group_by_key_input_visitor().visit_transform(applied)
       self.assertEqual(
@@ -367,7 +368,7 @@
     for pcoll in [pcoll1, pcoll2]:
       with self.assertRaisesRegex(ValueError, err_msg):
         common.group_by_key_input_visitor().visit_transform(
-            AppliedPTransform(None, beam.GroupByKey(), "label", [pcoll]))
+            AppliedPTransform(None, beam.GroupByKey(), "label", {'in': pcoll}))
 
   def test_group_by_key_input_visitor_for_non_gbk_transforms(self):
     p = TestPipeline()
@@ -375,7 +376,7 @@
     for transform in [beam.Flatten(), beam.Map(lambda x: x)]:
       pcoll.element_type = typehints.Any
       common.group_by_key_input_visitor().visit_transform(
-          AppliedPTransform(None, transform, "label", [pcoll]))
+          AppliedPTransform(None, transform, "label", {'in': pcoll}))
       self.assertEqual(pcoll.element_type, typehints.Any)
 
   def test_flatten_input_with_visitor_with_single_input(self):
@@ -387,11 +388,11 @@
 
   def _test_flatten_input_visitor(self, input_type, output_type, num_inputs):
     p = TestPipeline()
-    inputs = []
-    for _ in range(num_inputs):
+    inputs = {}
+    for ix in range(num_inputs):
       input_pcoll = PCollection(p)
       input_pcoll.element_type = input_type
-      inputs.append(input_pcoll)
+      inputs[str(ix)] = input_pcoll
     output_pcoll = PCollection(p)
     output_pcoll.element_type = output_type
 
@@ -399,7 +400,7 @@
     flatten.add_output(output_pcoll, None)
     DataflowRunner.flatten_input_visitor().visit_transform(flatten)
     for _ in range(num_inputs):
-      self.assertEqual(inputs[0].element_type, output_type)
+      self.assertEqual(inputs['0'].element_type, output_type)
 
   def test_gbk_then_flatten_input_visitor(self):
     p = TestPipeline(
@@ -442,7 +443,7 @@
         z: (x, y, z),
         beam.pvalue.AsSingleton(pc),
         beam.pvalue.AsMultiMap(pc))
-    applied_transform = AppliedPTransform(None, transform, "label", [pc])
+    applied_transform = AppliedPTransform(None, transform, "label", {'pc': pc})
     DataflowRunner.side_input_visitor(
         use_fn_api=True).visit_transform(applied_transform)
     self.assertEqual(2, len(applied_transform.side_inputs))
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py b/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
index 456636c..448fca7 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument.py
@@ -685,8 +685,8 @@
 
       def visit_transform(self, transform_node):
         if transform_node.inputs:
-          input_list = list(transform_node.inputs)
-          for i, input_pcoll in enumerate(input_list):
+          main_inputs = dict(transform_node.main_inputs)
+          for tag, input_pcoll in main_inputs.items():
             key = self._pin.cache_key(input_pcoll)
 
             # Replace the input pcollection with the cached pcollection (if it
@@ -694,9 +694,9 @@
             if key in self._pin._cached_pcoll_read:
               # Ignore this pcoll in the final pruned instrumented pipeline.
               self._pin._ignored_targets.add(input_pcoll)
-              input_list[i] = self._pin._cached_pcoll_read[key]
+              main_inputs[tag] = self._pin._cached_pcoll_read[key]
           # Update the transform with its new inputs.
-          transform_node.inputs = tuple(input_list)
+          transform_node.main_inputs = main_inputs
 
     v = ReadCacheWireVisitor(self)
     pipeline.visit(v)
diff --git a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
index 40d19a7..a3f91c0 100644
--- a/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
+++ b/sdks/python/apache_beam/runners/interactive/pipeline_instrument_test.py
@@ -297,11 +297,11 @@
 
       def visit_transform(self, transform_node):
         if transform_node.inputs:
-          input_list = list(transform_node.inputs)
-          for i in range(len(input_list)):
-            if input_list[i] == init_pcoll:
-              input_list[i] = cached_init_pcoll
-          transform_node.inputs = tuple(input_list)
+          main_inputs = dict(transform_node.main_inputs)
+          for tag in main_inputs.keys():
+            if main_inputs[tag] == init_pcoll:
+              main_inputs[tag] = cached_init_pcoll
+          transform_node.main_inputs = main_inputs
 
     v = TestReadCacheWireVisitor()
     p_origin.visit(v)
diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py
index 788485cd..50f7715 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -50,6 +50,7 @@
 from typing import Callable
 from typing import Dict
 from typing import List
+from typing import Mapping
 from typing import Optional
 from typing import Sequence
 from typing import Tuple
@@ -253,7 +254,7 @@
       return self.visit_nested(node)
 
 
-def get_named_nested_pvalues(pvalueish):
+def get_named_nested_pvalues(pvalueish, as_inputs=False):
   if isinstance(pvalueish, tuple):
     # Check to see if it's a named tuple.
     fields = getattr(pvalueish, '_fields', None)
@@ -262,16 +263,22 @@
     else:
       tagged_values = enumerate(pvalueish)
   elif isinstance(pvalueish, list):
+    if as_inputs:
+      # Full list treated as a list of value for eager evaluation.
+      yield None, pvalueish
+      return
     tagged_values = enumerate(pvalueish)
   elif isinstance(pvalueish, dict):
     tagged_values = pvalueish.items()
   else:
-    if isinstance(pvalueish, (pvalue.PValue, pvalue.DoOutputsTuple)):
+    if as_inputs or isinstance(pvalueish,
+                               (pvalue.PValue, pvalue.DoOutputsTuple)):
       yield None, pvalueish
     return
 
   for tag, subvalue in tagged_values:
-    for subtag, subsubvalue in get_named_nested_pvalues(subvalue):
+    for subtag, subsubvalue in get_named_nested_pvalues(
+        subvalue, as_inputs=as_inputs):
       if subtag is None:
         yield tag, subsubvalue
       else:
@@ -569,6 +576,8 @@
   def __ror__(self, left, label=None):
     """Used to apply this PTransform to non-PValues, e.g., a tuple."""
     pvalueish, pvalues = self._extract_input_pvalues(left)
+    if isinstance(pvalues, dict):
+      pvalues = tuple(pvalues.values())
     pipelines = [v.pipeline for v in pvalues if isinstance(v, pvalue.PValue)]
     if pvalues and not pipelines:
       deferred = False
@@ -597,8 +606,7 @@
     # pylint: enable=wrong-import-order, wrong-import-position
     replacements = {
         id(v): p | 'CreatePInput%s' % ix >> Create(v, reshuffle=False)
-        for ix,
-        v in enumerate(pvalues)
+        for (ix, v) in enumerate(pvalues)
         if not isinstance(v, pvalue.PValue) and v is not None
     }
     pvalueish = _SetInputPValues().visit(pvalueish, replacements)
@@ -628,19 +636,11 @@
     if isinstance(pvalueish, pipeline.Pipeline):
       pvalueish = pvalue.PBegin(pvalueish)
 
-    def _dict_tuple_leaves(pvalueish):
-      if isinstance(pvalueish, tuple):
-        for a in pvalueish:
-          for p in _dict_tuple_leaves(a):
-            yield p
-      elif isinstance(pvalueish, dict):
-        for a in pvalueish.values():
-          for p in _dict_tuple_leaves(a):
-            yield p
-      else:
-        yield pvalueish
-
-    return pvalueish, tuple(_dict_tuple_leaves(pvalueish))
+    return pvalueish, {
+        str(tag): value
+        for (tag, value) in get_named_nested_pvalues(
+            pvalueish, as_inputs=True)
+    }
 
   def _pvaluish_from_dict(self, input_dict):
     if len(input_dict) == 1:
@@ -648,16 +648,15 @@
     else:
       return input_dict
 
-  def _named_inputs(self, inputs, side_inputs):
-    # type: (Sequence[pvalue.PValue], Sequence[Any]) -> Dict[str, pvalue.PValue]
+  def _named_inputs(self, main_inputs, side_inputs):
+    # type: (Mapping[str, pvalue.PValue], Sequence[Any]) -> Dict[str, pvalue.PValue]
 
     """Returns the dictionary of named inputs (including side inputs) as they
     should be named in the beam proto.
     """
-    # TODO(BEAM-1833): Push names up into the sdk construction.
     main_inputs = {
-        str(ix): input
-        for (ix, input) in enumerate(inputs)
+        tag: input
+        for (tag, input) in main_inputs.items()
         if isinstance(input, pvalue.PCollection)
     }
     named_side_inputs = {(SIDE_INPUT_PREFIX + '%s') % ix: si.pvalue