blob: 500631186f84e3736e2ba1c9e473792a2aaf5837 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.rocketmq.client.producer;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.rocketmq.client.exception.MQBrokerException;
import org.apache.rocketmq.client.exception.MQClientException;
import org.apache.rocketmq.client.log.ClientLogger;
import org.apache.rocketmq.common.ServiceThread;
import org.apache.rocketmq.common.message.Message;
import org.apache.rocketmq.common.message.MessageBatch;
import org.apache.rocketmq.common.message.MessageClientIDSetter;
import org.apache.rocketmq.common.message.MessageConst;
import org.apache.rocketmq.common.message.MessageDecoder;
import org.apache.rocketmq.common.message.MessageQueue;
import org.apache.rocketmq.logging.InternalLogger;
import org.apache.rocketmq.remoting.exception.RemotingException;
public class ProduceAccumulator {
// totalHoldSize normal value
private long totalHoldSize = 32 * 1024 * 1024;
// holdSize normal value
private long holdSize = 32 * 1024;
// holdMs normal value
private int holdMs = 10;
private static final InternalLogger log = ClientLogger.getLog();
private final GuardForSyncSendService guardThreadForSyncSend;
private final GuardForAsyncSendService guardThreadForAsyncSend;
private Map<AggregateKey, MessageAccumulation> syncSendBatchs = new ConcurrentHashMap<AggregateKey, MessageAccumulation>();
private Map<AggregateKey, MessageAccumulation> asyncSendBatchs = new ConcurrentHashMap<AggregateKey, MessageAccumulation>();
private AtomicLong currentlyHoldSize = new AtomicLong(0);
private final String instanceName;
public ProduceAccumulator(String instanceName) {
this.instanceName = instanceName;
this.guardThreadForSyncSend = new GuardForSyncSendService(this.instanceName);
this.guardThreadForAsyncSend = new GuardForAsyncSendService(this.instanceName);
}
private class GuardForSyncSendService extends ServiceThread {
private final String serviceName;
public GuardForSyncSendService(String clientInstanceName) {
serviceName = String.format("Client_%s_GuardForSyncSend", clientInstanceName);
}
@Override public String getServiceName() {
return serviceName;
}
@Override public void run() {
ProduceAccumulator.log.info(this.getServiceName() + " service started");
while (!this.isStopped()) {
try {
this.doWork();
} catch (Exception e) {
ProduceAccumulator.log.warn(this.getServiceName() + " service has exception. ", e);
}
}
ProduceAccumulator.log.info(this.getServiceName() + " service end");
}
private void doWork() throws Exception {
Collection<MessageAccumulation> values = syncSendBatchs.values();
final int sleepTime = Math.max(1, holdMs / 2);
for (MessageAccumulation v : values) {
v.wakeup();
synchronized (v) {
synchronized (v.closed) {
if (v.messagesSize.get() == 0) {
v.closed.set(true);
syncSendBatchs.remove(v.aggregateKey, v);
} else {
v.notify();
}
}
}
}
Thread.sleep(sleepTime);
}
}
private class GuardForAsyncSendService extends ServiceThread {
private final String serviceName;
public GuardForAsyncSendService(String clientInstanceName) {
serviceName = String.format("Client_%s_GuardForAsyncSend", clientInstanceName);
}
@Override public String getServiceName() {
return serviceName;
}
@Override public void run() {
ProduceAccumulator.log.info(this.getServiceName() + " service started");
while (!this.isStopped()) {
try {
this.doWork();
} catch (Exception e) {
ProduceAccumulator.log.warn(this.getServiceName() + " service has exception. ", e);
}
}
ProduceAccumulator.log.info(this.getServiceName() + " service end");
}
private void doWork() throws Exception {
Collection<MessageAccumulation> values = asyncSendBatchs.values();
final int sleepTime = Math.max(1, holdMs / 2);
for (MessageAccumulation v : values) {
if (v.readyToSend()) {
v.send(null);
}
synchronized (v.closed) {
if (v.messagesSize.get() == 0) {
v.closed.set(true);
asyncSendBatchs.remove(v.aggregateKey, v);
}
}
}
Thread.sleep(sleepTime);
}
}
void start() {
guardThreadForSyncSend.start();
guardThreadForAsyncSend.start();
}
void shutdown() {
guardThreadForSyncSend.shutdown();
guardThreadForAsyncSend.shutdown();
}
int getBatchMaxDelayMs() {
return holdMs;
}
void batchMaxDelayMs(int holdMs) {
if (holdMs <= 0 || holdMs > 30 * 1000) {
throw new IllegalArgumentException(String.format("batchMaxDelayMs expect between 1ms and 30s, but get %d!", holdMs));
}
this.holdMs = holdMs;
}
long getBatchMaxBytes() {
return holdSize;
}
void batchMaxBytes(long holdSize) {
if (holdSize <= 0 || holdSize > 2 * 1024 * 1024) {
throw new IllegalArgumentException(String.format("batchMaxBytes expect between 1B and 2MB, but get %d!", holdSize));
}
this.holdSize = holdSize;
}
long getTotalBatchMaxBytes() {
return holdSize;
}
void totalBatchMaxBytes(long totalHoldSize) {
if (totalHoldSize <= 0) {
throw new IllegalArgumentException(String.format("totalBatchMaxBytes must bigger then 0, but get %d!", totalHoldSize));
}
this.totalHoldSize = totalHoldSize;
}
private MessageAccumulation getOrCreateSyncSendBatch(AggregateKey aggregateKey,
DefaultMQProducer defaultMQProducer) {
MessageAccumulation batch = syncSendBatchs.get(aggregateKey);
if (batch != null) {
return batch;
}
batch = new MessageAccumulation(aggregateKey, defaultMQProducer);
MessageAccumulation previous = syncSendBatchs.putIfAbsent(aggregateKey, batch);
return previous == null ? batch : previous;
}
private MessageAccumulation getOrCreateAsyncSendBatch(AggregateKey aggregateKey,
DefaultMQProducer defaultMQProducer) {
MessageAccumulation batch = asyncSendBatchs.get(aggregateKey);
if (batch != null) {
return batch;
}
batch = new MessageAccumulation(aggregateKey, defaultMQProducer);
MessageAccumulation previous = asyncSendBatchs.putIfAbsent(aggregateKey, batch);
return previous == null ? batch : previous;
}
SendResult send(Message msg,
DefaultMQProducer defaultMQProducer) throws InterruptedException, MQBrokerException, RemotingException, MQClientException {
AggregateKey partitionKey = new AggregateKey(msg);
while (true) {
MessageAccumulation batch = getOrCreateSyncSendBatch(partitionKey, defaultMQProducer);
int index = batch.add(msg);
if (index == -1) {
syncSendBatchs.remove(partitionKey, batch);
} else {
return batch.sendResults[index];
}
}
}
SendResult send(Message msg, MessageQueue mq,
DefaultMQProducer defaultMQProducer) throws InterruptedException, MQBrokerException, RemotingException, MQClientException {
AggregateKey partitionKey = new AggregateKey(msg, mq);
while (true) {
MessageAccumulation batch = getOrCreateSyncSendBatch(partitionKey, defaultMQProducer);
int index = batch.add(msg);
if (index == -1) {
syncSendBatchs.remove(partitionKey, batch);
} else {
return batch.sendResults[index];
}
}
}
void send(Message msg, SendCallback sendCallback,
DefaultMQProducer defaultMQProducer) throws InterruptedException, RemotingException, MQClientException {
AggregateKey partitionKey = new AggregateKey(msg);
while (true) {
MessageAccumulation batch = getOrCreateAsyncSendBatch(partitionKey, defaultMQProducer);
if (!batch.add(msg, sendCallback)) {
asyncSendBatchs.remove(partitionKey, batch);
} else {
return;
}
}
}
void send(Message msg, MessageQueue mq,
SendCallback sendCallback,
DefaultMQProducer defaultMQProducer) throws InterruptedException, RemotingException, MQClientException {
AggregateKey partitionKey = new AggregateKey(msg, mq);
while (true) {
MessageAccumulation batch = getOrCreateAsyncSendBatch(partitionKey, defaultMQProducer);
if (!batch.add(msg, sendCallback)) {
asyncSendBatchs.remove(partitionKey, batch);
} else {
return;
}
}
}
boolean tryAddMessage(Message message) {
synchronized (currentlyHoldSize) {
if (currentlyHoldSize.get() < totalHoldSize) {
currentlyHoldSize.addAndGet(message.getBody().length);
return true;
} else {
return false;
}
}
}
private class AggregateKey {
public String topic = null;
public MessageQueue mq = null;
public boolean waitStoreMsgOK = false;
public String tag = null;
public AggregateKey(Message message) {
this(message.getTopic(), null, message.isWaitStoreMsgOK(), message.getTags());
}
public AggregateKey(Message message, MessageQueue mq) {
this(message.getTopic(), mq, message.isWaitStoreMsgOK(), message.getTags());
}
public AggregateKey(String topic, MessageQueue mq, boolean waitStoreMsgOK, String tag) {
this.topic = topic;
this.mq = mq;
this.waitStoreMsgOK = waitStoreMsgOK;
this.tag = tag;
}
@Override public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
AggregateKey key = (AggregateKey) o;
return waitStoreMsgOK == key.waitStoreMsgOK && topic.equals(key.topic) && Objects.equals(mq, key.mq) && Objects.equals(tag, key.tag);
}
@Override public int hashCode() {
return Objects.hash(topic, mq, waitStoreMsgOK, tag);
}
}
private class MessageAccumulation {
private final DefaultMQProducer defaultMQProducer;
private LinkedList<Message> messages;
private LinkedList<SendCallback> sendCallbacks;
private Set<String> keys;
private AtomicBoolean closed;
private SendResult[] sendResults;
private AggregateKey aggregateKey;
private AtomicInteger messagesSize;
private int count;
private long createTime;
public MessageAccumulation(AggregateKey aggregateKey, DefaultMQProducer defaultMQProducer) {
this.defaultMQProducer = defaultMQProducer;
this.messages = new LinkedList<Message>();
this.sendCallbacks = new LinkedList<SendCallback>();
this.keys = new HashSet<String>();
this.closed = new AtomicBoolean(false);
this.messagesSize = new AtomicInteger(0);
this.aggregateKey = aggregateKey;
this.count = 0;
this.createTime = System.currentTimeMillis();
}
private boolean readyToSend() {
if (this.messagesSize.get() > holdSize
|| System.currentTimeMillis() >= this.createTime + holdMs) {
return true;
}
return false;
}
public int add(
Message msg) throws InterruptedException, MQBrokerException, RemotingException, MQClientException {
int ret = -1;
synchronized (this.closed) {
if (this.closed.get()) {
return ret;
}
ret = this.count++;
this.messages.add(msg);
messagesSize.addAndGet(msg.getBody().length);
String msgKeys = msg.getKeys();
if (msgKeys != null) {
this.keys.addAll(Arrays.asList(msgKeys.split(MessageConst.KEY_SEPARATOR)));
}
}
synchronized (this) {
while (!this.closed.get()) {
if (readyToSend()) {
this.send();
break;
} else {
this.wait();
}
}
return ret;
}
}
public boolean add(Message msg,
SendCallback sendCallback) throws InterruptedException, RemotingException, MQClientException {
synchronized (this.closed) {
if (this.closed.get()) {
return false;
}
this.count++;
this.messages.add(msg);
this.sendCallbacks.add(sendCallback);
messagesSize.getAndAdd(msg.getBody().length);
}
if (readyToSend()) {
this.send(sendCallback);
}
return true;
}
public synchronized void wakeup() {
if (this.closed.get()) {
return;
}
this.notify();
}
private MessageBatch batch() {
MessageBatch messageBatch = new MessageBatch(this.messages);
messageBatch.setTopic(this.aggregateKey.topic);
messageBatch.setWaitStoreMsgOK(this.aggregateKey.waitStoreMsgOK);
messageBatch.setKeys(this.keys);
messageBatch.setTags(this.aggregateKey.tag);
MessageClientIDSetter.setUniqID(messageBatch);
messageBatch.setBody(MessageDecoder.encodeMessages(this.messages));
return messageBatch;
}
private void splitSendResults(SendResult sendResult) {
if (sendResult == null) {
throw new IllegalArgumentException("sendResult is null");
}
boolean isBatchConsumerQueue = !sendResult.getMsgId().contains(",");
this.sendResults = new SendResult[this.count];
if (!isBatchConsumerQueue) {
String[] msgIds = sendResult.getMsgId().split(",");
String[] offsetMsgIds = sendResult.getOffsetMsgId().split(",");
if (offsetMsgIds.length != this.count || msgIds.length != this.count) {
throw new IllegalArgumentException("sendResult is illegal");
}
for (int i = 0; i < this.count; i++) {
this.sendResults[i] = new SendResult(sendResult.getSendStatus(), msgIds[i],
sendResult.getMessageQueue(), sendResult.getQueueOffset() + i,
sendResult.getTransactionId(), offsetMsgIds[i], sendResult.getRegionId());
}
} else {
for (int i = 0; i < this.count; i++) {
this.sendResults[i] = sendResult;
}
}
}
private void send() throws InterruptedException, MQClientException, MQBrokerException, RemotingException {
synchronized (this.closed) {
if (this.closed.getAndSet(true)) {
return;
}
}
MessageBatch messageBatch = this.batch();
SendResult sendResult = null;
try {
if (defaultMQProducer != null) {
sendResult = defaultMQProducer.sendDirect(messageBatch, aggregateKey.mq, null);
this.splitSendResults(sendResult);
} else {
throw new IllegalArgumentException("defaultMQProducer is null, can not send message");
}
} finally {
currentlyHoldSize.addAndGet(-messagesSize.get());
this.notifyAll();
}
}
private void send(SendCallback sendCallback) {
synchronized (this.closed) {
if (this.closed.getAndSet(true)) {
return;
}
}
MessageBatch messageBatch = this.batch();
SendResult sendResult = null;
try {
if (defaultMQProducer != null) {
final int size = messagesSize.get();
defaultMQProducer.sendDirect(messageBatch, aggregateKey.mq, new SendCallback() {
@Override public void onSuccess(SendResult sendResult) {
try {
splitSendResults(sendResult);
int i = 0;
Iterator<SendCallback> it = sendCallbacks.iterator();
while (it.hasNext()) {
SendCallback v = it.next();
v.onSuccess(sendResults[i++]);
}
if (i != count) {
throw new IllegalArgumentException("sendResult is illegal");
}
currentlyHoldSize.addAndGet(-size);
} catch (Exception e) {
onException(e);
}
}
@Override public void onException(Throwable e) {
for (SendCallback v : sendCallbacks) {
v.onException(e);
}
currentlyHoldSize.addAndGet(-size);
}
});
} else {
throw new IllegalArgumentException("defaultMQProducer is null, can not send message");
}
} catch (Exception e) {
for (SendCallback v : sendCallbacks) {
v.onException(e);
}
}
}
}
}