blob: 524d1bc04088324b19ea896da7e1aaa437b84f05 [file] [log] [blame]
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Implementation of Heron's application level protocol for Python"""
import random
import socket
import struct
from heron.common.src.python.utils.log import Log
class HeronProtocol(object):
"""Heron's application level network protocol"""
INT_PACK_FMT = ">I"
HEADER_SIZE = 4
@staticmethod
def pack_int(i):
"""Packs int to bytestring"""
return struct.pack(HeronProtocol.INT_PACK_FMT, i)
@staticmethod
def unpack_int(i):
"""Unpacks bytestring to int"""
return struct.unpack(HeronProtocol.INT_PACK_FMT, i)[0]
@staticmethod
def get_size_to_pack_string(string):
"""Get size to pack string, four byte used for specifying length of the string"""
return 4 + len(string)
@staticmethod
def get_size_to_pack_message(message):
"""Get size to pack message, four byte used for specifying length of the message"""
return 4 + message.ByteSize()
@staticmethod
def read_new_packet(dispatcher):
"""Reads new packet and returns an IncomingPacket object
:param dispatcher: asyncore's dispatcher class from which packets is read
"""
packet = IncomingPacket()
packet.read(dispatcher)
return packet
@staticmethod
def decode_packet(packet):
"""Decodes an IncomingPacket object and returns (typename, reqid, serialized message)"""
if not packet.is_complete:
raise RuntimeError("In decode_packet(): Packet corrupted")
data = packet.data
len_typename = HeronProtocol.unpack_int(data[:4])
data = data[4:]
typename = data[:len_typename]
data = data[len_typename:]
reqid = REQID.unpack(data[:REQID.REQID_SIZE])
data = data[REQID.REQID_SIZE:]
len_msg = HeronProtocol.unpack_int(data[:4])
data = data[4:]
serialized_msg = data[:len_msg]
return typename, reqid, serialized_msg
class OutgoingPacket(object):
"""Wrapper class for outgoing packet"""
def __init__(self, raw_data):
self.raw = str(raw_data)
self.to_send = str(raw_data)
def __len__(self):
return len(self.raw)
@staticmethod
def create_packet(reqid, message):
"""Creates Outgoing Packet from a given reqid and message
:param reqid: REQID object
:param message: protocol buffer object
"""
assert message.IsInitialized()
packet = ''
# calculate the totla size of the packet incl. header
typename = message.DESCRIPTOR.full_name
datasize = HeronProtocol.get_size_to_pack_string(typename) + \
REQID.REQID_SIZE + HeronProtocol.get_size_to_pack_message(message)
# first write out how much data is there as the header
packet += HeronProtocol.pack_int(datasize)
# next write the type string
packet += HeronProtocol.pack_int(len(typename))
packet += typename
# reqid
packet += reqid.pack()
# add the proto
packet += HeronProtocol.pack_int(message.ByteSize())
packet += message.SerializeToString()
return OutgoingPacket(packet)
@property
def sent_complete(self):
"""Indicates whether this packet is successfully sent"""
return len(self.to_send) == 0
def send(self, dispatcher):
"""Sends this outgoing packet to dispatcher's socket"""
if self.sent_complete:
return
sent = dispatcher.send(self.to_send)
self.to_send = self.to_send[sent:]
class IncomingPacket(object):
"""Helper class for incoming packet"""
def __init__(self):
"""Initializes IncomingPacket object"""
self.header = ''
self.data = ''
self.is_header_read = False
self.is_complete = False
# for debugging identification purposes
self.id = random.getrandbits(8)
@staticmethod
def create_packet(header, data):
"""Creates an IncomingPacket object from header and data
This method is for testing purposes
"""
packet = IncomingPacket()
packet.header = header
packet.data = data
if len(header) == HeronProtocol.HEADER_SIZE:
packet.is_header_read = True
if len(data) == packet.get_datasize():
packet.is_complete = True
return packet
def convert_to_raw(self):
"""Converts this IncomingPacket object to raw string
This method is for testing purposes
"""
return self.header + self.data
def get_datasize(self):
"""Returns the datasize of the packet
:returns: (int) datasize of the packet, or -1 if header is incomplete
"""
if not self.is_header_read:
return -1
return HeronProtocol.unpack_int(self.header)
def get_pktsize(self):
"""Returns the size of this packet, including header and data"""
return len(self.data) + len(self.header)
def read(self, dispatcher):
"""Reads incoming data from asyncore.dispatcher"""
try:
if not self.is_header_read:
# try reading header
to_read = HeronProtocol.HEADER_SIZE - len(self.header)
self.header += dispatcher.recv(to_read)
if len(self.header) == HeronProtocol.HEADER_SIZE:
self.is_header_read = True
else:
Log.debug("Header read incomplete; read %d bytes of header" % len(self.header))
return
if self.is_header_read and not self.is_complete:
# try reading data
to_read = self.get_datasize() - len(self.data)
self.data += dispatcher.recv(to_read)
if len(self.data) == self.get_datasize():
self.is_complete = True
except socket.error as e:
if e.errno == socket.errno.EAGAIN or e.errno == socket.errno.EWOULDBLOCK:
# Try again later -> call continue_read later
Log.debug("Try again error")
else:
# Fatal error
Log.debug("Fatal error when reading IncomingPacket")
raise RuntimeError("Fatal error occured in IncomingPacket.read()")
def __str__(self):
return "Packet ID: %s, header: %s, complete: %s" % \
(str(self.id), self.is_header_read, self.is_complete)
class REQID(object):
"""Helper class for REQID"""
REQID_SIZE = 32
def __init__(self, data_bytes):
self.bytes = data_bytes
@staticmethod
def generate():
"""Generates a random REQID for request"""
data_bytes = bytearray(random.getrandbits(8) for i in range(REQID.REQID_SIZE))
return REQID(data_bytes)
@staticmethod
def generate_zero():
"""Generates a zero REQID for message"""
data_bytes = bytearray(0 for i in range(REQID.REQID_SIZE))
return REQID(data_bytes)
def pack(self):
"""Packs this REQID to bytestring"""
return self.bytes
def is_zero(self):
"""Checks if this REQID is zero"""
return self == REQID.generate_zero()
@staticmethod
def unpack(raw_data):
"""Unpacks a given bytestring and returns REQID object"""
return REQID(bytearray(raw_data))
def __eq__(self, another):
return hasattr(another, 'bytes') and self.bytes == another.bytes
def __hash__(self):
return hash(self.__str__())
def __str__(self):
if self.is_zero():
return "ZERO"
else:
return ''.join([str(i) for i in list(self.bytes)])
class StatusCode(object):
"""StatusCode for Response"""
OK = 0
WRITE_ERROR = 1
READ_ERROR = 2
INVALID_PACKET = 3
CONNECT_ERROR = 4
CLOSE_ERROR = 5
TIMEOUT_ERROR = 6