blob: c9ef55bbca14b1c95b5ee164f88499166b0750f7 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Azure Storage helpers.
//!
//! This module provides utilities and shared abstractions for services built
//! on Azure Storage, such as Azure Blob Storage (`services-azblob`) or
//! Azure Data Lake Storage (`services-azdls`).
use std::collections::HashMap;
use http::Uri;
use http::response::Parts;
use reqsign::{AzureStorageConfig, AzureStorageCredential};
use crate::{Error, ErrorKind, Result};
/// Parses an [Azure connection string][1] into a configuration object.
///
/// The connection string doesn't have to specify all required parameters
/// because the user is still allowed to set them later directly on the object.
///
/// The function takes an AzureStorageService parameter because it determines
/// the fields used to parse the endpoint.
///
/// [1]: https://learn.microsoft.com/en-us/azure/storage/common/storage-configure-connection-string
pub(crate) fn azure_config_from_connection_string(
conn_str: &str,
storage: AzureStorageService,
) -> Result<AzureStorageConfig> {
let key_values = parse_connection_string(conn_str)?;
if storage == AzureStorageService::Blob {
// Try to read development storage configuration.
if let Some(development_config) = collect_blob_development_config(&key_values, &storage) {
return Ok(AzureStorageConfig {
account_name: Some(development_config.account_name),
account_key: Some(development_config.account_key),
endpoint: Some(development_config.endpoint),
..Default::default()
});
}
}
let mut config = AzureStorageConfig {
account_name: key_values.get("AccountName").cloned(),
endpoint: collect_endpoint(&key_values, &storage)?,
..Default::default()
};
if let Some(creds) = collect_credentials(&key_values) {
set_credentials(&mut config, creds);
};
Ok(config)
}
/// The service that a connection string refers to. The type influences
/// interpretation of endpoint-related fields during parsing.
#[derive(PartialEq)]
pub(crate) enum AzureStorageService {
/// Azure Blob Storage.
Blob,
/// Azure File Storage.
#[cfg(feature = "services-azfile")]
File,
/// Azure Data Lake Storage Gen2.
/// Backed by Blob Storage but exposed through a different endpoint (`dfs`).
#[cfg(feature = "services-azdls")]
Adls,
}
pub(crate) fn azure_account_name_from_endpoint(endpoint: &str) -> Option<String> {
/// Known Azure Storage endpoint suffixes.
const KNOWN_ENDPOINT_SUFFIXES: &[&str] = &[
"core.windows.net", // Azure public cloud
"core.usgovcloudapi.net", // Azure US Government
"core.chinacloudapi.cn", // Azure China
];
let endpoint: &str = endpoint
.strip_prefix("http://")
.or_else(|| endpoint.strip_prefix("https://"))
.unwrap_or(endpoint);
let (account_name, service_endpoint) = endpoint.split_once('.')?;
let (_storage_service, endpoint_suffix) = service_endpoint.split_once('.')?;
if KNOWN_ENDPOINT_SUFFIXES.contains(&endpoint_suffix.trim_end_matches('/')) {
Some(account_name.to_string())
} else {
None
}
}
/// Takes a semicolon-delimited Azure Storage connection string and returns
/// key-value pairs split from it.
fn parse_connection_string(conn_str: &str) -> Result<HashMap<String, String>> {
conn_str
.trim()
.replace("\n", "")
.split(';')
.filter(|&field| !field.is_empty())
.map(|field| {
let (key, value) = field.trim().split_once('=').ok_or(Error::new(
ErrorKind::ConfigInvalid,
format!("Invalid connection string, expected '=' in field: {field}"),
))?;
Ok((key.to_string(), value.to_string()))
})
.collect()
}
fn collect_blob_development_config(
key_values: &HashMap<String, String>,
storage: &AzureStorageService,
) -> Option<DevelopmentStorageConfig> {
debug_assert!(
storage == &AzureStorageService::Blob,
"Azurite Development Storage only supports Blob Storage"
);
// Azurite defaults.
const AZURITE_DEFAULT_STORAGE_ACCOUNT_NAME: &str = "devstoreaccount1";
const AZURITE_DEFAULT_STORAGE_ACCOUNT_KEY: &str =
"Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==";
const AZURITE_DEFAULT_BLOB_URI: &str = "http://127.0.0.1:10000";
if key_values.get("UseDevelopmentStorage") != Some(&"true".to_string()) {
return None; // Not using development storage
}
let account_name = key_values
.get("AccountName")
.cloned()
.unwrap_or(AZURITE_DEFAULT_STORAGE_ACCOUNT_NAME.to_string());
let account_key = key_values
.get("AccountKey")
.cloned()
.unwrap_or(AZURITE_DEFAULT_STORAGE_ACCOUNT_KEY.to_string());
let development_proxy_uri = key_values
.get("DevelopmentStorageProxyUri")
.cloned()
.unwrap_or(AZURITE_DEFAULT_BLOB_URI.to_string());
Some(DevelopmentStorageConfig {
endpoint: format!("{development_proxy_uri}/{account_name}"),
account_name,
account_key,
})
}
/// Helper struct to hold development storage aka Azurite configuration.
struct DevelopmentStorageConfig {
account_name: String,
account_key: String,
endpoint: String,
}
/// Parses an endpoint from the key-value pairs if possible.
///
/// Users are still able to later supplement configuration with an endpoint,
/// so endpoint-related fields aren't enforced.
fn collect_endpoint(
key_values: &HashMap<String, String>,
storage: &AzureStorageService,
) -> Result<Option<String>> {
match storage {
AzureStorageService::Blob => collect_or_build_endpoint(key_values, "BlobEndpoint", "blob"),
#[cfg(feature = "services-azfile")]
AzureStorageService::File => collect_or_build_endpoint(key_values, "FileEndpoint", "file"),
#[cfg(feature = "services-azdls")]
AzureStorageService::Adls => {
// ADLS doesn't have a dedicated endpoint field and we can only
// build it from parts.
if let Some(dfs_endpoint) = collect_endpoint_from_parts(key_values, "dfs")? {
Ok(Some(dfs_endpoint.clone()))
} else {
Ok(None)
}
}
}
}
fn collect_credentials(key_values: &HashMap<String, String>) -> Option<AzureStorageCredential> {
if let Some(sas_token) = key_values.get("SharedAccessSignature") {
Some(AzureStorageCredential::SharedAccessSignature(
sas_token.clone(),
))
} else if let (Some(account_name), Some(account_key)) =
(key_values.get("AccountName"), key_values.get("AccountKey"))
{
Some(AzureStorageCredential::SharedKey(
account_name.clone(),
account_key.clone(),
))
} else {
// We default to no authentication. This is not an error because e.g.
// Azure Active Directory configuration is typically not passed via
// connection strings.
// Users may also set credentials manually on the configuration.
None
}
}
fn set_credentials(config: &mut AzureStorageConfig, creds: AzureStorageCredential) {
match creds {
AzureStorageCredential::SharedAccessSignature(sas_token) => {
config.sas_token = Some(sas_token);
}
AzureStorageCredential::SharedKey(account_name, account_key) => {
config.account_name = Some(account_name);
config.account_key = Some(account_key);
}
AzureStorageCredential::BearerToken(_, _) => {
// Bearer tokens shouldn't be passed via connection strings.
}
}
}
fn collect_or_build_endpoint(
key_values: &HashMap<String, String>,
endpoint_key: &str,
service_name: &str,
) -> Result<Option<String>> {
if let Some(endpoint) = key_values.get(endpoint_key) {
Ok(Some(endpoint.clone()))
} else if let Some(built_endpoint) = collect_endpoint_from_parts(key_values, service_name)? {
Ok(Some(built_endpoint.clone()))
} else {
Ok(None)
}
}
fn collect_endpoint_from_parts(
key_values: &HashMap<String, String>,
storage_endpoint_name: &str,
) -> Result<Option<String>> {
let (account_name, endpoint_suffix) = match (
key_values.get("AccountName"),
key_values.get("EndpointSuffix"),
) {
(Some(name), Some(suffix)) => (name, suffix),
_ => return Ok(None), // Can't build an endpoint if one of them is missing
};
let protocol = key_values
.get("DefaultEndpointsProtocol")
.map(String::as_str)
.unwrap_or("https"); // Default to HTTPS if not specified
if protocol != "http" && protocol != "https" {
return Err(Error::new(
ErrorKind::ConfigInvalid,
format!("Invalid DefaultEndpointsProtocol: {protocol}"),
));
}
Ok(Some(format!(
"{protocol}://{account_name}.{storage_endpoint_name}.{endpoint_suffix}"
)))
}
/// Add response context to error.
///
/// This helper function will:
///
/// - remove sensitive or useless headers from parts.
/// - fetch uri if parts extensions contains `Uri`.
/// - censor sensitive SAS URI query parameters
pub fn with_azure_error_response_context(mut err: Error, mut parts: Parts) -> Error {
if let Some(uri) = parts.extensions.get::<Uri>() {
err = err.with_context("uri", censor_sas_uri(uri));
}
// The following headers may contains sensitive information.
parts.headers.remove("Set-Cookie");
parts.headers.remove("WWW-Authenticate");
parts.headers.remove("Proxy-Authenticate");
err = err.with_context("response", format!("{parts:?}"));
err
}
fn censor_sas_uri(uri: &Uri) -> String {
if let Some(query) = uri.query() {
// There is a large set of query parameters specified for SAS URIs.
// Some of them may be useful to an attacker, but the most important part is the signature.
// Without a signature, an attacker will not be able to replay the request.
// For now, just remove the signature.
//
// https://learn.microsoft.com/en-us/rest/api/storageservices/create-account-sas
// https://learn.microsoft.com/en-us/rest/api/storageservices/create-service-sas
// https://learn.microsoft.com/en-us/rest/api/storageservices/create-user-delegation-sas
//
let path = uri.path();
let new_query: String = query
.split("&")
.filter(|p| !p.starts_with("sig="))
.collect::<Vec<_>>()
.join("&");
let mut parts = uri.clone().into_parts();
parts.path_and_query = Some(format!("{path}?{new_query}").try_into().unwrap());
Uri::from_parts(parts).unwrap().to_string()
} else {
uri.to_string()
}
}
#[cfg(test)]
mod tests {
use http::Uri;
use reqsign::AzureStorageConfig;
use crate::raw::azure::censor_sas_uri;
use super::{
AzureStorageService, azure_account_name_from_endpoint, azure_config_from_connection_string,
};
#[test]
fn test_azure_config_from_connection_string() {
#[allow(unused_mut)]
let mut test_cases = vec![
("minimal fields",
(AzureStorageService::Blob, "BlobEndpoint=https://testaccount.blob.core.windows.net/"),
Some(AzureStorageConfig{
endpoint: Some("https://testaccount.blob.core.windows.net/".to_string()),
..Default::default()
}),
),
("basic creds and blob endpoint",
(AzureStorageService::Blob, "AccountName=testaccount;AccountKey=testkey;BlobEndpoint=https://testaccount.blob.core.windows.net/"),
Some(AzureStorageConfig{
account_name: Some("testaccount".to_string()),
account_key: Some("testkey".to_string()),
endpoint: Some("https://testaccount.blob.core.windows.net/".to_string()),
..Default::default()
}),
),
("SAS token",
(AzureStorageService::Blob, "SharedAccessSignature=blablabla"),
Some(AzureStorageConfig{
sas_token: Some("blablabla".to_string()),
..Default::default()
}),
),
("endpoint from parts",
(AzureStorageService::Blob, "AccountName=testaccount;EndpointSuffix=core.windows.net;DefaultEndpointsProtocol=https"),
Some(AzureStorageConfig{
endpoint: Some("https://testaccount.blob.core.windows.net".to_string()),
account_name: Some("testaccount".to_string()),
..Default::default()
}),
),
("endpoint from parts and no protocol",
(AzureStorageService::Blob, "AccountName=testaccount;EndpointSuffix=core.windows.net"),
Some(AzureStorageConfig{
// Defaults to https
endpoint: Some("https://testaccount.blob.core.windows.net".to_string()),
account_name: Some("testaccount".to_string()),
..Default::default()
}),
),
("prefers sas over key",
(AzureStorageService::Blob, "AccountName=testaccount;AccountKey=testkey;SharedAccessSignature=sas_token"),
Some(AzureStorageConfig{
sas_token: Some("sas_token".to_string()),
account_name: Some("testaccount".to_string()),
..Default::default()
}),
),
("development storage",
(AzureStorageService::Blob, "UseDevelopmentStorage=true",),
Some(AzureStorageConfig{
account_name: Some("devstoreaccount1".to_string()),
account_key: Some("Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==".to_string()),
endpoint: Some("http://127.0.0.1:10000/devstoreaccount1".to_string()),
..Default::default()
}),
),
("development storage with custom account values",
(AzureStorageService::Blob, "UseDevelopmentStorage=true;AccountName=myAccount;AccountKey=myKey"),
Some(AzureStorageConfig {
endpoint: Some("http://127.0.0.1:10000/myAccount".to_string()),
account_name: Some("myAccount".to_string()),
account_key: Some("myKey".to_string()),
..Default::default()
}),
),
("development storage with custom uri",
(AzureStorageService::Blob, "UseDevelopmentStorage=true;DevelopmentStorageProxyUri=http://127.0.0.1:12345"),
Some(AzureStorageConfig {
endpoint: Some("http://127.0.0.1:12345/devstoreaccount1".to_string()),
account_name: Some("devstoreaccount1".to_string()),
account_key: Some("Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==".to_string()),
..Default::default()
}),
),
("unknown key is ignored",
(AzureStorageService::Blob, "SomeUnknownKey=123;BlobEndpoint=https://testaccount.blob.core.windows.net/"),
Some(AzureStorageConfig{
endpoint: Some("https://testaccount.blob.core.windows.net/".to_string()),
..Default::default()
}),
),
("leading and trailing `;`",
(AzureStorageService::Blob, ";AccountName=testaccount;"),
Some(AzureStorageConfig {
account_name: Some("testaccount".to_string()),
..Default::default()
}),
),
("line breaks",
(AzureStorageService::Blob, r#"
AccountName=testaccount;
AccountKey=testkey;
EndpointSuffix=core.windows.net;
DefaultEndpointsProtocol=https"#),
Some(AzureStorageConfig {
account_name: Some("testaccount".to_string()),
account_key: Some("testkey".to_string()),
endpoint: Some("https://testaccount.blob.core.windows.net".to_string()),
..Default::default()
}),
),
("missing equals",
(AzureStorageService::Blob, "AccountNameexample;AccountKey=example;EndpointSuffix=core.windows.net;DefaultEndpointsProtocol=https",),
None, // This should fail due to missing '='
),
("with invalid protocol",
(AzureStorageService::Blob, "DefaultEndpointsProtocol=ftp;AccountName=example;EndpointSuffix=core.windows.net",),
None, // This should fail due to invalid protocol
),
];
#[cfg(feature = "services-azdls")]
test_cases.push(
("adls endpoint from parts",
(AzureStorageService::Adls, "AccountName=testaccount;EndpointSuffix=core.windows.net;DefaultEndpointsProtocol=https"),
Some(AzureStorageConfig{
account_name: Some("testaccount".to_string()),
endpoint: Some("https://testaccount.dfs.core.windows.net".to_string()),
..Default::default()
}),
)
);
#[cfg(feature = "services-azfile")]
test_cases.extend(vec![
(
"file endpoint from field",
(
AzureStorageService::File,
"FileEndpoint=https://testaccount.file.core.windows.net",
),
Some(AzureStorageConfig {
endpoint: Some("https://testaccount.file.core.windows.net".to_string()),
..Default::default()
}),
),
(
"file endpoint from parts",
(
AzureStorageService::File,
"AccountName=testaccount;EndpointSuffix=core.windows.net",
),
Some(AzureStorageConfig {
account_name: Some("testaccount".to_string()),
endpoint: Some("https://testaccount.file.core.windows.net".to_string()),
..Default::default()
}),
),
]);
#[cfg(feature = "services-azdls")]
test_cases.push((
"azdls development storage",
(AzureStorageService::Adls, "UseDevelopmentStorage=true"),
Some(AzureStorageConfig::default()), // Azurite doesn't support ADLSv2, so we ignore this case
));
for (name, (storage, conn_str), expected) in test_cases {
let actual = azure_config_from_connection_string(conn_str, storage);
if let Some(expected) = expected {
assert_azure_storage_config_eq(&actual.expect(name), &expected, name);
} else {
assert!(actual.is_err(), "Expected error for case: {name}");
}
}
}
#[test]
fn test_azure_account_name_from_endpoint() {
let test_cases = vec![
("https://account.blob.core.windows.net", Some("account")),
(
"https://account.blob.core.usgovcloudapi.net",
Some("account"),
),
(
"https://account.blob.core.chinacloudapi.cn",
Some("account"),
),
("https://account.dfs.core.windows.net", Some("account")),
("https://account.blob.core.windows.net/", Some("account")),
("https://account.blob.unknown.suffix.com", None),
("http://blob.core.windows.net", None),
];
for (endpoint, expected_account_name) in test_cases {
let account_name = azure_account_name_from_endpoint(endpoint);
assert_eq!(
account_name,
expected_account_name.map(|s| s.to_string()),
"Endpoint: {endpoint}"
);
}
}
#[test]
fn test_azure_uri_context_removes_sig() {
let uri: Uri = "https://foo.bar/path?foo=foo&sig=SENSITIVE&bar=bar"
.parse()
.unwrap();
let expected = "https://foo.bar/path?foo=foo&bar=bar";
assert_eq!(censor_sas_uri(&uri), expected);
}
/// Helper function to compare AzureStorageConfig fields manually.
fn assert_azure_storage_config_eq(
actual: &AzureStorageConfig,
expected: &AzureStorageConfig,
name: &str,
) {
assert_eq!(
actual.account_name, expected.account_name,
"account_name mismatch: {name}"
);
assert_eq!(
actual.account_key, expected.account_key,
"account_key mismatch: {name}"
);
assert_eq!(
actual.endpoint, expected.endpoint,
"endpoint mismatch: {name}"
);
assert_eq!(
actual.sas_token, expected.sas_token,
"sas_token mismatch: {name}"
);
}
}