blob: 78c6418f2bf8d9c6ce8ff79d5fc5419726ef0694 [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::sync::Arc;
use log::debug;
use super::HF_SCHEME;
use super::config::HfConfig;
use super::core::HfCore;
use super::deleter::HfDeleter;
use super::lister::HfLister;
use super::reader::HfReader;
use super::uri::{HfRepo, HfRepoType};
use super::writer::HfWriter;
use opendal_core::raw::*;
use opendal_core::*;
/// [Hugging Face](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api)'s API support.
#[doc = include_str!("docs.md")]
#[derive(Debug, Default)]
pub struct HfBuilder {
pub(super) config: HfConfig,
}
impl HfBuilder {
/// Set repo type of this backend. Default is model.
///
/// Available values:
/// - model
/// - dataset
/// - datasets (alias for dataset)
/// - space
/// - bucket
///
/// [Reference](https://huggingface.co/docs/hub/repositories)
pub fn repo_type(mut self, repo_type: &str) -> Self {
if !repo_type.is_empty() {
if let Ok(rt) = HfRepoType::parse(repo_type) {
self.config.repo_type = rt;
}
}
self
}
/// Set repo id of this backend. This is required.
///
/// Repo id consists of the account name and the repository name.
///
/// For example, model's repo id looks like:
/// - meta-llama/Llama-2-7b
///
/// Dataset's repo id looks like:
/// - databricks/databricks-dolly-15k
pub fn repo_id(mut self, repo_id: &str) -> Self {
if !repo_id.is_empty() {
self.config.repo_id = Some(repo_id.to_string());
}
self
}
/// Set revision of this backend. Default is main.
///
/// Revision can be a branch name or a commit hash.
///
/// For example, revision can be:
/// - main
/// - 1d0c4eb
pub fn revision(mut self, revision: &str) -> Self {
if !revision.is_empty() {
self.config.revision = Some(revision.to_string());
}
self
}
/// Set root of this backend.
///
/// All operations will happen under this root.
pub fn root(mut self, root: &str) -> Self {
self.config.root = if root.is_empty() {
None
} else {
Some(root.to_string())
};
self
}
/// Set the token of this backend.
///
/// This is optional.
pub fn token(mut self, token: &str) -> Self {
if !token.is_empty() {
self.config.token = Some(token.to_string());
}
self
}
/// configure the Hub base url. You might want to set this variable if your
/// organization is using a Private Hub https://huggingface.co/enterprise
///
/// Default is "https://huggingface.co"
pub fn endpoint(mut self, endpoint: &str) -> Self {
if !endpoint.is_empty() {
self.config.endpoint = Some(endpoint.to_string());
}
self
}
}
impl Builder for HfBuilder {
type Config = HfConfig;
fn build(self) -> Result<impl Access> {
debug!("backend build started: {:?}", &self);
let repo_type = self.config.repo_type;
debug!("backend use repo_type: {:?}", &repo_type);
let repo_id = match &self.config.repo_id {
Some(repo_id) => Ok(repo_id.clone()),
None => Err(Error::new(ErrorKind::ConfigInvalid, "repo_id is empty")
.with_operation("Builder::build")
.with_context("service", HF_SCHEME)),
}?;
debug!("backend use repo_id: {}", &repo_id);
let revision = match &self.config.revision {
Some(revision) => revision.clone(),
None => "main".to_string(),
};
debug!("backend use revision: {}", &revision);
let root = normalize_root(&self.config.root.unwrap_or_default());
debug!("backend use root: {}", &root);
let token = self.config.token.as_ref().cloned();
let endpoint = match &self.config.endpoint {
Some(endpoint) => endpoint.clone(),
None => {
// Try to read from HF_ENDPOINT env var which is used
// by the official huggingface clients.
if let Ok(env_endpoint) = std::env::var("HF_ENDPOINT") {
env_endpoint
} else {
"https://huggingface.co".to_string()
}
}
};
debug!("backend use endpoint: {}", &endpoint);
let info: Arc<AccessorInfo> = {
let am = AccessorInfo::default();
am.set_scheme(HF_SCHEME).set_native_capability(Capability {
stat: true,
read: true,
write: token.is_some(),
delete: token.is_some(),
delete_max_size: Some(100),
list: true,
list_with_recursive: true,
shared: true,
..Default::default()
});
am.into()
};
let repo = HfRepo::new(repo_type, repo_id, Some(revision.clone()));
debug!("backend repo uri: {:?}", repo.uri(&root, ""));
Ok(HfBackend {
core: Arc::new(HfCore::build(info, repo, root, token, endpoint)?),
})
}
}
/// Backend for Hugging Face service
#[derive(Debug, Clone)]
pub struct HfBackend {
pub(crate) core: Arc<HfCore>,
}
impl Access for HfBackend {
type Reader = HfReader;
type Writer = HfWriter;
type Lister = oio::PageLister<HfLister>;
type Deleter = oio::BatchDeleter<HfDeleter>;
fn info(&self) -> Arc<AccessorInfo> {
self.core.info.clone()
}
async fn stat(&self, path: &str, _: OpStat) -> Result<RpStat> {
// Stat root always returns a DIR.
if path == "/" {
return Ok(RpStat::new(Metadata::new(EntryMode::DIR)));
}
if self.core.repo.is_bucket() {
if path.ends_with('/') {
return Ok(RpStat::new(Metadata::new(EntryMode::DIR)));
}
return match self.core.maybe_xet_file(path).await? {
Some(file_info) => {
let size = file_info.file_size().unwrap_or(0);
Ok(RpStat::new(
Metadata::new(EntryMode::FILE).with_content_length(size),
))
}
None => Err(Error::new(ErrorKind::NotFound, "path not found")),
};
}
let info = self.core.path_info(path).await?;
Ok(RpStat::new(info.metadata()?))
}
async fn read(&self, path: &str, args: OpRead) -> Result<(RpRead, Self::Reader)> {
let reader = HfReader::try_new(&self.core, path, args.range()).await?;
Ok((RpRead::default(), reader))
}
async fn list(&self, path: &str, args: OpList) -> Result<(RpList, Self::Lister)> {
let lister = HfLister::new(self.core.clone(), path.to_string(), args.recursive());
Ok((RpList::default(), oio::PageLister::new(lister)))
}
async fn write(&self, path: &str, _args: OpWrite) -> Result<(RpWrite, Self::Writer)> {
let writer = HfWriter::try_new(self.core.clone(), path.to_string()).await?;
Ok((RpWrite::default(), writer))
}
async fn delete(&self) -> Result<(RpDelete, Self::Deleter)> {
let deleter = HfDeleter::new(self.core.clone());
let max_batch_size = self.core.info.full_capability().delete_max_size;
Ok((
RpDelete::default(),
oio::BatchDeleter::new(deleter, max_batch_size),
))
}
}
#[cfg(test)]
pub(super) mod test_utils {
use super::HfBuilder;
use opendal_core::Operator;
use opendal_core::layers::HttpClientLayer;
use opendal_core::raw::HttpClient;
fn finish_operator(op: Operator) -> Operator {
let client = HttpClient::with(reqwest::Client::new());
op.layer(HttpClientLayer::new(client))
}
pub fn gpt2_operator() -> Operator {
let op = Operator::new(
HfBuilder::default()
.repo_type("model")
.repo_id("openai-community/gpt2"),
)
.unwrap()
.finish();
finish_operator(op)
}
pub fn mbpp_operator() -> Operator {
let op = Operator::new(
HfBuilder::default()
.repo_type("dataset")
.repo_id("google-research-datasets/mbpp"),
)
.unwrap()
.finish();
finish_operator(op)
}
pub fn testing_bucket_operator() -> Operator {
let repo_id = std::env::var("HF_OPENDAL_BUCKET").expect("HF_OPENDAL_BUCKET must be set");
let token = std::env::var("HF_OPENDAL_TOKEN").expect("HF_OPENDAL_TOKEN must be set");
let op = Operator::new(
HfBuilder::default()
.repo_type("bucket")
.repo_id(&repo_id)
.token(&token),
)
.unwrap()
.finish();
finish_operator(op)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn build_accepts_datasets_alias() {
HfBuilder::default()
.repo_id("org/repo")
.repo_type("datasets")
.build()
.expect("builder should accept datasets alias");
}
#[test]
fn build_accepts_space_repo_type() {
HfBuilder::default()
.repo_id("org/space")
.repo_type("space")
.build()
.expect("builder should accept space repo type");
}
#[test]
fn test_both_schemes_are_supported() {
use opendal_core::OperatorRegistry;
let registry = OperatorRegistry::get();
super::super::register_hf_service(registry);
// Test short scheme "hf"
let op = registry
.load("hf://user/repo")
.expect("short scheme should be registered and work");
assert_eq!(op.info().scheme(), "hf");
// Test long scheme "huggingface"
let op = registry
.load("huggingface://user/repo")
.expect("long scheme should be registered and work");
assert_eq!(op.info().scheme(), "hf");
}
}