blob: b7f44f3219f297494d785cc55ff9c08dca3ab3f7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.reef.io.network.impl;
import org.apache.reef.wake.Identifier;
import org.apache.reef.wake.IdentifierFactory;
import org.apache.reef.wake.remote.Codec;
import java.io.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* DefaultNetworkMessageCodec implementation.
* This codec encodes/decodes NetworkConnectionServiceMessageImpl according to the type <T>.
*/
final class NetworkConnectionServiceMessageCodec implements Codec<NetworkConnectionServiceMessage> {
private static final Logger LOG = Logger.getLogger(NetworkConnectionServiceMessageCodec.class.getName());
private final IdentifierFactory factory;
/**
* Contains entries of (id of connection factory, instance of connection factory).
*/
private final Map<String, NetworkConnectionFactory> connFactoryMap;
/**
* Contains entries of (instance of codec, boolean whether the codec is streaming or not).
*/
private final ConcurrentMap<Codec, Boolean> isStreamingCodecMap;
/**
* Constructs a network connection service message codec.
*/
NetworkConnectionServiceMessageCodec(
final IdentifierFactory factory,
final Map<String, NetworkConnectionFactory> connFactoryMap) {
this.factory = factory;
this.connFactoryMap = connFactoryMap;
this.isStreamingCodecMap = new ConcurrentHashMap<>();
}
/**
* Encodes a network connection service message to bytes.
* @param obj a message
* @return bytes
*/
@Override
public byte[] encode(final NetworkConnectionServiceMessage obj) {
final Codec codec = connFactoryMap.get(obj.getConnectionFactoryId()).getCodec();
Boolean isStreamingCodec = isStreamingCodecMap.get(codec);
if (isStreamingCodec == null) {
isStreamingCodec = codec instanceof StreamingCodec;
isStreamingCodecMap.putIfAbsent(codec, isStreamingCodec);
}
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
try (DataOutputStream daos = new DataOutputStream(baos)) {
daos.writeUTF(obj.getConnectionFactoryId());
daos.writeUTF(obj.getSrcId().toString());
daos.writeUTF(obj.getDestId().toString());
daos.writeInt(obj.getData().size());
if (isStreamingCodec) {
for (final Object rec : obj.getData()) {
((StreamingCodec) codec).encodeToStream(rec, daos);
}
} else {
final Iterable dataList = obj.getData();
for (final Object message : dataList) {
final byte[] bytes = codec.encode(message);
daos.writeInt(bytes.length);
daos.write(bytes);
}
}
return baos.toByteArray();
}
} catch (final IOException e) {
throw new RuntimeException("IOException", e);
}
}
/**
* Decodes a network connection service message from bytes.
*
* @param data bytes
* @return a message
*/
@Override
public NetworkConnectionServiceMessage decode(final byte[] data) {
try (ByteArrayInputStream bais = new ByteArrayInputStream(data)) {
try (DataInputStream dais = new DataInputStream(bais)) {
final String connFactoryId = dais.readUTF();
final Identifier srcId = factory.getNewInstance(dais.readUTF());
final Identifier destId = factory.getNewInstance(dais.readUTF());
final int size = dais.readInt();
final List list = new ArrayList(size);
final Codec codec = connFactoryMap.get(connFactoryId).getCodec();
Boolean isStreamingCodec = isStreamingCodecMap.get(codec);
if (isStreamingCodec == null) {
isStreamingCodec = codec instanceof StreamingCodec;
isStreamingCodecMap.putIfAbsent(codec, isStreamingCodec);
}
if (isStreamingCodec) {
for (int i = 0; i < size; i++) {
list.add(((StreamingCodec) codec).decodeFromStream(dais));
}
} else {
for (int i = 0; i < size; i++) {
final int byteSize = dais.readInt();
final byte[] bytes = new byte[byteSize];
if (dais.read(bytes) == -1) {
LOG.log(Level.FINE, "No data read because end of stream was reached");
}
list.add(codec.decode(bytes));
}
}
return new NetworkConnectionServiceMessage(
connFactoryId,
srcId,
destId,
list
);
}
} catch (final IOException e) {
throw new RuntimeException("IOException", e);
}
}
}