Fix: get stuck when load extension in the concurrency environment (#184)
* Fix: get stuck when load extension in the concurrency environment
- Add a new struct called LoadExtensionPromise
- Remove async modifier in ExtensionDirectory
Close #183
* Ftr: use RwLock instead of unsafe
* Rft: simplify the code of extension promise resolve
diff --git a/dubbo/src/extension/mod.rs b/dubbo/src/extension/mod.rs
index 5641bea..c1d0395 100644
--- a/dubbo/src/extension/mod.rs
+++ b/dubbo/src/extension/mod.rs
@@ -22,8 +22,9 @@
};
use dubbo_base::{extension_param::ExtensionType, url::UrlParam, StdError, Url};
use dubbo_logger::tracing::{error, info};
+use std::{future::Future, pin::Pin, sync::Arc};
use thiserror::Error;
-use tokio::sync::oneshot;
+use tokio::sync::{oneshot, RwLock};
pub static EXTENSIONS: once_cell::sync::Lazy<ExtensionDirectoryCommander> =
once_cell::sync::Lazy::new(|| ExtensionDirectory::init());
@@ -41,13 +42,11 @@
let mut extension_directory = ExtensionDirectory::default();
// register static registry extension
- let _ = extension_directory
- .register(
- StaticRegistry::name(),
- StaticRegistry::convert_to_extension_factories(),
- ExtensionType::Registry,
- )
- .await;
+ let _ = extension_directory.register(
+ StaticRegistry::name(),
+ StaticRegistry::convert_to_extension_factories(),
+ ExtensionType::Registry,
+ );
while let Some(extension_opt) = rx.recv().await {
match extension_opt {
@@ -57,20 +56,19 @@
extension_type,
tx,
) => {
- let result = extension_directory
- .register(extension_name, extension_factories, extension_type)
- .await;
+ let result = extension_directory.register(
+ extension_name,
+ extension_factories,
+ extension_type,
+ );
let _ = tx.send(result);
}
ExtensionOpt::Remove(extension_name, extension_type, tx) => {
- let result = extension_directory
- .remove(extension_name, extension_type)
- .await;
+ let result = extension_directory.remove(extension_name, extension_type);
let _ = tx.send(result);
}
ExtensionOpt::Load(url, extension_type, tx) => {
- let result = extension_directory.load(url, extension_type).await;
- let _ = tx.send(result);
+ let _ = extension_directory.load(url, extension_type, tx);
}
}
}
@@ -79,7 +77,7 @@
ExtensionDirectoryCommander { sender: tx }
}
- async fn register(
+ fn register(
&mut self,
extension_name: String,
extension_factories: ExtensionFactories,
@@ -89,40 +87,53 @@
ExtensionType::Registry => match extension_factories {
ExtensionFactories::RegistryExtensionFactory(registry_extension_factory) => {
self.registry_extension_loader
- .register(extension_name, registry_extension_factory)
- .await;
+ .register(extension_name, registry_extension_factory);
Ok(())
}
},
}
}
- async fn remove(
+ fn remove(
&mut self,
extension_name: String,
extension_type: ExtensionType,
) -> Result<(), StdError> {
match extension_type {
ExtensionType::Registry => {
- self.registry_extension_loader.remove(extension_name).await;
+ self.registry_extension_loader.remove(extension_name);
Ok(())
}
}
}
- async fn load(
+ fn load(
&mut self,
url: Url,
extension_type: ExtensionType,
- ) -> Result<Extensions, StdError> {
+ callback: oneshot::Sender<Result<Extensions, StdError>>,
+ ) {
match extension_type {
ExtensionType::Registry => {
- let extension = self.registry_extension_loader.load(&url).await;
+ let extension = self.registry_extension_loader.load(url);
match extension {
- Ok(extension) => Ok(Extensions::Registry(extension)),
+ Ok(mut extension) => {
+ tokio::spawn(async move {
+ let extension = extension.resolve().await;
+ match extension {
+ Ok(extension) => {
+ let _ = callback.send(Ok(Extensions::Registry(extension)));
+ }
+ Err(err) => {
+ error!("load extension failed: {}", err);
+ let _ = callback.send(Err(err));
+ }
+ }
+ });
+ }
Err(err) => {
error!("load extension failed: {}", err);
- Err(err)
+ let _ = callback.send(Err(err));
}
}
}
@@ -130,6 +141,95 @@
}
}
+type ExtensionCreator<T> = Box<
+ dyn Fn(Url) -> Pin<Box<dyn Future<Output = Result<T, StdError>> + Send + 'static>>
+ + Send
+ + Sync
+ + 'static,
+>;
+pub(crate) struct ExtensionPromiseResolver<T> {
+ resolved_data: Option<T>,
+ creator: ExtensionCreator<T>,
+ url: Url,
+}
+
+impl<T> ExtensionPromiseResolver<T>
+where
+ T: Send + Clone + 'static,
+{
+ fn new(creator: ExtensionCreator<T>, url: Url) -> Self {
+ ExtensionPromiseResolver {
+ resolved_data: None,
+ creator,
+ url,
+ }
+ }
+
+ fn resolved_data(&self) -> Option<T> {
+ self.resolved_data.clone()
+ }
+
+ async fn resolve(&mut self) -> Result<T, StdError> {
+ match (self.creator)(self.url.clone()).await {
+ Ok(data) => {
+ self.resolved_data = Some(data.clone());
+ Ok(data)
+ }
+ Err(err) => {
+ error!("create extension failed: {}", err);
+ Err(LoadExtensionError::new(
+ "load extension failed, create extension occur an error".to_string(),
+ )
+ .into())
+ }
+ }
+ }
+}
+
+pub(crate) struct LoadExtensionPromise<T> {
+ resolver: Arc<RwLock<ExtensionPromiseResolver<T>>>,
+}
+
+impl<T> LoadExtensionPromise<T>
+where
+ T: Send + Clone + 'static,
+{
+ pub(crate) fn new(creator: ExtensionCreator<T>, url: Url) -> Self {
+ let resolver = ExtensionPromiseResolver::new(creator, url);
+ LoadExtensionPromise {
+ resolver: Arc::new(RwLock::new(resolver)),
+ }
+ }
+
+ pub(crate) async fn resolve(&mut self) -> Result<T, StdError> {
+ // get read lock
+ let resolver_read_lock = self.resolver.read().await;
+ // if extension is not None, return it
+ if let Some(extension) = resolver_read_lock.resolved_data() {
+ return Ok(extension);
+ }
+ drop(resolver_read_lock);
+
+ let mut write_lock = self.resolver.write().await;
+
+ match write_lock.resolved_data() {
+ Some(extension) => Ok(extension),
+ None => {
+ let extension = write_lock.resolve().await;
+ extension
+ }
+ }
+ }
+}
+
+impl<T> Clone for LoadExtensionPromise<T> {
+ fn clone(&self) -> Self {
+ LoadExtensionPromise {
+ resolver: self.resolver.clone(),
+ }
+ }
+}
+
pub struct ExtensionDirectoryCommander {
sender: tokio::sync::mpsc::Sender<ExtensionOpt>,
}
@@ -280,7 +380,7 @@
fn name() -> String;
- async fn create(url: &Url) -> Result<Self::Target, StdError>;
+ async fn create(url: Url) -> Result<Self::Target, StdError>;
}
#[allow(private_bounds)]
diff --git a/dubbo/src/extension/registry_extension.rs b/dubbo/src/extension/registry_extension.rs
index e27d6a5..ce998bc 100644
--- a/dubbo/src/extension/registry_extension.rs
+++ b/dubbo/src/extension/registry_extension.rs
@@ -29,6 +29,7 @@
use crate::extension::{
ConvertToExtensionFactories, Extension, ExtensionFactories, ExtensionMetaInfo, ExtensionType,
+ LoadExtensionPromise,
};
// extension://0.0.0.0/?extension-type=registry&extension-name=nacos®istry-url=nacos://127.0.0.1:8848
@@ -78,27 +79,9 @@
T: Extension<Target = Box<dyn Registry + Send + 'static>>,
{
fn convert_to_extension_factories() -> ExtensionFactories {
- fn constrain<F>(f: F) -> F
- where
- F: for<'a> Fn(
- &'a Url,
- ) -> Pin<
- Box<
- dyn Future<Output = Result<Box<dyn Registry + Send + 'static>, StdError>>
- + Send
- + 'a,
- >,
- >,
- {
- f
- }
-
- let constructor = constrain(|url: &Url| {
- let f = <T as Extension>::create(url);
- Box::pin(f)
- });
-
- ExtensionFactories::RegistryExtensionFactory(RegistryExtensionFactory::new(constructor))
+ ExtensionFactories::RegistryExtensionFactory(RegistryExtensionFactory::new(
+ <T as Extension>::create,
+ ))
}
}
@@ -108,19 +91,18 @@
}
impl RegistryExtensionLoader {
- pub(crate) async fn register(
- &mut self,
- extension_name: String,
- factory: RegistryExtensionFactory,
- ) {
+ pub(crate) fn register(&mut self, extension_name: String, factory: RegistryExtensionFactory) {
self.factories.insert(extension_name, factory);
}
- pub(crate) async fn remove(&mut self, extension_name: String) {
+ pub(crate) fn remove(&mut self, extension_name: String) {
self.factories.remove(&extension_name);
}
- pub(crate) async fn load(&mut self, url: &Url) -> Result<RegistryProxy, StdError> {
+ pub(crate) fn load(
+ &mut self,
+ url: Url,
+ ) -> Result<LoadExtensionPromise<RegistryProxy>, StdError> {
let extension_name = url.query::<ExtensionName>().unwrap();
let extension_name = extension_name.value();
let factory = self.factories.get_mut(&extension_name).ok_or_else(|| {
@@ -129,19 +111,19 @@
extension_name
))
})?;
- factory.create(url).await
+ factory.create(url)
}
}
-type RegistryConstructor = for<'a> fn(
- &'a Url,
+type RegistryConstructor = fn(
+ Url,
) -> Pin<
- Box<dyn Future<Output = Result<Box<dyn Registry + Send + 'static>, StdError>> + Send + 'a>,
+ Box<dyn Future<Output = Result<Box<dyn Registry + Send + 'static>, StdError>> + Send>,
>;
pub(crate) struct RegistryExtensionFactory {
constructor: RegistryConstructor,
- instances: HashMap<String, RegistryProxy>,
+ instances: HashMap<String, LoadExtensionPromise<RegistryProxy>>,
}
impl RegistryExtensionFactory {
@@ -154,7 +136,10 @@
}
impl RegistryExtensionFactory {
- pub(super) async fn create(&mut self, url: &Url) -> Result<RegistryProxy, StdError> {
+ pub(super) fn create(
+ &mut self,
+ url: Url,
+ ) -> Result<LoadExtensionPromise<RegistryProxy>, StdError> {
let registry_url = url.query::<RegistryUrl>().unwrap();
let registry_url = registry_url.value();
let url_str = registry_url.as_str().to_string();
@@ -164,10 +149,28 @@
Ok(proxy)
}
None => {
- let registry = (self.constructor)(url).await?;
- let proxy = <RegistryProxy as From<Box<dyn Registry + Send>>>::from(registry);
- self.instances.insert(url_str, proxy.clone());
- Ok(proxy)
+ let constructor = self.constructor;
+
+ let creator = move |url: Url| {
+ let registry = constructor(url);
+ Box::pin(async move {
+ let registry = registry.await?;
+ let proxy =
+ <RegistryProxy as From<Box<dyn Registry + Send>>>::from(registry);
+ Ok(proxy)
+ })
+ as Pin<
+ Box<
+ dyn Future<Output = Result<RegistryProxy, StdError>>
+ + Send
+ + 'static,
+ >,
+ >
+ };
+
+ let promise = LoadExtensionPromise::new(Box::new(creator), url);
+ self.instances.insert(url_str, promise.clone());
+ Ok(promise)
}
}
}
diff --git a/dubbo/src/registry/registry.rs b/dubbo/src/registry/registry.rs
index 85a8168..f998778 100644
--- a/dubbo/src/registry/registry.rs
+++ b/dubbo/src/registry/registry.rs
@@ -200,7 +200,7 @@
"static".to_string()
}
- async fn create(url: &Url) -> Result<Self::Target, StdError> {
+ async fn create(url: Url) -> Result<Self::Target, StdError> {
// url example:
// extension://0.0.0.0?extension-type=registry&extension-name=static®istry=static://127.0.0.1
let static_invoker_urls = url.query::<StaticInvokerUrls>();
diff --git a/registry/nacos/src/lib.rs b/registry/nacos/src/lib.rs
index 204846b..0507452 100644
--- a/registry/nacos/src/lib.rs
+++ b/registry/nacos/src/lib.rs
@@ -256,7 +256,7 @@
"nacos".to_string()
}
- async fn create(url: &Url) -> Result<Self::Target, StdError> {
+ async fn create(url: Url) -> Result<Self::Target, StdError> {
// url example:
// extension://0.0.0.0?extension-type=registry&extension-name=nacos®istry=nacos://127.0.0.1:8848
let registry_url = url.query::<RegistryUrl>().unwrap();
@@ -446,7 +446,7 @@
extension_url.add_query_param(ExtensionName::new("nacos".to_string()));
extension_url.add_query_param(RegistryUrl::new("nacos://127.0.0.1:8848/org.apache.dubbo.registry.RegistryService?application=dubbo-demo-triple-api-provider&dubbo=2.0.2&interface=org.apache.dubbo.registry.RegistryService&pid=7015".parse().unwrap()));
- let registry = NacosRegistry::create(&extension_url).await.unwrap();
+ let registry = NacosRegistry::create(extension_url).await.unwrap();
let mut service_url: Url = "tri://127.0.0.1:50052/org.apache.dubbo.demo.GreeterService?anyhost=true&application=dubbo-demo-triple-api-provider&background=false&deprecated=false&dubbo=2.0.2&dynamic=true&generic=false&interface=org.apache.dubbo.demo.GreeterService&methods=sayHello,sayHelloAsync&pid=7015&service-name-mapping=true&side=provider×tamp=1670060843807".parse().unwrap();
@@ -478,7 +478,7 @@
extension_url.add_query_param(ExtensionName::new("nacos".to_string()));
extension_url.add_query_param(RegistryUrl::new("nacos://127.0.0.1:8848/org.apache.dubbo.registry.RegistryService?application=dubbo-demo-triple-api-provider&dubbo=2.0.2&interface=org.apache.dubbo.registry.RegistryService&pid=7015".parse().unwrap()));
- let registry = NacosRegistry::create(&extension_url).await.unwrap();
+ let registry = NacosRegistry::create(extension_url).await.unwrap();
let mut service_url: Url = "tri://127.0.0.1:50052/org.apache.dubbo.demo.GreeterService?anyhost=true&application=dubbo-demo-triple-api-provider&background=false&deprecated=false&dubbo=2.0.2&dynamic=true&generic=false&interface=org.apache.dubbo.demo.GreeterService&methods=sayHello,sayHelloAsync&pid=7015&service-name-mapping=true&side=provider×tamp=1670060843807".parse().unwrap();
@@ -518,7 +518,7 @@
extension_url.add_query_param(ExtensionName::new("nacos".to_string()));
extension_url.add_query_param(RegistryUrl::new("nacos://127.0.0.1:8848/org.apache.dubbo.registry.RegistryService?application=dubbo-demo-triple-api-provider&dubbo=2.0.2&interface=org.apache.dubbo.registry.RegistryService&pid=7015".parse().unwrap()));
- let registry = NacosRegistry::create(&extension_url).await.unwrap();
+ let registry = NacosRegistry::create(extension_url).await.unwrap();
let mut service_url: Url = "tri://127.0.0.1:50052/org.apache.dubbo.demo.GreeterService?anyhost=true&application=dubbo-demo-triple-api-provider&background=false&deprecated=false&dubbo=2.0.2&dynamic=true&generic=false&interface=org.apache.dubbo.demo.GreeterService&methods=sayHello,sayHelloAsync&pid=7015&service-name-mapping=true&side=provider×tamp=1670060843807".parse().unwrap();
@@ -562,7 +562,7 @@
extension_url.add_query_param(ExtensionName::new("nacos".to_string()));
extension_url.add_query_param(RegistryUrl::new("nacos://127.0.0.1:8848/org.apache.dubbo.registry.RegistryService?application=dubbo-demo-triple-api-provider&dubbo=2.0.2&interface=org.apache.dubbo.registry.RegistryService&pid=7015".parse().unwrap()));
- let registry = NacosRegistry::create(&extension_url).await.unwrap();
+ let registry = NacosRegistry::create(extension_url).await.unwrap();
let mut service_url: Url = "tri://127.0.0.1:50052/org.apache.dubbo.demo.GreeterService?anyhost=true&application=dubbo-demo-triple-api-provider&background=false&deprecated=false&dubbo=2.0.2&dynamic=true&generic=false&interface=org.apache.dubbo.demo.GreeterService&methods=sayHello,sayHelloAsync&pid=7015&service-name-mapping=true&side=provider×tamp=1670060843807".parse().unwrap();