feat(transaction): add trasaction producer message (#261)
* feat(transaction): add trasaction producer message
diff --git a/core/api.go b/core/api.go
index b0273e0..f141f3f 100644
--- a/core/api.go
+++ b/core/api.go
@@ -120,6 +120,23 @@
SendMessageOrderlyByShardingKey(msg *Message, shardingkey string) (*SendResult, error)
}
+// NewTransactionProducer create a new trasaction producer with config
+func NewTransactionProducer(config *ProducerConfig, listener TransactionLocalListener, arg interface{}) (TransactionProducer, error) {
+ return newDefaultTransactionProducer(config, listener, arg)
+}
+
+// TransactionExecutor local executor for transaction message
+type TransactionLocalListener interface {
+ Execute(m *Message, arg interface{}) TransactionStatus
+ Check(m *MessageExt, arg interface{}) TransactionStatus
+}
+
+type TransactionProducer interface {
+ baseAPI
+ // send a transaction message with sync
+ SendMessageTransaction(msg *Message, arg interface{}) (*SendResult, error)
+}
+
// NewPushConsumer create a new consumer with config.
func NewPushConsumer(config *PushConsumerConfig) (PushConsumer, error) {
return newPushConsumer(config)
diff --git a/core/message.go b/core/message.go
index 8c32847..f6f4637 100644
--- a/core/message.go
+++ b/core/message.go
@@ -37,12 +37,18 @@
Body string
DelayTimeLevel int
Property map[string]string
+ cmsg *C.struct_CMessage
}
func (msg *Message) String() string {
return fmt.Sprintf("[Topic: %s, Tags: %s, Keys: %s, Body: %s, DelayTimeLevel: %d, Property: %v]",
msg.Topic, msg.Tags, msg.Keys, msg.Body, msg.DelayTimeLevel, msg.Property)
}
+func (msg *Message) GetProperty(key string) string {
+ ck := C.CString(key)
+ defer C.free(unsafe.Pointer(ck))
+ return C.GoString(C.GetOriginMessageProperty(msg.cmsg, ck))
+}
func goMsgToC(gomsg *Message) *C.struct_CMessage {
cs := C.CString(gomsg.Topic)
@@ -73,6 +79,19 @@
return cmsg
}
+func cMsgToGo(cMsg *C.struct_CMessage) *Message {
+ gomsg := &Message{}
+
+ gomsg.Topic = C.GoString(C.GetOriginMessageTopic(cMsg))
+ gomsg.Tags = C.GoString(C.GetOriginMessageTags(cMsg))
+ gomsg.Keys = C.GoString(C.GetOriginMessageKeys(cMsg))
+ gomsg.Body = C.GoString(C.GetOriginMessageBody(cMsg))
+ gomsg.DelayTimeLevel = int(C.GetOriginDelayTimeLevel(cMsg))
+ gomsg.cmsg = cMsg
+
+ return gomsg
+}
+
//MessageExt used for consume
type MessageExt struct {
Message
@@ -99,7 +118,9 @@
//GetProperty get the message property by key from message ext
func (msgExt *MessageExt) GetProperty(key string) string {
- return C.GoString(C.GetMessageProperty(msgExt.cmsgExt, C.CString(key)))
+ ck := C.CString(key)
+ defer C.free(unsafe.Pointer(ck))
+ return C.GoString(C.GetMessageProperty(msgExt.cmsgExt, ck))
}
func cmsgExtToGo(cmsg *C.struct_CMessageExt) *MessageExt {
diff --git a/core/transaction_funcs.go b/core/transaction_funcs.go
new file mode 100644
index 0000000..a41dd4c
--- /dev/null
+++ b/core/transaction_funcs.go
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package rocketmq
+
+/*
+#cgo LDFLAGS: -L/usr/local/lib -lrocketmq
+
+#include <rocketmq/CMessage.h>
+#include <rocketmq/CMessageExt.h>
+#include <rocketmq/CProducer.h>
+*/
+import "C"
+import "sync"
+
+var transactionProducerMap sync.Map
+
+//export localTransactionExecutorCallback
+func localTransactionExecutorCallback(cproducer *C.CProducer, msg *C.CMessage, arg interface{}) C.int {
+ producer, exist := transactionProducerMap.Load(cproducer)
+ if !exist {
+ return C.int(UnknownTransaction)
+ }
+
+ message := cMsgToGo(msg)
+ listenerWrap, exist := producer.(*defaultTransactionProducer).listenerFuncsMap.Load(cproducer)
+ if !exist {
+ status := listenerWrap.(TransactionLocalListener).Execute(message, arg)
+ return C.int(status)
+ }
+ return C.int(UnknownTransaction)
+}
+
+//export localTransactionCheckerCallback
+func localTransactionCheckerCallback(cproducer *C.CProducer, msg *C.CMessageExt, arg interface{}) C.int {
+ producer, exist := transactionProducerMap.Load(cproducer)
+ if !exist {
+ return C.int(UnknownTransaction)
+ }
+
+ message := cmsgExtToGo(msg)
+ listener, exist := producer.(*defaultTransactionProducer).listenerFuncsMap.Load(cproducer)
+ if !exist {
+ status := listener.(TransactionLocalListener).Check(message, arg)
+ return C.int(status)
+ }
+ return C.int(UnknownTransaction)
+}
diff --git a/core/transaction_producer.go b/core/transaction_producer.go
new file mode 100644
index 0000000..da57b33
--- /dev/null
+++ b/core/transaction_producer.go
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package rocketmq
+
+/*
+#cgo LDFLAGS: -L/usr/local/lib/ -lrocketmq
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <rocketmq/CMessage.h>
+#include <rocketmq/CProducer.h>
+#include <rocketmq/CSendResult.h>
+#include <rocketmq/CTransactionStatus.h>
+
+extern int localTransactionCheckerCallback(CProducer *producer, CMessageExt *msg,void *userData);
+int transactionChecker_cgo(CProducer *producer, CMessageExt *msg, void *userData) {
+ return localTransactionCheckerCallback(producer, msg, userData);
+}
+
+extern int localTransactionExecutorCallback(CProducer *producer, CMessage *msg,void *userData);
+int transactionExecutor_cgo(CProducer *producer, CMessage *msg, void *userData) {
+ return localTransactionExecutorCallback(producer, msg, userData);
+}
+*/
+import "C"
+import (
+ "errors"
+ log "github.com/sirupsen/logrus"
+ "sync"
+ "unsafe"
+)
+
+type TransactionStatus int
+
+const (
+ CommitTransaction = TransactionStatus(C.E_COMMIT_TRANSACTION)
+ RollbackTransaction = TransactionStatus(C.E_ROLLBACK_TRANSACTION)
+ UnknownTransaction = TransactionStatus(C.E_UNKNOWN_TRANSACTION)
+)
+
+func (status TransactionStatus) String() string {
+ switch status {
+ case CommitTransaction:
+ return "CommitTransaction"
+ case RollbackTransaction:
+ return "RollbackTransaction"
+ case UnknownTransaction:
+ return "UnknownTransaction"
+ default:
+ return "UnknownTransaction"
+ }
+}
+func newDefaultTransactionProducer(config *ProducerConfig, listener TransactionLocalListener, arg interface{}) (*defaultTransactionProducer, error) {
+ if config == nil {
+ return nil, errors.New("config is nil")
+ }
+
+ if config.GroupID == "" {
+ return nil, errors.New("GroupId is empty")
+ }
+
+ if config.NameServer == "" && config.NameServerDomain == "" {
+ return nil, errors.New("NameServer and NameServerDomain is empty")
+ }
+
+ producer := &defaultTransactionProducer{config: config}
+ cs := C.CString(config.GroupID)
+ var cproduer *C.struct_CProducer
+
+ cproduer = C.CreateTransactionProducer(cs, (C.CLocalTransactionCheckerCallback)(unsafe.Pointer(C.transactionChecker_cgo)), unsafe.Pointer(&arg))
+
+ C.free(unsafe.Pointer(cs))
+
+ if cproduer == nil {
+ return nil, errors.New("create transaction Producer failed")
+ }
+
+ var err rmqError
+ if config.NameServer != "" {
+ cs = C.CString(config.NameServer)
+ err = rmqError(C.SetProducerNameServerAddress(cproduer, cs))
+ C.free(unsafe.Pointer(cs))
+ if err != NIL {
+ return nil, err
+ }
+ }
+
+ if config.NameServerDomain != "" {
+ cs = C.CString(config.NameServerDomain)
+ err = rmqError(C.SetProducerNameServerDomain(cproduer, cs))
+ C.free(unsafe.Pointer(cs))
+ if err != NIL {
+ return nil, err
+ }
+ }
+
+ if config.InstanceName != "" {
+ cs = C.CString(config.InstanceName)
+ err = rmqError(C.SetProducerInstanceName(cproduer, cs))
+ C.free(unsafe.Pointer(cs))
+ if err != NIL {
+ return nil, err
+ }
+ }
+
+ if config.Credentials != nil {
+ ak := C.CString(config.Credentials.AccessKey)
+ sk := C.CString(config.Credentials.SecretKey)
+ ch := C.CString(config.Credentials.Channel)
+ err = rmqError(C.SetProducerSessionCredentials(cproduer, ak, sk, ch))
+
+ C.free(unsafe.Pointer(ak))
+ C.free(unsafe.Pointer(sk))
+ C.free(unsafe.Pointer(ch))
+ if err != NIL {
+ return nil, err
+ }
+ }
+
+ if config.LogC != nil {
+ cs = C.CString(config.LogC.Path)
+ err = rmqError(C.SetProducerLogPath(cproduer, cs))
+ C.free(unsafe.Pointer(cs))
+ if err != NIL {
+ return nil, err
+ }
+
+ err = rmqError(C.SetProducerLogFileNumAndSize(cproduer, C.int(config.LogC.FileNum), C.long(config.LogC.FileSize)))
+ if err != NIL {
+ return nil, err
+ }
+
+ err = rmqError(C.SetProducerLogLevel(cproduer, C.CLogLevel(config.LogC.Level)))
+ if err != NIL {
+ return nil, err
+ }
+ }
+
+ if config.SendMsgTimeout > 0 {
+ err = rmqError(C.SetProducerSendMsgTimeout(cproduer, C.int(config.SendMsgTimeout)))
+ if err != NIL {
+ return nil, err
+ }
+ }
+
+ if config.CompressLevel > 0 {
+ err = rmqError(C.SetProducerCompressLevel(cproduer, C.int(config.CompressLevel)))
+ if err != NIL {
+ return nil, err
+ }
+ }
+
+ if config.MaxMessageSize > 0 {
+ err = rmqError(C.SetProducerMaxMessageSize(cproduer, C.int(config.MaxMessageSize)))
+ if err != NIL {
+ return nil, err
+ }
+ }
+
+ producer.cproduer = cproduer
+ transactionProducerMap.Store(cproduer, producer)
+ producer.listenerFuncsMap.Store(cproduer, listener)
+ return producer, nil
+}
+
+type defaultTransactionProducer struct {
+ config *ProducerConfig
+ cproduer *C.struct_CProducer
+ listenerFuncsMap sync.Map
+}
+
+func (p *defaultTransactionProducer) String() string {
+ return p.config.String()
+}
+
+// Start the producer.
+func (p *defaultTransactionProducer) Start() error {
+ err := rmqError(C.StartProducer(p.cproduer))
+ if err != NIL {
+ return err
+ }
+ return nil
+}
+
+// Shutdown the producer.
+func (p *defaultTransactionProducer) Shutdown() error {
+ err := rmqError(C.ShutdownProducer(p.cproduer))
+
+ if err != NIL {
+ return err
+ }
+
+ err = rmqError(int(C.DestroyProducer(p.cproduer)))
+ if err != NIL {
+ return err
+ }
+
+ return err
+}
+
+func (p *defaultTransactionProducer) SendMessageTransaction(msg *Message, arg interface{}) (*SendResult, error) {
+ cmsg := goMsgToC(msg)
+ defer C.DestroyMessage(cmsg)
+
+ var sr C.struct__SendResult_
+ err := rmqError(C.SendMessageTransaction(p.cproduer, cmsg, (C.CLocalTransactionExecutorCallback)(unsafe.Pointer(C.transactionExecutor_cgo)), unsafe.Pointer(&arg), &sr))
+ if err != NIL {
+ log.Warnf("send message error, error is: %s", err.Error())
+ return nil, err
+ }
+
+ result := &SendResult{}
+ result.Status = SendStatus(sr.sendStatus)
+ result.MsgId = C.GoString(&sr.msgId[0])
+ result.Offset = int64(sr.offset)
+ return result, nil
+}
diff --git a/demos/transaction_producer.go b/demos/transaction_producer.go
new file mode 100644
index 0000000..a4b64b2
--- /dev/null
+++ b/demos/transaction_producer.go
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package main
+
+import (
+ "fmt"
+ "github.com/apache/rocketmq-client-go/core"
+ "time"
+)
+
+// Change to main if you want to run it directly
+func main4() {
+ pConfig := &rocketmq.ProducerConfig{
+ ClientConfig: rocketmq.ClientConfig{
+ GroupID: "GID_XXXXXXXXXXXX",
+ NameServer: "http://XXXXXXXXXXXXXXXXXX:80",
+ Credentials: &rocketmq.SessionCredentials{
+ AccessKey: "Your Access Key",
+ SecretKey: "Your Secret Key",
+ Channel: "ALIYUN/OtherChannel",
+ },
+ },
+ //Set to Common Producer as default.
+ ProducerModel: rocketmq.CommonProducer,
+ }
+ sendTransactionMessage(pConfig)
+}
+
+type MyTransactionLocalListener struct {
+}
+
+func (l *MyTransactionLocalListener) Execute(m *rocketmq.Message, arg interface{}) rocketmq.TransactionStatus {
+ return rocketmq.UnknownTransaction
+}
+func (l *MyTransactionLocalListener) Check(m *rocketmq.MessageExt, arg interface{}) rocketmq.TransactionStatus {
+ return rocketmq.CommitTransaction
+}
+func sendTransactionMessage(config *rocketmq.ProducerConfig) {
+ listener := &MyTransactionLocalListener{}
+ producer, err := rocketmq.NewTransactionProducer(config, listener, listener)
+
+ if err != nil {
+ fmt.Println("create Transaction producer failed, error:", err)
+ return
+ }
+
+ err = producer.Start()
+ if err != nil {
+ fmt.Println("start Transaction producer error", err)
+ return
+ }
+ defer producer.Shutdown()
+
+ fmt.Printf("Transaction producer: %s started... \n", producer)
+ for i := 0; i < 10; i++ {
+ msg := fmt.Sprintf("%s-%d", "Hello,Transaction MQ Message-", i)
+ result, err := producer.SendMessageTransaction(&rocketmq.Message{Topic: "YourTopicXXXXXXXX", Body: msg}, msg)
+ if err != nil {
+ fmt.Println("Error:", err)
+ }
+ fmt.Printf("send message: %s result: %s\n", msg, result)
+ }
+ time.Sleep(time.Duration(1) * time.Minute)
+ fmt.Println("shutdown Transaction producer.")
+}