[improve][fn] Support --retain-key-ordering (KEY_SHARED sub) in Python functions (#20756)

diff --git a/pulsar-functions/instance/src/main/python/python_instance.py b/pulsar-functions/instance/src/main/python/python_instance.py
index 57edbf9..2ab3ccc 100755
--- a/pulsar-functions/instance/src/main/python/python_instance.py
+++ b/pulsar-functions/instance/src/main/python/python_instance.py
@@ -147,6 +147,8 @@
     if self.instance_config.function_details.retainOrdering or \
       self.instance_config.function_details.processingGuarantees == Function_pb2.ProcessingGuarantees.Value("EFFECTIVELY_ONCE"):
       mode = pulsar._pulsar.ConsumerType.Failover
+    elif self.instance_config.function_details.retainKeyOrdering:
+      mode = pulsar._pulsar.ConsumerType.KeyShared
 
     position = pulsar._pulsar.InitialPosition.Latest
     if self.instance_config.function_details.source.subscriptionPosition == Function_pb2.SubscriptionPosition.Value("EARLIEST"):
diff --git a/pulsar-functions/utils/src/main/java/org/apache/pulsar/functions/utils/FunctionConfigUtils.java b/pulsar-functions/utils/src/main/java/org/apache/pulsar/functions/utils/FunctionConfigUtils.java
index 4d7e1c9..8e95e0f 100644
--- a/pulsar-functions/utils/src/main/java/org/apache/pulsar/functions/utils/FunctionConfigUtils.java
+++ b/pulsar-functions/utils/src/main/java/org/apache/pulsar/functions/utils/FunctionConfigUtils.java
@@ -714,10 +714,6 @@
         if (functionConfig.getMaxMessageRetries() != null && functionConfig.getMaxMessageRetries() >= 0) {
             throw new IllegalArgumentException("Message retries not yet supported in python");
         }
-
-        if (functionConfig.getRetainKeyOrdering() != null && functionConfig.getRetainKeyOrdering()) {
-            throw new IllegalArgumentException("Retain Key Orderering not yet supported in python");
-        }
     }
 
     private static void doGolangChecks(FunctionConfig functionConfig) {
diff --git a/tests/integration/src/test/java/org/apache/pulsar/tests/integration/functions/PulsarFunctionsTest.java b/tests/integration/src/test/java/org/apache/pulsar/tests/integration/functions/PulsarFunctionsTest.java
index db18451..b78a832 100644
--- a/tests/integration/src/test/java/org/apache/pulsar/tests/integration/functions/PulsarFunctionsTest.java
+++ b/tests/integration/src/test/java/org/apache/pulsar/tests/integration/functions/PulsarFunctionsTest.java
@@ -62,10 +62,12 @@
 import org.apache.pulsar.client.impl.PulsarClientImpl;
 import org.apache.pulsar.client.impl.schema.generic.GenericJsonRecord;
 import org.apache.pulsar.common.functions.ConsumerConfig;
+import org.apache.pulsar.common.functions.FunctionConfig;
 import org.apache.pulsar.common.policies.data.FunctionStatsImpl;
 import org.apache.pulsar.common.policies.data.FunctionStatus;
 import org.apache.pulsar.common.policies.data.FunctionStatusUtil;
 import org.apache.pulsar.common.policies.data.SchemaCompatibilityStrategy;
+import org.apache.pulsar.common.policies.data.SubscriptionStats;
 import org.apache.pulsar.common.policies.data.TopicStats;
 import org.apache.pulsar.common.schema.KeyValue;
 import org.apache.pulsar.common.schema.KeyValueEncodingType;
@@ -379,10 +381,10 @@
         if (runtime == Runtime.PYTHON) {
             submitFunction(
                     runtime, inputTopicName, outputTopicName, functionName, EXCEPTION_FUNCTION_PYTHON_FILE,
-                    EXCEPTION_PYTHON_CLASS, schema);
+                    EXCEPTION_PYTHON_CLASS, schema, null);
         } else {
             submitFunction(
-                    runtime, inputTopicName, outputTopicName, functionName, null, EXCEPTION_JAVA_CLASS, schema);
+                    runtime, inputTopicName, outputTopicName, functionName, null, EXCEPTION_JAVA_CLASS, schema, null);
         }
 
         // get function info
@@ -563,7 +565,7 @@
                         PUBLISH_JAVA_CLASS,
                         schema,
                         Collections.singletonMap("publish-topic", outputTopicName),
-                        null, null, null, null, null);
+                        null, null, null, null, null, null);
                 break;
             case PYTHON:
                 ConsumerConfig consumerConfig = new ConsumerConfig();
@@ -580,7 +582,7 @@
                         PUBLISH_PYTHON_CLASS,
                         schema,
                         Collections.singletonMap("publish-topic", outputTopicName),
-                        objectMapper.writeValueAsString(inputSpecs), "string", null, null, null);
+                        objectMapper.writeValueAsString(inputSpecs), "string", null, null, null, null);
                 break;
             case GO:
                 submitFunction(
@@ -592,7 +594,7 @@
                         null,
                         schema,
                         Collections.singletonMap("publish-topic", outputTopicName),
-                        null, null, null, null, null);
+                        null, null, null, null, null, null);
         }
 
         // get function info
@@ -667,6 +669,15 @@
                                            boolean pyZip,
                                            boolean multipleInput,
                                            boolean withExtraDeps) throws Exception {
+        testExclamationFunction(runtime, isTopicPattern, pyZip, multipleInput, withExtraDeps, null);
+    }
+
+    protected void testExclamationFunction(Runtime runtime,
+                                           boolean isTopicPattern,
+                                           boolean pyZip,
+                                           boolean multipleInput,
+                                           boolean withExtraDeps,
+                                           java.util.function.Consumer<CommandGenerator> commandGeneratorConsumer) throws Exception {
         if (functionRuntimeType == FunctionRuntimeType.THREAD && (runtime == Runtime.PYTHON || runtime == Runtime.GO)) {
             // python&go can only run on process mode
             return;
@@ -696,10 +707,10 @@
 
         // submit the exclamation function
         submitExclamationFunction(
-                runtime, inputTopicName, outputTopicName, functionName, pyZip, withExtraDeps, schema);
+                runtime, inputTopicName, outputTopicName, functionName, pyZip, withExtraDeps, schema, commandGeneratorConsumer);
 
         // get function info
-        getFunctionInfoSuccess(functionName);
+        final String info = getFunctionInfoSuccess(functionName);
 
         // get function stats
         getFunctionStatsEmpty(functionName);
@@ -741,6 +752,9 @@
                 break;
         }
 
+        checkSubscriptionType(inputTopicName,
+                ObjectMapperFactory.getMapper().getObjectMapper().readValue(info, FunctionConfig.class));
+
         // delete function
         deleteFunction(functionName);
 
@@ -752,6 +766,41 @@
 
     }
 
+    private void checkSubscriptionType(String topic, FunctionConfig config) {
+        List<String> topics = new ArrayList<>();
+        if (topic.endsWith(".*")) {
+            topics.add(topic.substring(0, topic.length() - 2) + "1");
+            topics.add(topic.substring(0, topic.length() - 2) + "2");
+        } else if (topic.contains(",")) {
+            topics.addAll(Arrays.asList(topic.split(",")));
+        } else {
+            topics.add(topic);
+        }
+        topics.stream().forEach(t -> {
+            try {
+                ContainerExecResult result = pulsarCluster.getAnyBroker().execCmd(
+                        PulsarCluster.ADMIN_SCRIPT,
+                        "topics",
+                        "stats",
+                        t);
+                TopicStats topicStats = ObjectMapperFactory.getMapper().reader()
+                        .readValue(result.getStdout(), TopicStats.class);
+                assertEquals(topicStats.getSubscriptions().size(), 1);
+                final SubscriptionStats sub = topicStats.getSubscriptions().values().iterator()
+                        .next();
+                if (config.getRetainOrdering()) {
+                    assertEquals(sub.getType(), "Failover");
+                } else if (config.getRetainKeyOrdering()) {
+                    assertEquals(sub.getType(), "Key_Shared");
+                } else {
+                    assertEquals(sub.getType(), "Shared");
+                }
+            } catch (Exception e) {
+                fail("Command should have exited with non-zero");
+            }
+        });
+    }
+
     private void submitExclamationFunction(Runtime runtime,
                                            String inputTopicName,
                                            String outputTopicName,
@@ -759,6 +808,18 @@
                                            boolean pyZip,
                                            boolean withExtraDeps,
                                            Schema<?> schema) throws Exception {
+        submitExclamationFunction(runtime, inputTopicName, outputTopicName, functionName, pyZip,
+                withExtraDeps, schema, null);
+    }
+
+    private void submitExclamationFunction(Runtime runtime,
+                                           String inputTopicName,
+                                           String outputTopicName,
+                                           String functionName,
+                                           boolean pyZip,
+                                           boolean withExtraDeps,
+                                           Schema<?> schema,
+                                           java.util.function.Consumer<CommandGenerator> commandGeneratorConsumer) throws Exception {
         submitFunction(
                 runtime,
                 inputTopicName,
@@ -768,7 +829,8 @@
                 withExtraDeps,
                 false,
                 getExclamationClass(runtime, pyZip, withExtraDeps),
-                schema);
+                schema,
+                commandGeneratorConsumer);
     }
 
     private <T> void submitFunction(Runtime runtime,
@@ -779,7 +841,8 @@
                                     boolean withExtraDeps,
                                     boolean isPublishFunction,
                                     String functionClass,
-                                    Schema<T> inputTopicSchema) throws Exception {
+                                    Schema<T> inputTopicSchema,
+                                    java.util.function.Consumer<CommandGenerator> commandGeneratorConsumer) throws Exception {
 
         String file = null;
         if (Runtime.JAVA == runtime) {
@@ -802,7 +865,8 @@
             }
         }
 
-        submitFunction(runtime, inputTopicName, outputTopicName, functionName, file, functionClass, inputTopicSchema);
+        submitFunction(runtime, inputTopicName, outputTopicName, functionName, file, functionClass, inputTopicSchema,
+                commandGeneratorConsumer);
     }
 
     private <T> void submitFunction(Runtime runtime,
@@ -811,9 +875,11 @@
                                     String functionName,
                                     String functionFile,
                                     String functionClass,
-                                    Schema<T> inputTopicSchema) throws Exception {
+                                    Schema<T> inputTopicSchema,
+                                    java.util.function.Consumer<CommandGenerator> commandGeneratorConsumer) throws Exception {
         submitFunction(runtime, inputTopicName, outputTopicName, functionName, functionFile, functionClass,
-                inputTopicSchema, null, null, null, null, null, null);
+                inputTopicSchema, null, null, null, null, null, null,
+                commandGeneratorConsumer);
     }
 
     private <T> void submitFunction(Runtime runtime,
@@ -828,7 +894,8 @@
                                     String outputSchemaType,
                                     SubscriptionInitialPosition subscriptionInitialPosition,
                                     String inputTypeClassName,
-                                    String outputTypeClassName) throws Exception {
+                                    String outputTypeClassName,
+                                    java.util.function.Consumer<CommandGenerator> commandGeneratorConsumer) throws Exception {
 
         if (StringUtils.isNotEmpty(inputTopicName)) {
             ensureSubscriptionCreated(
@@ -864,6 +931,9 @@
         if (outputTypeClassName != null) {
             generator.setOutputTypeClassName(outputTypeClassName);
         }
+        if (commandGeneratorConsumer != null) {
+            commandGeneratorConsumer.accept(generator);
+        }
         String command = "";
 
         switch (runtime) {
@@ -994,7 +1064,7 @@
         }
     }
 
-    protected void getFunctionInfoSuccess(String functionName) throws Exception {
+    protected String getFunctionInfoSuccess(String functionName) throws Exception {
         ContainerExecResult result = pulsarCluster.getAnyWorker().execCmd(
                 PulsarCluster.ADMIN_SCRIPT,
                 "functions",
@@ -1006,8 +1076,10 @@
 
         log.info("FUNCTION STATE: {}", result.getStdout());
         assertTrue(result.getStdout().contains("\"name\": \"" + functionName + "\""));
+        return result.getStdout();
     }
 
+
     protected void getFunctionStatsEmpty(String functionName) throws Exception {
         ContainerExecResult result = pulsarCluster.getAnyWorker().execCmd(
                 PulsarCluster.ADMIN_SCRIPT,
@@ -1257,7 +1329,8 @@
         }
 
         for (int i = 0; i < numMessages; i++) {
-            Message<String> msg = consumer.receive(30, TimeUnit.SECONDS);
+            log.info("Trying to receive message.. {}/{}", i, numMessages);
+            Message<String> msg = consumer.receive(30, TimeUnit.MINUTES);
             log.info("Received: {}", msg.getValue());
             assertTrue(expectedMessages.contains(msg.getValue()));
             expectedMessages.remove(msg.getValue());
@@ -1364,7 +1437,8 @@
                 false,
                 false,
                 AutoSchemaFunction.class.getName(),
-                Schema.AVRO(CustomObject.class));
+                Schema.AVRO(CustomObject.class),
+                null);
 
         // get function info
         getFunctionInfoSuccess(functionName);
@@ -1474,7 +1548,8 @@
                     functionName,
                     null,
                     AvroSchemaTestFunction.class.getName(),
-                    Schema.AVRO(AvroTestObject.class));
+                    Schema.AVRO(AvroTestObject.class),
+                    null);
         } else if (runtime == Runtime.PYTHON) {
             ConsumerConfig consumerConfig = new ConsumerConfig();
             consumerConfig.setSchemaType("avro");
@@ -1490,7 +1565,8 @@
                     AVRO_SCHEMA_PYTHON_CLASS,
                     Schema.AVRO(AvroTestObject.class),
                     null, objectMapper.writeValueAsString(inputSpecs), "avro", null,
-                    "avro_schema_test_function.AvroTestObject", "avro_schema_test_function.AvroTestObject");
+                    "avro_schema_test_function.AvroTestObject", "avro_schema_test_function.AvroTestObject",
+                    null);
         }
         log.info("pulsar submitFunction");
 
@@ -1567,7 +1643,7 @@
         // submit the exclamation function
         submitFunction(runtime, inputTopicName, outputTopicName, functionName, null,
                 InitializableFunction.class.getName(), schema,
-                Collections.singletonMap("publish-topic", outputTopicName), null, null, null, null, null);
+                Collections.singletonMap("publish-topic", outputTopicName), null, null, null, null, null, null);
 
         // publish and consume result
         publishAndConsumeMessages(inputTopicName, outputTopicName, numMessages);
@@ -1760,7 +1836,7 @@
                 null,
                 null,
                 SchemaType.NONE.name(),
-                SubscriptionInitialPosition.Earliest, null, null);
+                SubscriptionInitialPosition.Earliest, null, null, null);
         try {
             if (keyValue) {
                 @Cleanup
@@ -1876,7 +1952,8 @@
                 functionName,
                 null,
                 RecordFunction.class.getName(),
-                Schema.AUTO_CONSUME());
+                Schema.AUTO_CONSUME(),
+                null);
         try {
             @Cleanup
             Producer<String> producer = pulsarClient
@@ -1946,7 +2023,7 @@
                 null,
                 inputSpecNode.toString(),
                 SchemaType.AUTO_PUBLISH.name().toUpperCase(),
-                SubscriptionInitialPosition.Earliest, null, null);
+                SubscriptionInitialPosition.Earliest, null, null, null);
 
         getFunctionInfoSuccess(functionName);
 
diff --git a/tests/integration/src/test/java/org/apache/pulsar/tests/integration/functions/python/PulsarFunctionsPythonTest.java b/tests/integration/src/test/java/org/apache/pulsar/tests/integration/functions/python/PulsarFunctionsPythonTest.java
index 87a52d2..9ba210b 100644
--- a/tests/integration/src/test/java/org/apache/pulsar/tests/integration/functions/python/PulsarFunctionsPythonTest.java
+++ b/tests/integration/src/test/java/org/apache/pulsar/tests/integration/functions/python/PulsarFunctionsPythonTest.java
@@ -69,4 +69,21 @@
         testAvroSchemaFunction(Runtime.PYTHON);
     }
 
+    @Test(groups = {"python_function", "function"})
+    public void testRetainOrderingTest() throws Exception {
+        testExclamationFunction(Runtime.PYTHON, false, false, false,
+                false, generator -> {
+                    generator.setRetainOrdering(true);
+                });
+    }
+
+    @Test(groups = {"python_function", "function"})
+    public void testRetainKeyOrderingTest() throws Exception {
+        testExclamationFunction(Runtime.PYTHON, false, false, false,
+                false, generator -> {
+                    System.out.println("calling generator.setRetainKeyOrdering(true);");
+                    generator.setRetainKeyOrdering(true);
+                });
+    }
+
 }
diff --git a/tests/integration/src/test/java/org/apache/pulsar/tests/integration/functions/utils/CommandGenerator.java b/tests/integration/src/test/java/org/apache/pulsar/tests/integration/functions/utils/CommandGenerator.java
index adc791f..e0fbd60 100644
--- a/tests/integration/src/test/java/org/apache/pulsar/tests/integration/functions/utils/CommandGenerator.java
+++ b/tests/integration/src/test/java/org/apache/pulsar/tests/integration/functions/utils/CommandGenerator.java
@@ -64,6 +64,8 @@
     private String outputTypeClassName;
     private String schemaType;
     private SubscriptionInitialPosition subscriptionInitialPosition;
+    private Boolean retainOrdering;
+    private Boolean retainKeyOrdering;
 
     private Map<String, String> userConfig = new HashMap<>();
     public static final String JAVAJAR = "/pulsar/examples/java-test-functions.jar";
@@ -227,6 +229,12 @@
         if (subscriptionInitialPosition != null) {
             commandBuilder.append(" --subs-position " + subscriptionInitialPosition.name());
         }
+        if (retainOrdering != null) {
+            commandBuilder.append(" --retain-ordering ");
+        }
+        if (retainKeyOrdering != null) {
+            commandBuilder.append(" --retain-key-ordering ");
+        }
 
         switch (runtime){
             case JAVA: