blob: 385b414a1691e698603bf00ecfb5b42a7fa47135 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ssl
from ssl import SSLContext
from pyignite.constants import SSL_DEFAULT_CIPHERS, SSL_DEFAULT_VERSION
from pyignite.exceptions import ParameterError
def wrap(socket, ssl_params):
""" Wrap socket in SSL wrapper. """
if not ssl_params.get('use_ssl'):
return socket
context = create_ssl_context(ssl_params)
return context.wrap_socket(sock=socket)
def check_ssl_params(params):
expected_args = [
'use_ssl',
'ssl_version',
'ssl_ciphers',
'ssl_cert_reqs',
'ssl_keyfile',
'ssl_keyfile_password',
'ssl_certfile',
'ssl_ca_certfile',
]
for param in params:
if param not in expected_args:
raise ParameterError((
'Unexpected parameter for connection initialization: `{}`'
).format(param))
def create_ssl_context(ssl_params):
if not ssl_params.get('use_ssl'):
return None
keyfile = ssl_params.get('ssl_keyfile', None)
certfile = ssl_params.get('ssl_certfile', None)
if keyfile and not certfile:
raise ValueError("certfile must be specified")
password = ssl_params.get('ssl_keyfile_password', None)
ssl_version = ssl_params.get('ssl_version', SSL_DEFAULT_VERSION)
ciphers = ssl_params.get('ssl_ciphers', SSL_DEFAULT_CIPHERS)
cert_reqs = ssl_params.get('ssl_cert_reqs', ssl.CERT_NONE)
ca_certs = ssl_params.get('ssl_ca_certfile', None)
context = SSLContext(ssl_version)
context.verify_mode = cert_reqs
if ca_certs:
context.load_verify_locations(ca_certs)
if certfile:
context.load_cert_chain(certfile, keyfile, password)
if ciphers:
context.set_ciphers(ciphers)
return context