Merge pull request #51 from duhenglucky/transaction

Add transaction message support
diff --git a/src/PythonWrapper.cpp b/src/PythonWrapper.cpp
index 404c92a..2783015 100644
--- a/src/PythonWrapper.cpp
+++ b/src/PythonWrapper.cpp
@@ -17,7 +17,6 @@
 #include "CCommon.h"
 #include "CMessage.h"
 #include "CMessageExt.h"
-#include "CBatchMessage.h"
 #include "CSendResult.h"
 #include "CProducer.h"
 #include "CPushConsumer.h"
@@ -33,6 +32,8 @@
         "PYTHON_CLIENT_VERSION: " PYTHON_CLIENT_VERSION ", BUILD DATE: " PYCLI_BUILD_DATE " ";
 
 map<CPushConsumer *, pair<PyObject *, object>> g_CallBackMap;
+map<CProducer *, PyObject *> g_TransactionCheckCallBackMap;
+
 
 class PyThreadStateLock {
 public:
@@ -96,18 +97,6 @@
     return SetDelayTimeLevel((CMessage *) msg, level);
 }
 
-//batch message
-void *PyCreateBatchMessage() {
-    return (void *) CreateBatchMessage();
-}
-
-int PyAddMessage(void *batchMsg, void *msg) {
-    return AddMessage((CBatchMessage *) batchMsg, (CMessage *) msg);
-}
-
-int PyDestroyBatchMessage(void *batchMsg) {
-    return DestroyBatchMessage((CBatchMessage *) batchMsg);
-}
 
 //messageExt
 const char *PyGetMessageTopic(PyMessageExt msgExt) {
@@ -134,6 +123,27 @@
     PyEval_InitThreads();  // ensure create GIL, for call Python callback from C.
     return (void *) CreateProducer(groupId);
 }
+
+void *PyCreateTransactionProducer(const char *groupId, PyObject *localTransactionCheckerCallback) {
+    PyEval_InitThreads();
+    CProducer *producer = CreateTransactionProducer(groupId, &PyLocalTransactionCheckerCallback, NULL);
+    g_TransactionCheckCallBackMap[producer] = localTransactionCheckerCallback;
+    return producer;
+}
+
+CTransactionStatus PyLocalTransactionCheckerCallback(CProducer *producer, CMessageExt *msg, void *data) {
+    PyThreadStateLock pyThreadLock;  // ensure hold GIL, before call python callback
+    PyMessageExt message = {.pMessageExt = msg};
+    map<CProducer *, PyObject *>::iterator iter;
+    iter = g_TransactionCheckCallBackMap.find(producer);
+    if (iter != g_TransactionCheckCallBackMap.end()) {
+        PyObject *pCallback = iter->second;
+        CTransactionStatus status = boost::python::call<CTransactionStatus>(pCallback, message);
+        return status;
+    }
+    return CTransactionStatus::E_UNKNOWN_TRANSACTION;
+}
+
 int PyDestroyProducer(void *producer) {
     return DestroyProducer((CProducer *) producer);
 }
@@ -190,23 +200,23 @@
     return SendMessageOneway((CProducer *) producer, (CMessage *) msg);
 }
 
-void PySendSuccessCallback(CSendResult result, CMessage *msg, void *pyCallback){
+void PySendSuccessCallback(CSendResult result, CMessage *msg, void *pyCallback) {
     PyThreadStateLock PyThreadLock;  // ensure hold GIL, before call python callback
     PySendResult sendResult;
     sendResult.sendStatus = result.sendStatus;
     sendResult.offset = result.offset;
     strncpy(sendResult.msgId, result.msgId, MAX_MESSAGE_ID_LENGTH - 1);
     sendResult.msgId[MAX_MESSAGE_ID_LENGTH - 1] = 0;
-    PyCallback *callback = (PyCallback *)pyCallback;
+    PyCallback *callback = (PyCallback *) pyCallback;
     boost::python::call<void>(callback->successCallback, sendResult, (void *) msg);
     delete pyCallback;
 }
 
 
-void PySendExceptionCallback(CMQException e, CMessage *msg, void *pyCallback){
+void PySendExceptionCallback(CMQException e, CMessage *msg, void *pyCallback) {
     PyThreadStateLock PyThreadLock;  // ensure hold GIL, before call python callback
     PyMQException exception;
-    PyCallback *callback = (PyCallback *)pyCallback;
+    PyCallback *callback = (PyCallback *) pyCallback;
     exception.error = e.error;
     exception.line = e.line;
     strncpy(exception.file, e.file, MAX_EXEPTION_FILE_LENGTH - 1);
@@ -219,30 +229,21 @@
     delete pyCallback;
 }
 
-int PySendMessageAsync(void *producer, void *msg, PyObject *sendSuccessCallback, PyObject *sendExceptionCallback){
-    PyCallback* pyCallback = new PyCallback();
+int PySendMessageAsync(void *producer, void *msg, PyObject *sendSuccessCallback, PyObject *sendExceptionCallback) {
+    PyCallback *pyCallback = new PyCallback();
     pyCallback->successCallback = sendSuccessCallback;
     pyCallback->exceptionCallback = sendExceptionCallback;
-    return SendAsync((CProducer *) producer,  (CMessage *) msg, &PySendSuccessCallback, &PySendExceptionCallback, (void *)pyCallback);
-}
-
-PySendResult PySendBatchMessage(void *producer, void *batchMessage) {
-    PySendResult ret;
-    CSendResult result;
-    SendBatchMessage((CProducer *) producer, (CBatchMessage *) batchMessage, &result);
-    ret.sendStatus = result.sendStatus;
-    ret.offset = result.offset;
-    strncpy(ret.msgId, result.msgId, MAX_MESSAGE_ID_LENGTH - 1);
-    ret.msgId[MAX_MESSAGE_ID_LENGTH - 1] = 0;
-    return ret;
+    return SendAsync((CProducer *) producer, (CMessage *) msg, &PySendSuccessCallback, &PySendExceptionCallback,
+                     (void *) pyCallback);
 }
 
 
 PySendResult PySendMessageOrderly(void *producer, void *msg, int autoRetryTimes, void *args, PyObject *queueSelector) {
     PySendResult ret;
     CSendResult result;
-    PyUserData userData = {queueSelector,args};
-    SendMessageOrderly((CProducer *) producer, (CMessage *) msg, &PyOrderlyCallbackInner, &userData, autoRetryTimes, &result);
+    PyUserData userData = {queueSelector, args};
+    SendMessageOrderly((CProducer *) producer, (CMessage *) msg, &PyOrderlyCallbackInner, &userData, autoRetryTimes,
+                       &result);
     ret.sendStatus = result.sendStatus;
     ret.offset = result.offset;
     strncpy(ret.msgId, result.msgId, MAX_MESSAGE_ID_LENGTH - 1);
@@ -251,7 +252,7 @@
 }
 
 int PyOrderlyCallbackInner(int size, CMessage *msg, void *args) {
-    PyUserData *userData = (PyUserData *)args;
+    PyUserData *userData = (PyUserData *) args;
     int index = boost::python::call<int>(userData->pyObject, size, (void *) msg, userData->pData);
     return index;
 }
@@ -267,6 +268,26 @@
     return ret;
 }
 
+CTransactionStatus PyLocalTransactionExecuteCallback(CProducer *producer, CMessage *msg, void *data) {
+    PyUserData *localCallback = (PyUserData *) data;
+    CTransactionStatus status = boost::python::call<CTransactionStatus>(localCallback->pyObject, (void *) msg,
+                                                                        localCallback->pData);
+    return status;
+}
+
+PySendResult PySendMessageInTransaction(void *producer, void *msg, PyObject *localTransactionCallback, void *args) {
+    PyUserData userData = {localTransactionCallback, args};
+    PySendResult ret;
+    CSendResult result;
+    SendMessageTransaction((CProducer *) producer, (CMessage *) msg, &PyLocalTransactionExecuteCallback, &userData,
+                           &result);
+    ret.sendStatus = result.sendStatus;
+    ret.offset = result.offset;
+    strncpy(ret.msgId, result.msgId, MAX_MESSAGE_ID_LENGTH - 1);
+    ret.msgId[MAX_MESSAGE_ID_LENGTH - 1] = 0;
+    return ret;
+}
+
 //SendResult
 const char *PyGetSendResultMsgID(CSendResult &sendResult) {
     return (const char *) (sendResult.msgId);
@@ -286,6 +307,15 @@
     }
     return DestroyPushConsumer(consumerInner);
 }
+int PyDestroyTransactionProducer(void *producer) {
+    CProducer *producerInner = (CProducer *) producer;
+    map<CProducer *, PyObject *>::iterator iter;
+    iter = g_TransactionCheckCallBackMap.find(producerInner);
+    if (iter != g_TransactionCheckCallBackMap.end()) {
+        g_TransactionCheckCallBackMap.erase(iter);
+    }
+    return DestroyProducer(producerInner);
+}
 int PyStartPushConsumer(void *consumer) {
     return StartPushConsumer((CPushConsumer *) consumer);
 }
@@ -308,7 +338,7 @@
     return RegisterMessageCallback(consumerInner, &PythonMessageCallBackInner);
 }
 
-int PyRegisterMessageCallbackOrderly(void *consumer, PyObject *pCallback, object args){
+int PyRegisterMessageCallbackOrderly(void *consumer, PyObject *pCallback, object args) {
     CPushConsumer *consumerInner = (CPushConsumer *) consumer;
     g_CallBackMap[consumerInner] = make_pair(pCallback, std::move(args));
     return RegisterMessageCallbackOrderly(consumerInner, &PythonMessageCallBackInner);
@@ -418,6 +448,10 @@
             .value("E_LOG_LEVEL_TRACE", E_LOG_LEVEL_TRACE)
             .value("E_LOG_LEVEL_LEVEL_NUM", E_LOG_LEVEL_LEVEL_NUM);
 
+    enum_<CTransactionStatus>("TransactionStatus")
+            .value("E_COMMIT_TRANSACTION", E_COMMIT_TRANSACTION)
+            .value("E_ROLLBACK_TRANSACTION", E_ROLLBACK_TRANSACTION)
+            .value("E_UNKNOWN_TRANSACTION", E_UNKNOWN_TRANSACTION);
 
     //For Message
     def("CreateMessage", PyCreateMessage, return_value_policy<return_opaque_pointer>());
@@ -430,11 +464,6 @@
     def("SetMessageProperty", PySetMessageProperty);
     def("SetDelayTimeLevel", PySetMessageDelayTimeLevel);
 
-    //For batch message
-    def("CreateBatchMessage", PyCreateBatchMessage, return_value_policy<return_opaque_pointer>());
-    def("AddMessage", PyAddMessage);
-    def("DestroyBatchMessage", PyDestroyBatchMessage);
-
     //For MessageExt
     def("GetMessageTopic", PyGetMessageTopic);
     def("GetMessageTags", PyGetMessageTags);
@@ -445,7 +474,9 @@
 
     //For producer
     def("CreateProducer", PyCreateProducer, return_value_policy<return_opaque_pointer>());
+    def("CreateTransactionProducer", PyCreateTransactionProducer, return_value_policy<return_opaque_pointer>());
     def("DestroyProducer", PyDestroyProducer);
+    def("DestroyTransactionProducer", PyDestroyTransactionProducer);
     def("StartProducer", PyStartProducer);
     def("ShutdownProducer", PyShutdownProducer);
     def("SetProducerNameServerAddress", PySetProducerNameServerAddress);
@@ -462,11 +493,11 @@
 
     def("SendMessageSync", PySendMessageSync);
     def("SendMessageAsync", PySendMessageAsync);
-    def("SendBatchMessage", PySendBatchMessage);
 
     def("SendMessageOneway", PySendMessageOneway);
     def("SendMessageOrderly", PySendMessageOrderly);
     def("SendMessageOrderlyByShardingKey", PySendMessageOrderlyByShardingKey);
+    def("SendMessageInTransaction", PySendMessageInTransaction);
 
     //For Consumer
     def("CreatePushConsumer", PyCreatePushConsumer, return_value_policy<return_opaque_pointer>());
diff --git a/src/PythonWrapper.h b/src/PythonWrapper.h
index 29a4952..732320b 100644
--- a/src/PythonWrapper.h
+++ b/src/PythonWrapper.h
@@ -18,7 +18,6 @@
 #include "CCommon.h"
 #include "CMessage.h"
 #include "CMessageExt.h"
-#include "CBatchMessage.h"
 #include "CSendResult.h"
 #include "CProducer.h"
 #include "CPushConsumer.h"
@@ -91,11 +90,6 @@
 int PySetMessageProperty(void *msg, const char *key, const char *value);
 int PySetMessageDelayTimeLevel(void *msg, int level);
 
-//batch message
-void *PyCreateBatchMessage();
-int PyAddMessage(void *batchMsg, void *msg);
-int PyDestroyBatchMessage(void *batchMsg);
-
 //messageExt
 const char *PyGetMessageTopic(PyMessageExt msgExt);
 const char *PyGetMessageTags(PyMessageExt msgExt);
@@ -106,7 +100,12 @@
 
 //producer
 void *PyCreateProducer(const char *groupId);
+CTransactionStatus PyLocalTransactionCheckerCallback(CProducer *producer, CMessageExt *msg, void *data);
+CTransactionStatus PyLocalTransactionExecuteCallback(CProducer *producer, CMessage *msg, void *data);
+void *PyCreateTransactionProducer(const char *groupId, PyObject *localTransactionCheckerCallback);
+
 int PyDestroyProducer(void *producer);
+int PyDestroyTransactionProducer(void *producer);
 int PyStartProducer(void *producer);
 int PyShutdownProducer(void *producer);
 int PySetProducerNameServerAddress(void *producer, const char *namesrv);
@@ -127,9 +126,9 @@
 void PySendExceptionCallback(CMQException e, CMessage *msg, void *pyCallback);
 int PySendMessageAsync(void *producer, void *msg, PyObject *sendSuccessCallback, PyObject *sendExceptionCallback);
 
-PySendResult PySendBatchMessage(void *producer, void *msg);
 PySendResult PySendMessageOrderly(void *producer, void *msg, int autoRetryTimes, void *args, PyObject *queueSelector);
 PySendResult PySendMessageOrderlyByShardingKey(void *producer, void *msg, const char *shardingKey);
+PySendResult PySendMessageInTransaction(void *producer , void *msg, PyObject *localTransactionExecuteCallback , void *args);
 
 int PyOrderlyCallbackInner(int size, CMessage *msg, void *args);
 
diff --git a/test/TestSendMessages.py b/test/TestSendMessages.py
index 9baf78e..179e1f1 100644
--- a/test/TestSendMessages.py
+++ b/test/TestSendMessages.py
@@ -31,8 +31,18 @@
     StartProducer(producer)
     return producer
 
+def transaction_local_checker(msg):
+    print 'begin check for msg: ' + GetMessageId(msg)
+    return TransactionStatus.E_COMMIT_TRANSACTION
 
-producer = init_producer()
+def init_transaction_producer():
+    producer = CreateTransactionProducer('TransactionTestProducer', transaction_local_checker)
+    SetProducerLogLevel(producer, CLogLevel.E_LOG_LEVEL_INFO)
+    SetProducerNameServerAddress(producer, name_srv)
+    StartProducer(producer)
+    return producer
+
+producer = init_transaction_producer()
 tag = 'rmq-tag'
 key = 'rmq-key'
 
@@ -257,24 +267,24 @@
     print 'send message failed'
     print 'error msg: ' + exception.GetMsg()
 
-def send_batch_message(batch_count):
+def send_transaction_message(count):
     key = 'rmq-key'
-    print 'start send batch message'
+    print 'start send transaction message'
     tag = 'test'
-    batchMsg = CreateBatchMessage()
-
     for n in range(count):
         body = 'hi rmq message, now is' + str(n)
         msg = CreateMessage(topic)
         SetMessageBody(msg, body)
         SetMessageKeys(msg, key)
         SetMessageTags(msg, tag)
-        AddMessage(batchMsg, msg)
-        DestroyMessage(msg)
 
-    SendBatchMessage(producer, batchMsg)
-    DestroyBatchMessage(batchMsg)
-    print 'send batch message done'
+        SendMessageInTransaction(producer, msg, transaction_local_execute, None)
+    print 'send transaction message done'
+    time.sleep(10000)
+
+def transaction_local_execute(msg, args):
+    print 'begin execute local transaction'
+    return TransactionStatus.E_UNKNOWN_TRANSACTION
 
 if __name__ == '__main__':
-    send_message_async(10)
+    send_transaction_message(10)