blob: 99cff6627322a947f0a85d57fd84d8906d7c335f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import hmac
import time
import base64
from hashlib import sha256
from libcloud.http import LibcloudConnection
from libcloud.utils.py3 import ET, b, httplib, urlparse, urlencode, basestring
from libcloud.utils.xml import fixxpath
from libcloud.common.base import (
BaseDriver,
RawResponse,
XmlResponse,
ConnectionUserAndKey,
CertificateConnection,
)
from libcloud.common.types import LibcloudError, InvalidCredsError, MalformedResponseError
from libcloud.common.azure_arm import AzureAuthJsonResponse, publicEnvironments
# The time format for headers in Azure requests
AZURE_TIME_FORMAT = "%a, %d %b %Y %H:%M:%S GMT"
class AzureRedirectException(Exception):
def __init__(self, response):
self.location = response.headers["location"]
class AzureResponse(XmlResponse):
valid_response_codes = [
httplib.NOT_FOUND,
httplib.CONFLICT,
httplib.BAD_REQUEST,
# added TEMPORARY_REDIRECT as this can sometimes be
# sent by azure instead of a success or fail response
httplib.TEMPORARY_REDIRECT,
# Used by Azure Blobs range downloads
httplib.PARTIAL_CONTENT,
]
def success(self):
i = int(self.status)
return 200 <= i <= 299 or i in self.valid_response_codes
def parse_error(self, msg=None):
error_msg = "Unknown error"
try:
# Azure does give some meaningful errors, but is inconsistent
# Some APIs respond with an XML error. Others just dump HTML
body = self.parse_body()
# pylint: disable=no-member
if type(body) == ET.Element:
code = body.findtext(fixxpath(xpath="Code"))
message = body.findtext(fixxpath(xpath="Message"))
message = message.split("\n")[0]
error_msg = "{}: {}".format(code, message)
except MalformedResponseError:
pass
if msg:
error_msg = "{} - {}".format(msg, error_msg)
if self.status in [httplib.UNAUTHORIZED, httplib.FORBIDDEN]:
raise InvalidCredsError(error_msg)
raise LibcloudError("%s Status code: %d." % (error_msg, self.status), driver=self)
def parse_body(self):
is_redirect = int(self.status) == httplib.TEMPORARY_REDIRECT
if is_redirect and self.connection.driver.follow_redirects:
raise AzureRedirectException(self)
else:
return super().parse_body()
class AzureRawResponse(RawResponse):
pass
class AzureBaseDriver(BaseDriver):
name = "Microsoft Azure Service Management API"
class AzureActiveDirectoryConnection(ConnectionUserAndKey):
"""
Represents a single connection to Azure using Azure AD for Blob
"""
conn_class = LibcloudConnection
driver = AzureBaseDriver
name = "Azure AD Auth"
responseCls = AzureResponse
rawResponseCls = AzureRawResponse
API_VERSION = "2017-11-09"
def __init__(
self,
key,
secret,
secure=True,
host=None,
port=None,
tenant_id=None,
identity=None,
cloud_environment="default",
**kwargs,
):
super().__init__(identity, secret, **kwargs)
if isinstance(cloud_environment, basestring):
cloud_environment = publicEnvironments[cloud_environment]
if not isinstance(cloud_environment, dict):
raise Exception(
"cloud_environment must be one of '%s' or a dict "
"containing keys 'resourceManagerEndpointUrl', "
"'activeDirectoryEndpointUrl', "
"'activeDirectoryResourceId', "
"'storageEndpointSuffix'" % ("', '".join(publicEnvironments.keys()))
)
self.login_host = urlparse.urlparse(
cloud_environment["activeDirectoryEndpointUrl"]
).hostname
self.login_resource = cloud_environment["activeDirectoryResourceId"]
self.host = host
self.identity = identity
self.tenant_id = tenant_id
self.storage_account_id = key
def add_default_headers(self, headers):
headers["x-ms-date"] = time.strftime(AZURE_TIME_FORMAT, time.gmtime())
headers["x-ms-version"] = self.API_VERSION
headers["Content-Type"] = "application/xml"
headers["Authorization"] = "Bearer %s" % self.access_token
return headers
def get_client_credentials(self):
"""
Log in and get bearer token used to authorize API requests.
"""
conn = self.conn_class(self.login_host, 443, timeout=self.timeout)
conn.connect()
params = urlencode(
{
"grant_type": "client_credentials",
"client_id": self.user_id,
"client_secret": self.key,
"resource": "https://storage.azure.com/",
}
)
headers = {"Content-type": "application/x-www-form-urlencoded"}
conn.request("POST", "/%s/oauth2/token" % self.tenant_id, params, headers)
js = AzureAuthJsonResponse(conn.getresponse(), conn)
self.access_token = js.object["access_token"]
self.expires_on = js.object["expires_on"]
def connect(self, **kwargs):
self.get_client_credentials()
return super().connect(**kwargs)
def request(
self,
action,
params=None,
data=None,
headers=None,
method="GET",
raw=False,
stream=False,
json=None,
retry_failed=None,
*kwargs,
):
# Log in again if the token has expired or is going to expire soon
# (next 5 minutes).
if (time.time() + 300) >= int(self.expires_on):
self.get_client_credentials()
return super().request(
action,
params=params,
data=data,
headers=headers,
method=method,
raw=raw,
stream=stream,
json=json,
retry_failed=retry_failed,
)
class AzureConnection(ConnectionUserAndKey):
"""
Represents a single connection to Azure
"""
responseCls = AzureResponse
rawResponseCls = AzureRawResponse
API_VERSION = "2012-02-12"
def add_default_params(self, params):
return params
def pre_connect_hook(self, params, headers):
headers = copy.deepcopy(headers)
# We have to add a date header in GMT
headers["x-ms-date"] = time.strftime(AZURE_TIME_FORMAT, time.gmtime())
headers["x-ms-version"] = self.API_VERSION
# Add the authorization header
headers["Authorization"] = self._get_azure_auth_signature(
method=self.method,
headers=headers,
params=params,
account=self.user_id,
secret_key=self.key,
path=self.action,
)
# Azure cribs about this in 'raw' connections
headers.pop("Host", None)
return params, headers
def _get_azure_auth_signature(self, method, headers, params, account, secret_key, path="/"):
"""
Signature = Base64( HMAC-SHA1( YourSecretAccessKeyID,
UTF-8-Encoding-Of( StringToSign ) ) ) );
StringToSign = HTTP-VERB + "\n" +
Content-Encoding + "\n" +
Content-Language + "\n" +
Content-Length + "\n" +
Content-MD5 + "\n" +
Content-Type + "\n" +
Date + "\n" +
If-Modified-Since + "\n" +
If-Match + "\n" +
If-None-Match + "\n" +
If-Unmodified-Since + "\n" +
Range + "\n" +
CanonicalizedHeaders +
CanonicalizedResource;
"""
xms_header_values = []
param_list = []
# Split the x-ms headers and normal headers and make everything
# lower case
headers_copy = {}
for header, value in headers.items():
header = header.lower()
value = str(value).strip()
if header.startswith("x-ms-"):
xms_header_values.append((header, value))
else:
headers_copy[header] = value
# Get the values for the headers in the specific order
special_header_values = self._format_special_header_values(headers_copy, method)
# Prepare the first section of the string to be signed
values_to_sign = [method] + special_header_values
# string_to_sign = '\n'.join([method] + special_header_values)
# The x-ms-* headers have to be in lower case and sorted
xms_header_values.sort()
for header, value in xms_header_values:
values_to_sign.append("{}:{}".format(header, value))
# Add the canonicalized path
values_to_sign.append("/{}{}".format(account, path))
# URL query parameters (sorted and lower case)
for key, value in params.items():
param_list.append((key.lower(), str(value).strip()))
param_list.sort()
for key, value in param_list:
values_to_sign.append("{}:{}".format(key, value))
string_to_sign = b("\n".join(values_to_sign))
secret_key = b(secret_key)
b64_hmac = base64.b64encode(hmac.new(secret_key, string_to_sign, digestmod=sha256).digest())
return "SharedKey {}:{}".format(self.user_id, b64_hmac.decode("utf-8"))
def _format_special_header_values(self, headers, method):
is_change = method not in ("GET", "HEAD")
is_old_api = self.API_VERSION <= "2014-02-14"
special_header_keys = [
"content-encoding",
"content-language",
"content-length",
"content-md5",
"content-type",
"date",
"if-modified-since",
"if-match",
"if-none-match",
"if-unmodified-since",
"range",
]
special_header_values = []
for header in special_header_keys:
header = header.lower() # Just for safety
if header in headers:
special_header_values.append(headers[header])
elif header == "content-length" and is_change and is_old_api:
# For old API versions, the Content-Length header must be '0'
# https://docs.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key#content-length-header-in-version-2014-02-14-and-earlier
special_header_values.append("0")
else:
special_header_values.append("")
return special_header_values
class AzureServiceManagementConnection(CertificateConnection):
# This needs the following approach -
# 1. Make request using LibcloudHTTPSConnection which is a overloaded
# class which takes in a client certificate
# 2. Depending on the type of operation use a PollingConnection
# when the response id is returned
# 3. The Response can be used in an AzureServiceManagementResponse
"""
Authentication class for "Service Account" authentication.
"""
driver = AzureBaseDriver
responseCls = AzureResponse
rawResponseCls = AzureRawResponse
name = "Azure Service Management API Connection"
host = "management.core.windows.net"
keyfile = ""
def __init__(self, subscription_id, key_file, *args, **kwargs):
"""
Check to see if PyCrypto is available, and convert key file path into a
key string if the key is in a file.
:param subscription_id: Azure subscription ID.
:type subscription_id: ``str``
:param key_file: The PEM file used to authenticate with the service.
:type key_file: ``str``
"""
super().__init__(key_file, *args, **kwargs)
self.subscription_id = subscription_id
keypath = os.path.expanduser(key_file)
self.keyfile = keypath
is_file_path = os.path.exists(keypath) and os.path.isfile(keypath)
if not is_file_path:
raise InvalidCredsError(
"You need an certificate PEM file to authenticate with "
"Microsoft Azure. This can be found in the portal."
)
self.key_file = key_file
def add_default_headers(self, headers):
"""
@inherits: :class:`Connection.add_default_headers`
TODO: move to constant..
"""
headers["x-ms-version"] = "2014-05-01"
headers["x-ms-date"] = time.strftime(AZURE_TIME_FORMAT, time.gmtime())
# headers['host'] = self.host
return headers