| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """Base definitions for RPC.""" |
| # pylint: disable=invalid-name |
| |
| import socket |
| import time |
| import json |
| import errno |
| import struct |
| import random |
| import logging |
| |
| from .._ffi.base import py_str |
| |
| # Magic header for RPC data plane |
| RPC_MAGIC = 0xFF271 |
| # magic header for RPC tracker(control plane) |
| RPC_TRACKER_MAGIC = 0x2F271 |
| # sucess response |
| RPC_CODE_SUCCESS = RPC_MAGIC + 0 |
| # duplicate key in proxy |
| RPC_CODE_DUPLICATE = RPC_MAGIC + 1 |
| # cannot found matched key in server |
| RPC_CODE_MISMATCH = RPC_MAGIC + 2 |
| |
| logger = logging.getLogger("RPCServer") |
| |
| |
| class TrackerCode(object): |
| """Enumeration code for the RPC tracker""" |
| |
| FAIL = -1 |
| SUCCESS = 0 |
| PING = 1 |
| STOP = 2 |
| PUT = 3 |
| REQUEST = 4 |
| UPDATE_INFO = 5 |
| SUMMARY = 6 |
| GET_PENDING_MATCHKEYS = 7 |
| |
| |
| RPC_SESS_MASK = 128 |
| |
| |
| def get_addr_family(addr): |
| res = socket.getaddrinfo(addr[0], addr[1], 0, 0, socket.IPPROTO_TCP) |
| return res[0][0] |
| |
| |
| def recvall(sock, nbytes): |
| """Receive all nbytes from socket. |
| |
| Parameters |
| ---------- |
| sock: Socket |
| The socket |
| |
| nbytes : int |
| Number of bytes to be received. |
| """ |
| res = [] |
| nread = 0 |
| while nread < nbytes: |
| chunk = sock.recv(min(nbytes - nread, 1024)) |
| if not chunk: |
| raise IOError("connection reset") |
| nread += len(chunk) |
| res.append(chunk) |
| return b"".join(res) |
| |
| |
| def sendjson(sock, data): |
| """send a python value to remote via json |
| |
| Parameters |
| ---------- |
| sock : Socket |
| The socket |
| |
| data : object |
| Python value to be sent. |
| """ |
| data = json.dumps(data) |
| sock.sendall(struct.pack("<i", len(data))) |
| sock.sendall(data.encode("utf-8")) |
| |
| |
| def recvjson(sock): |
| """receive python value from remote via json |
| |
| Parameters |
| ---------- |
| sock : Socket |
| The socket |
| |
| Returns |
| ------- |
| value : object |
| The value received. |
| """ |
| size = struct.unpack("<i", recvall(sock, 4))[0] |
| data = json.loads(py_str(recvall(sock, size))) |
| return data |
| |
| |
| def random_key(prefix, cmap=None): |
| """Generate a random key |
| |
| Parameters |
| ---------- |
| prefix : str |
| The string prefix |
| |
| cmap : dict |
| Conflict map |
| |
| Returns |
| ------- |
| key : str |
| The generated random key |
| """ |
| if cmap: |
| while True: |
| key = prefix + str(random.random()) |
| if key not in cmap: |
| return key |
| else: |
| return prefix + str(random.random()) |
| |
| |
| def connect_with_retry(addr, timeout=60, retry_period=5): |
| """Connect to a TPC address with retry |
| |
| This function is only reliable to short period of server restart. |
| |
| Parameters |
| ---------- |
| addr : tuple |
| address tuple |
| |
| timeout : float |
| Timeout during retry |
| |
| retry_period : float |
| Number of seconds before we retry again. |
| """ |
| tstart = time.time() |
| while True: |
| try: |
| sock = socket.socket(get_addr_family(addr), socket.SOCK_STREAM) |
| sock.connect(addr) |
| return sock |
| except socket.error as sock_err: |
| if sock_err.args[0] not in (errno.ECONNREFUSED,): |
| raise sock_err |
| period = time.time() - tstart |
| if period > timeout: |
| raise RuntimeError("Failed to connect to server %s" % str(addr)) |
| logger.warning( |
| "Cannot connect to tracker %s, retry in %g secs...", str(addr), retry_period |
| ) |
| time.sleep(retry_period) |