blob: 181c34ad526e6dd032d25735f87256b2f70bd455 [file] [log] [blame]
/*
* Copyright 1999-2011 Alibaba Group.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.dubbo.remoting.exchange.codec;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import com.alibaba.dubbo.common.extension.ExtensionLoader;
import com.alibaba.dubbo.common.io.Bytes;
import com.alibaba.dubbo.common.io.StreamUtils;
import com.alibaba.dubbo.common.io.UnsafeByteArrayInputStream;
import com.alibaba.dubbo.common.io.UnsafeByteArrayOutputStream;
import com.alibaba.dubbo.common.logger.Logger;
import com.alibaba.dubbo.common.logger.LoggerFactory;
import com.alibaba.dubbo.common.serialize.ObjectInput;
import com.alibaba.dubbo.common.serialize.ObjectOutput;
import com.alibaba.dubbo.common.serialize.Serialization;
import com.alibaba.dubbo.common.utils.StringUtils;
import com.alibaba.dubbo.remoting.Channel;
import com.alibaba.dubbo.remoting.RemotingException;
import com.alibaba.dubbo.remoting.exchange.Request;
import com.alibaba.dubbo.remoting.exchange.Response;
import com.alibaba.dubbo.remoting.exchange.support.DefaultFuture;
import com.alibaba.dubbo.remoting.telnet.codec.TelnetCodec;
/**
* ExchangeCodec.
*
* @author qianlei
* @author william.liangf
*/
public class ExchangeCodec extends TelnetCodec {
private static final Logger logger = LoggerFactory.getLogger(ExchangeCodec.class);
// header length.
protected static final int HEADER_LENGTH = 16;
// magic header.
protected static final short MAGIC = (short) 0xdabb;
protected static final byte MAGIC_HIGH = Bytes.short2bytes(MAGIC)[0];
protected static final byte MAGIC_LOW = Bytes.short2bytes(MAGIC)[1];
// message flag.
protected static final byte FLAG_REQUEST = (byte) 0x80;
protected static final byte FLAG_TWOWAY = (byte) 0x40;
protected static final byte FLAG_EVENT = (byte) 0x20;
protected static final int SERIALIZATION_MASK = 0x1f;
private static Map<Byte, Serialization> ID_SERIALIZATION_MAP = new HashMap<Byte, Serialization>();
static {
Set<String> supportedExtensions = ExtensionLoader.getExtensionLoader(Serialization.class).getSupportedExtensions();
for (String name : supportedExtensions) {
Serialization serialization = ExtensionLoader.getExtensionLoader(Serialization.class).getExtension(name);
byte idByte = serialization.getContentTypeId();
if (ID_SERIALIZATION_MAP.containsKey(idByte)) {
logger.error("Serialization extension " + serialization.getClass().getName()
+ " has duplicate id to Serialization extension "
+ ID_SERIALIZATION_MAP.get(idByte).getClass().getName()
+ ", ignore this Serialization extension");
continue;
}
ID_SERIALIZATION_MAP.put(idByte, serialization);
}
}
public Short getMagicCode() {
return MAGIC;
}
public void encode(Channel channel, OutputStream os, Object msg) throws IOException {
if (msg instanceof Request) {
encodeRequest(channel, os, (Request) msg);
} else if (msg instanceof Response) {
encodeResponse(channel, os, (Response) msg);
} else {
super.encode(channel, os, msg);
}
}
public Object decode(Channel channel, InputStream is) throws IOException {
int readable = is.available();
byte[] header = new byte[Math.min(readable, HEADER_LENGTH)];
is.read(header);
return decode(channel, is, readable, header);
}
protected Object decode(Channel channel, InputStream is, int readable, byte[] header) throws IOException {
// check magic number.
if (readable > 0 && header[0] != MAGIC_HIGH
|| readable > 1 && header[1] != MAGIC_LOW) {
int length = header.length;
if (header.length < readable) {
header = Bytes.copyOf(header, readable);
is.read(header, length, readable - length);
}
for (int i = 1; i < header.length - 1; i ++) {
if (header[i] == MAGIC_HIGH && header[i + 1] == MAGIC_LOW) {
UnsafeByteArrayInputStream bis = ((UnsafeByteArrayInputStream) is);
bis.position(bis.position() - header.length + i);
header = Bytes.copyOf(header, i);
break;
}
}
return super.decode(channel, is, readable, header);
}
// check length.
if (readable < HEADER_LENGTH) {
return NEED_MORE_INPUT;
}
// get data length.
int len = Bytes.bytes2int(header, 12);
checkPayload(channel, len);
int tt = len + HEADER_LENGTH;
if( readable < tt ) {
return NEED_MORE_INPUT;
}
// limit input stream.
if( readable != tt )
is = StreamUtils.limitedInputStream(is, len);
byte flag = header[2], proto = (byte)( flag & SERIALIZATION_MASK );
Serialization s = getSerializationById(proto);
if (s == null) {
s = getSerialization(channel);
}
ObjectInput in = s.deserialize(channel.getUrl(), is);
// get request id.
long id = Bytes.bytes2long(header, 4);
if( ( flag & FLAG_REQUEST ) == 0 ) {
// decode response.
Response res = new Response(id);
if (( flag & FLAG_EVENT ) != 0){
res.setEvent(Response.HEARTBEAT_EVENT);
}
// get status.
byte status = header[3];
res.setStatus(status);
if( status == Response.OK ) {
try {
Object data;
if (res.isHeartbeat()) {
data = decodeHeartbeatData(channel, in);
} else if (res.isEvent()) {
data = decodeEventData(channel, in);
} else {
data = decodeResponseData(channel, in, getRequestData(id));
}
res.setResult(data);
} catch (Throwable t) {
res.setStatus(Response.CLIENT_ERROR);
res.setErrorMessage(StringUtils.toString(t));
}
} else {
res.setErrorMessage(in.readUTF());
}
return res;
} else {
// decode request.
Request req = new Request(id);
req.setVersion("2.0.0");
req.setTwoWay( ( flag & FLAG_TWOWAY ) != 0 );
if (( flag & FLAG_EVENT ) != 0 ){
req.setEvent(Request.HEARTBEAT_EVENT);
}
try {
Object data;
if (req.isHeartbeat()) {
data = decodeHeartbeatData(channel, in);
} else if (req.isEvent()) {
data = decodeEventData(channel, in);
} else {
data = decodeRequestData(channel, in);
}
req.setData(data);
} catch (Throwable t) {
// bad request
req.setBroken(true);
req.setData(t);
}
return req;
}
}
protected Object getRequestData(long id) {
DefaultFuture future = DefaultFuture.getFuture(id);
if (future == null)
return null;
Request req = future.getRequest();
if (req == null)
return null;
return req.getData();
}
protected void encodeRequest(Channel channel, OutputStream os, Request req) throws IOException {
Serialization serialization = getSerialization(channel);
// header.
byte[] header = new byte[HEADER_LENGTH];
// set magic number.
Bytes.short2bytes(MAGIC, header);
// set request and serialization flag.
header[2] = (byte) (FLAG_REQUEST | serialization.getContentTypeId());
if (req.isTwoWay()) header[2] |= FLAG_TWOWAY;
if (req.isEvent()) header[2] |= FLAG_EVENT;
// set request id.
Bytes.long2bytes(req.getId(), header, 4);
// encode request data.
UnsafeByteArrayOutputStream bos = new UnsafeByteArrayOutputStream(1024);
ObjectOutput out = serialization.serialize(channel.getUrl(), bos);
if (req.isEvent()) {
encodeEventData(channel, out, req.getData());
} else {
encodeRequestData(channel, out, req.getData());
}
out.flushBuffer();
bos.flush();
bos.close();
byte[] data = bos.toByteArray();
Bytes.int2bytes(data.length, header, 12);
// write
os.write(header); // write header.
os.write(data); // write data.
}
protected void encodeResponse(Channel channel, OutputStream os, Response res) throws IOException {
try {
Serialization serialization = getSerialization(channel);
// header.
byte[] header = new byte[HEADER_LENGTH];
// set magic number.
Bytes.short2bytes(MAGIC, header);
// set request and serialization flag.
header[2] = serialization.getContentTypeId();
if (res.isHeartbeat()) header[2] |= FLAG_EVENT;
// set response status.
byte status = res.getStatus();
header[3] = status;
// set request id.
Bytes.long2bytes(res.getId(), header, 4);
UnsafeByteArrayOutputStream bos = new UnsafeByteArrayOutputStream(1024);
ObjectOutput out = serialization.serialize(channel.getUrl(), bos);
// encode response data or error message.
if (status == Response.OK) {
if (res.isHeartbeat()) {
encodeHeartbeatData(channel, out, res.getResult());
} else {
encodeResponseData(channel, out, res.getResult());
}
}
else out.writeUTF(res.getErrorMessage());
out.flushBuffer();
bos.flush();
bos.close();
byte[] data = bos.toByteArray();
Bytes.int2bytes(data.length, header, 12);
// write
os.write(header); // write header.
os.write(data); // write data.
} catch (Throwable t) {
// 发送失败信息给Consumer,否则Consumer只能等超时了
if (! res.isEvent() && res.getStatus() != Response.BAD_RESPONSE) {
try {
// FIXME 在Codec中打印出错日志?在IoHanndler的caught中统一处理?
logger.warn("Fail to encode response: " + res + ", send bad_response info instead, cause: " + t.getMessage(), t);
Response r = new Response(res.getId(), res.getVersion());
r.setStatus(Response.BAD_RESPONSE);
r.setErrorMessage("Failed to send response: " + res + ", cause: " + StringUtils.toString(t));
channel.send(r);
return;
} catch (RemotingException e) {
logger.warn("Failed to send bad_response info back: " + res + ", cause: " + e.getMessage(), e);
}
}
// 重新抛出收到的异常
if (t instanceof IOException) {
throw (IOException) t;
} else if (t instanceof RuntimeException) {
throw (RuntimeException) t;
} else if (t instanceof Error) {
throw (Error) t;
} else {
throw new RuntimeException(t.getMessage(), t);
}
}
}
private static final Serialization getSerializationById(Byte id) {
return ID_SERIALIZATION_MAP.get(id);
}
@Override
protected Object decodeData(ObjectInput in) throws IOException {
return decodeRequestData(in);
}
@Deprecated
protected Object decodeHeartbeatData(ObjectInput in) throws IOException {
try {
return in.readObject();
} catch (ClassNotFoundException e) {
throw new IOException(StringUtils.toString("Read object failed.", e));
}
}
protected Object decodeRequestData(ObjectInput in) throws IOException {
try {
return in.readObject();
} catch (ClassNotFoundException e) {
throw new IOException(StringUtils.toString("Read object failed.", e));
}
}
protected Object decodeResponseData(ObjectInput in) throws IOException {
try {
return in.readObject();
} catch (ClassNotFoundException e) {
throw new IOException(StringUtils.toString("Read object failed.", e));
}
}
@Override
protected void encodeData(ObjectOutput out, Object data) throws IOException {
encodeRequestData(out, data);
}
private void encodeEventData(ObjectOutput out, Object data) throws IOException {
out.writeObject(data);
}
@Deprecated
protected void encodeHeartbeatData(ObjectOutput out, Object data) throws IOException {
encodeEventData(out, data);
}
protected void encodeRequestData(ObjectOutput out, Object data) throws IOException {
out.writeObject(data);
}
protected void encodeResponseData(ObjectOutput out, Object data) throws IOException {
out.writeObject(data);
}
@Override
protected Object decodeData(Channel channel, ObjectInput in) throws IOException {
return decodeRequestData(channel ,in);
}
private Object decodeEventData(Channel channel, ObjectInput in) throws IOException {
try {
return in.readObject();
} catch (ClassNotFoundException e) {
throw new IOException(StringUtils.toString("Read object failed.", e));
}
}
@Deprecated
protected Object decodeHeartbeatData(Channel channel, ObjectInput in) throws IOException {
try {
return in.readObject();
} catch (ClassNotFoundException e) {
throw new IOException(StringUtils.toString("Read object failed.", e));
}
}
protected Object decodeRequestData(Channel channel, ObjectInput in) throws IOException {
return decodeRequestData(in);
}
protected Object decodeResponseData(Channel channel, ObjectInput in) throws IOException {
return decodeResponseData(in);
}
protected Object decodeResponseData(Channel channel, ObjectInput in, Object requestData) throws IOException {
return decodeResponseData(channel, in);
}
@Override
protected void encodeData(Channel channel, ObjectOutput out, Object data) throws IOException {
encodeRequestData(channel, out, data);
}
private void encodeEventData(Channel channel, ObjectOutput out, Object data) throws IOException {
encodeEventData(out, data);
}
@Deprecated
protected void encodeHeartbeatData(Channel channel, ObjectOutput out, Object data) throws IOException {
encodeHeartbeatData(out, data);
}
protected void encodeRequestData(Channel channel, ObjectOutput out, Object data) throws IOException {
encodeRequestData(out, data);
}
protected void encodeResponseData(Channel channel, ObjectOutput out, Object data) throws IOException {
encodeResponseData(out, data);
}
}