blob: 657d72fbed92eac171fcf40ab0f8f3b180a4d57f [file] [log] [blame]
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
namespace Casbin.Rbac
{
public class DefaultRoleManager : IRoleManager
{
private readonly string _defaultDomain = string.Empty;
private readonly ConcurrentDictionary<string, ConcurrentDictionary<string, Role>> _allDomains = new();
private readonly int _maxHierarchyLevel;
private IEnumerable<string> _cachedAllDomains;
private readonly ConcurrentDictionary<string, Role> _defaultRoles = new();
public Func<string, string, bool> MatchingFunc { get; set; }
public Func<string, string, bool> DomainMatchingFunc { get; set; }
public bool HasPattern => MatchingFunc is not null;
public bool HasDomainPattern => DomainMatchingFunc is not null;
public DefaultRoleManager(int maxHierarchyLevel)
{
_allDomains[_defaultDomain] = _defaultRoles;
_maxHierarchyLevel = maxHierarchyLevel;
}
public virtual IEnumerable<string> GetDomains(string name)
{
_cachedAllDomains ??= _allDomains.Keys;
var domains = new HashSet<string>();
foreach (string domain in _cachedAllDomains)
{
if (AnyRolesInDomain(name, domain))
{
domains.Add(domain);
}
}
return domains;
}
public virtual IEnumerable<string> GetRoles(string name, string domain = null)
{
domain ??= _defaultDomain;
if (HasDomainPattern is false)
{
return GetRolesInDomain(name, domain);
}
var roleNames = new List<string>();
foreach (string matchDomain in GetPatternDomains(domain))
{
roleNames.AddRange(GetRolesInDomain(name, matchDomain));
}
return roleNames.Distinct();
}
private bool AnyRolesInDomain(string name, string domain)
{
ConcurrentDictionary<string, Role> roles;
if (domain == _defaultDomain)
{
roles = _defaultRoles;
}
else
{
if (_allDomains.TryGetValue(domain, out roles) is false)
{
return false;
}
}
if (HasPattern is false)
{
return roles.ContainsKey(name);
}
foreach (var role in roles)
{
if (MatchingFunc(name, role.Key))
{
return true;
}
}
return false;
}
private IEnumerable<string> GetRolesInDomain(string name, string domain)
{
ConcurrentDictionary<string, Role> roles;
if (domain == _defaultDomain)
{
roles = _defaultRoles;
}
else
{
if (_allDomains.TryGetValue(domain, out roles) is false)
{
return Enumerable.Empty<string>();
}
}
if (HasPattern is false)
{
return roles.TryGetValue(name, out var role)
? role.GetRoles() : Enumerable.Empty<string>();
}
var rolesNames = new List<string>();
foreach (var role in roles)
{
if (MatchingFunc(name, role.Key))
{
rolesNames.AddRange(role.Value.GetRoles());
}
}
return rolesNames;
}
public virtual IEnumerable<string> GetUsers(string name, string domain = null)
{
domain ??= _defaultDomain;
if (HasDomainPattern is false)
{
return GetUsersInDomain(name, domain);
}
var userNames = new List<string>();
foreach (string matchDomain in GetPatternDomains(domain))
{
userNames.AddRange(GetUsersInDomain(name, matchDomain));
}
return userNames.Distinct();
}
private IEnumerable<string> GetUsersInDomain(string name, string domain)
{
ConcurrentDictionary<string, Role> roles;
if (domain == _defaultDomain)
{
roles = _defaultRoles;
}
else
{
if (_allDomains.TryGetValue(domain, out roles) is false)
{
return Enumerable.Empty<string>();
}
}
var userNames = new List<string>();
foreach (var role in roles)
{
if (role.Value.HasDirectRole(name))
{
userNames.Add(role.Key);
}
}
return userNames;
}
public virtual bool HasLink(string name1, string name2, string domain = null)
{
if (string.Equals(name1, name2))
{
return true;
}
domain ??= _defaultDomain;
if (HasDomainPattern is false)
{
return HasLinkInDomain(name1, name2, domain);
}
foreach (string matchDomain in GetPatternDomains(domain))
{
if (HasLinkInDomain(name1, name2, matchDomain))
{
return true;
}
}
return false;
}
private bool HasLinkInDomain(string name1, string name2, string domain)
{
ConcurrentDictionary<string, Role> roles;
if (domain == _defaultDomain)
{
roles = _defaultRoles;
}
else
{
if (_allDomains.TryGetValue(domain, out roles) is false)
{
return false;
}
}
if (HasPattern is false)
{
return roles.TryGetValue(name1, out var role1) is not false
&& role1.HasRole(name2, _maxHierarchyLevel);
}
foreach (var role in roles)
{
if (MatchingFunc(name1, role.Key) is false)
{
continue;
}
if (role.Value.HasRole(name2, _maxHierarchyLevel, MatchingFunc))
{
return true;
}
}
return false;
}
public virtual void AddLink(string name1, string name2, string domain = null)
{
domain ??= _defaultDomain;
_cachedAllDomains = null;
AddLinkInDomain(name1, name2, domain);
}
private void AddLinkInDomain(string name1, string name2, string domain)
{
ConcurrentDictionary<string, Role> roles = domain == _defaultDomain
? _defaultRoles
: _allDomains.GetOrAdd(domain, new ConcurrentDictionary<string, Role>());
bool role1IsNew = roles.ContainsKey(name1) is false;
bool role2IsNew = roles.ContainsKey(name2) is false;
Role role1 = roles.GetOrAdd(name1, new Role(name1, domain));
Role role2 = roles.GetOrAdd(name2, new Role(name2, domain));
role1.AddRole(role2);
if (HasDomainPattern)
{
if (role1IsNew)
{
AddLinksFromMatchingDomains(role1);
}
if (role2IsNew)
{
AddLinksToMatchingDomains(role2);
}
}
if (HasPattern is false)
{
return;
}
foreach (var role in roles)
{
string name = role.Key;
if (name == name1 || name == name2)
{
return;
}
if (MatchingFunc(name, name1) || MatchingFunc(name1, name))
{
role.Value.AddRole(role1);
}
if (MatchingFunc(name, name2) || MatchingFunc(name2, name))
{
role2.AddRole(role.Value);
}
}
}
private void AddLinksFromMatchingDomains(Role role)
{
if (HasDomainPattern is false)
{
return;
}
IEnumerable<string> matchingDomains = GetMatchingDomains(role.Domain);
foreach (string domain in matchingDomains)
{
if (domain == role.Domain || _allDomains.TryGetValue(domain, out var matchingDomain) is false)
{
continue;
}
if (HasPattern is false)
{
if (matchingDomain.TryGetValue(role.Name, out var matchingRole) && role != matchingRole)
{
matchingRole.AddRole(role);
};
continue;
}
var roleNames = matchingDomain.Keys;
foreach (string roleName in roleNames)
{
if (MatchingFunc(roleName, role.Name) is false)
{
continue;
}
if (matchingDomain.TryGetValue(roleName, out var matchingRole) && role != matchingRole)
{
matchingRole.AddRole(role);
}
}
}
}
private void AddLinksToMatchingDomains(Role role)
{
if (HasDomainPattern is false)
{
return;
}
IEnumerable<string> matchingDomains = GetPatternDomains(role.Domain);
foreach (string domain in matchingDomains)
{
if (domain == role.Domain || _allDomains.TryGetValue(domain, out var matchingDomain) is false)
{
continue;
}
if (HasPattern is false)
{
if (matchingDomain.TryGetValue(role.Name, out var matchingRole) && role != matchingRole)
{
role.AddRole(matchingRole);
};
continue;
}
var roleNames = matchingDomain.Keys;
foreach (string roleName in roleNames)
{
if (MatchingFunc(role.Name, roleName) is false)
{
continue;
}
if (matchingDomain.TryGetValue(roleName, out var matchingRole) && matchingRole != role)
{
role.AddRole(matchingRole);
}
}
}
}
public virtual void DeleteLink(string name1, string name2, string domain = null)
{
domain ??= _defaultDomain;
ConcurrentDictionary<string, Role> roles;
if (domain == _defaultDomain)
{
roles = _defaultRoles;
}
else
{
if (_allDomains.TryGetValue(domain, out roles) is false)
{
ThrowHelper.ThrowOneOfNamesNotFoundException();
return;
}
}
if (roles.TryGetValue(name1, out var role1) is false
|| roles.TryGetValue(name2, out var role2) is false)
{
ThrowHelper.ThrowOneOfNamesNotFoundException();
return;
}
role1.DeleteRole(role2);
}
public virtual void Clear()
{
_allDomains.Clear();
_defaultRoles.Clear();
_allDomains[_defaultDomain] = _defaultRoles;
_cachedAllDomains = null;
}
private IEnumerable<string> GetPatternDomains(string domain)
{
List<string> matchDomains = new() { domain };
_cachedAllDomains ??= _allDomains.Keys;
if (HasDomainPattern)
{
matchDomains.AddRange(_cachedAllDomains.Where(key =>
DomainMatchingFunc(domain, key) && key != domain));
}
return matchDomains;
}
private IEnumerable<string> GetMatchingDomains(string domainPattern)
{
List<string> matchDomains = new() { domainPattern };
_cachedAllDomains ??= _allDomains.Keys;
if (HasDomainPattern)
{
matchDomains.AddRange(_cachedAllDomains.Where(key =>
DomainMatchingFunc(key, domainPattern) && key != domainPattern));
}
return matchDomains;
}
}
}