blob: 205684d64be9458fc853275dcddee3743869438e [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.qpid.server.user.connection.limits.config;
import java.security.Principal;
import java.time.Duration;
import java.time.Instant;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import java.util.Map.Entry;
import java.util.NavigableSet;
import java.util.Objects;
import java.util.Queue;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.security.auth.Subject;
import org.apache.qpid.server.security.group.GroupPrincipal;
import org.apache.qpid.server.security.limit.ConnectionLimitException;
import org.apache.qpid.server.security.limit.ConnectionLimiter;
import org.apache.qpid.server.security.limit.ConnectionSlot;
import org.apache.qpid.server.transport.AMQPConnection;
import org.apache.qpid.server.user.connection.limits.outcome.AcceptRegistration;
import org.apache.qpid.server.user.connection.limits.outcome.RejectRegistration;
final class PortConnectionCounter
{
private final Function<String, ConnectionCounter> _connectionCounterFactory;
private final Map<String, ConnectionCounter> _connectionCounters = new ConcurrentHashMap<>();
PortConnectionCounter(AbstractBuilder<?> builder)
{
super();
_connectionCounterFactory = builder.getConnectionCounterFactory();
}
public AcceptRegistration register(AMQPConnection<?> connection, ConnectionLimiter subLimiter)
{
final Principal principal = connection.getAuthorizedPrincipal();
if (principal == null)
{
throw new ConnectionLimitException("Unauthorized connection is forbidden");
}
final String userId = principal.getName();
return _connectionCounters.computeIfAbsent(userId, _connectionCounterFactory)
.registerConnection(
userId,
collectGroupPrincipals(connection.getSubject()),
connection,
subLimiter);
}
private Set<String> collectGroupPrincipals(Subject subject)
{
if (subject == null)
{
return Collections.emptySet();
}
final Set<String> principalNames = new HashSet<>();
for (final Principal principal : subject.getPrincipals(GroupPrincipal.class))
{
principalNames.add(principal.getName());
}
return principalNames;
}
static Builder newBuilder(Duration defaultFrequencyPeriod)
{
if (defaultFrequencyPeriod == null || defaultFrequencyPeriod.isNegative())
{
return new BuilderWithoutFrequencyImpl();
}
return new BuilderImpl();
}
public interface Builder
{
Builder add(Rule rule);
Builder addAll(Collection<? extends Rule> rule);
PortConnectionCounter build();
}
private interface ConnectionCounter extends ConnectionSlot
{
AcceptRegistration registerConnection(
String userId, Set<String> groups, AMQPConnection<?> connection, ConnectionLimiter subLimiter);
}
abstract static class AbstractBuilder<T extends CombinableLimit<T>> implements Builder
{
final Map<String, T> _userLimits = new HashMap<>();
T _defaultUserLimits;
abstract T newLimits(Rule rule);
abstract Function<String, ConnectionCounter> getConnectionCounterFactory();
AbstractBuilder(T defaultUserLimits)
{
super();
_defaultUserLimits = defaultUserLimits;
}
@Override
public Builder add(Rule rule)
{
addImpl(rule);
return this;
}
@Override
public Builder addAll(Collection<? extends Rule> rule)
{
if (rule != null)
{
rule.forEach(this::addImpl);
}
return this;
}
@Override
public PortConnectionCounter build()
{
return new PortConnectionCounter(this);
}
private void addImpl(Rule rule)
{
if (rule == null)
{
return;
}
final T newLimits = newLimits(rule);
if (newLimits.isEmpty())
{
return;
}
final String id = rule.getIdentity();
if (RulePredicates.isAllUser(id))
{
_defaultUserLimits = newLimits.mergeWith(_defaultUserLimits);
}
else
{
_userLimits.merge(id, newLimits, CombinableLimit::mergeWith);
}
}
}
private static final class BuilderImpl extends AbstractBuilder<ConnectionLimits>
{
BuilderImpl()
{
super(ConnectionLimits.noLimits());
}
@Override
ConnectionLimits newLimits(Rule rule)
{
if (rule.getFrequencyLimit() != null &&
(rule.getFrequencyPeriod() == null || rule.getFrequencyPeriod().isNegative()))
{
throw new IllegalArgumentException("Connection frequency limit needs a time period");
}
return rule;
}
Function<String, ConnectionCounter> getConnectionCounterFactory()
{
return userId -> new CombinedConnectionCounterImpl(
new LimitCompiler<>(_userLimits, _defaultUserLimits, ConnectionLimits::noLimits));
}
}
private static final class BuilderWithoutFrequencyImpl extends AbstractBuilder<ConnectionCountLimit>
{
BuilderWithoutFrequencyImpl()
{
super(ConnectionCountLimit.noLimits());
}
@Override
ConnectionCountLimit newLimits(Rule rule)
{
return ConnectionCountLimit.newInstance(rule);
}
@Override
Function<String, ConnectionCounter> getConnectionCounterFactory()
{
return userId -> new ConnectionCounterImpl(
new LimitCompiler<>(_userLimits, _defaultUserLimits, ConnectionCountLimit::noLimits));
}
}
private static final class CombinedConnectionCounterImpl implements ConnectionCounter
{
private final LimitCompiler<ConnectionLimits> _limits;
private final Queue<Instant> _registrationTime = new LinkedList<>();
private long _counter = 0L;
CombinedConnectionCounterImpl(LimitCompiler<ConnectionLimits> limits)
{
super();
_limits = limits;
}
@Override
public synchronized void free()
{
_counter = Math.max(_counter - 1L, 0L);
}
@Override
public AcceptRegistration registerConnection(
String userId, Set<String> groups, AMQPConnection<?> connection, ConnectionLimiter subLimiter)
{
final ConnectionLimits limits = _limits.compileLimits(userId, groups);
if (limits.isUserBlocked())
{
throw RejectRegistration.blockedUser(userId, connection.getPort().getName());
}
final Integer connectionCountLimit = limits.getCountLimit();
final Map<Duration, Integer> connectionFrequencyLimit = limits.getFrequencyLimits();
if (connectionFrequencyLimit.isEmpty())
{
if (connectionCountLimit == null)
{
return noLimits(userId, connection, subLimiter);
}
return countLimit(connectionCountLimit, userId, connection, subLimiter);
}
if (connectionCountLimit == null)
{
return frequencyLimit(connectionFrequencyLimit, userId, connection, subLimiter);
}
return bothLimits(connectionCountLimit, connectionFrequencyLimit, userId, connection, subLimiter);
}
private synchronized AcceptRegistration noLimits(
String id, AMQPConnection<?> connection, ConnectionLimiter subLimiter)
{
_registrationTime.clear();
final AcceptRegistration result = AcceptRegistration.newInstance(
chainTo(subLimiter.register(connection)),
id,
_counter + 1L,
connection.getPort().getName());
_counter = _counter + 1L;
return result;
}
private synchronized AcceptRegistration countLimit(
Integer countLimit, String id, AMQPConnection<?> connection, ConnectionLimiter subLimiter)
{
_registrationTime.clear();
checkCounter(countLimit, id, connection);
final AcceptRegistration result = AcceptRegistration.newInstance(
chainTo(subLimiter.register(connection)),
id,
_counter + 1L,
connection.getPort().getName());
_counter = _counter + 1L;
return result;
}
private synchronized AcceptRegistration frequencyLimit(
Map<Duration, Integer> frequencyLimit, String id, AMQPConnection<?> connection, ConnectionLimiter subLimiter)
{
final Instant now = checkFrequencyLimit(frequencyLimit, id, connection);
final AcceptRegistration result = AcceptRegistration.newInstance(
chainTo(subLimiter.register(connection)),
id,
_counter + 1L,
connection.getPort().getName());
_registrationTime.add(now);
_counter = _counter + 1L;
return result;
}
private synchronized AcceptRegistration bothLimits(
Integer countLimit, Map<Duration, Integer> frequencyLimit, String id,
AMQPConnection<?> connection, ConnectionLimiter subLimiter)
{
final Instant now = checkFrequencyLimit(frequencyLimit, id, connection);
checkCounter(countLimit, id, connection);
final AcceptRegistration result = AcceptRegistration.newInstance(
chainTo(subLimiter.register(connection)),
id,
_counter + 1L,
connection.getPort().getName());
_registrationTime.add(now);
_counter = _counter + 1L;
return result;
}
private void checkCounter(Integer countLimit, String id, AMQPConnection<?> connection)
{
if (_counter >= countLimit)
{
throw RejectRegistration.breakingConnectionCount(id, countLimit, connection.getPort().getName());
}
}
private Instant checkFrequencyLimit(
final Map<Duration, Integer> frequencyLimit, String id, AMQPConnection<?> connection)
{
final Instant now = Instant.now();
final NavigableSet<FrequencyPeriod> periods = new TreeSet<>();
for (final Entry<Duration, Integer> entry : frequencyLimit.entrySet())
{
periods.add(new FrequencyPeriod(now, entry));
}
final Iterator<Instant> registrationTimeIterator = _registrationTime.iterator();
int counter = _registrationTime.size();
boolean obsoleteRegistrationTime = true;
while (registrationTimeIterator.hasNext() && !periods.isEmpty())
{
final Instant registrationTime = registrationTimeIterator.next();
final Iterator<FrequencyPeriod> periodIterator = periods.iterator();
FrequencyPeriod period;
while (periodIterator.hasNext() &&
(period = periodIterator.next()).isPeriodBeginningBefore(registrationTime))
{
if (!period.isCounterUnderLimit(counter))
{
throw RejectRegistration.breakingConnectionFrequency(
id, period.getLimit(), period.getDuration(), connection.getPort().getName());
}
obsoleteRegistrationTime = false;
periodIterator.remove();
}
counter--;
if (obsoleteRegistrationTime)
{
registrationTimeIterator.remove();
}
}
return now;
}
}
private static final class ConnectionCounterImpl implements ConnectionCounter
{
private final LimitCompiler<ConnectionCountLimit> _limits;
private long _counter = 0L;
ConnectionCounterImpl(LimitCompiler<ConnectionCountLimit> limits)
{
super();
_limits = limits;
}
@Override
public synchronized void free()
{
_counter = Math.max(_counter - 1L, 0L);
}
@Override
public AcceptRegistration registerConnection(
String userId, Set<String> groups, AMQPConnection<?> connection, ConnectionLimiter subLimiter)
{
final ConnectionCountLimit limits = _limits.compileLimits(userId, groups);
if (limits.isUserBlocked())
{
throw RejectRegistration.blockedUser(userId, connection.getPort().getName());
}
if (limits.getCountLimit() == null)
{
return noLimits(userId, connection, subLimiter);
}
return countLimit(limits.getCountLimit(), userId, connection, subLimiter);
}
private synchronized AcceptRegistration noLimits(
String id, AMQPConnection<?> connection, ConnectionLimiter subLimiter)
{
final AcceptRegistration result = AcceptRegistration.newInstance(
chainTo(subLimiter.register(connection)),
id,
_counter + 1L,
connection.getPort().getName());
_counter = _counter + 1L;
return result;
}
private synchronized AcceptRegistration countLimit(
Integer countLimit, String id, AMQPConnection<?> connection, ConnectionLimiter subLimiter)
{
if (_counter >= countLimit)
{
throw RejectRegistration.breakingConnectionCount(id, countLimit, connection.getPort().getName());
}
final AcceptRegistration result = AcceptRegistration.newInstance(
chainTo(subLimiter.register(connection)),
id,
_counter + 1L,
connection.getPort().getName());
_counter = _counter + 1L;
return result;
}
}
private static final class LimitCompiler<T extends CombinableLimit<T>>
{
private final Map<String, T> _limitMap;
private final T _defaultLimits;
private final Supplier<T> _noLimits;
LimitCompiler(Map<String, T> limitMap, T defaultLimits, Supplier<T> noLimits)
{
_limitMap = new HashMap<>(limitMap);
_defaultLimits = Objects.requireNonNull(defaultLimits);
_noLimits = Objects.requireNonNull(noLimits);
}
public T compileLimits(String userId, Set<String> groups)
{
final T userLimits = _noLimits.get().mergeWith(_limitMap.get(userId));
T groupLimits = _noLimits.get();
for (final String group : groups)
{
groupLimits = groupLimits.mergeWith(_limitMap.get(group));
}
return userLimits.then(groupLimits).then(_defaultLimits);
}
}
private static final class FrequencyPeriod implements Comparable<FrequencyPeriod>
{
private final Instant _startTime;
private final Duration _duration;
private final int _limit;
FrequencyPeriod(Instant now, Map.Entry<Duration, Integer> limit)
{
_startTime = now.minus(limit.getKey());
_duration = limit.getKey();
_limit = limit.getValue();
}
@Override
public int compareTo(FrequencyPeriod o)
{
return _startTime.compareTo(o._startTime);
}
@Override
public int hashCode()
{
return _startTime.hashCode();
}
@Override
public boolean equals(Object o)
{
if (o instanceof FrequencyPeriod)
{
return _startTime.equals(((FrequencyPeriod) o)._startTime);
}
return false;
}
boolean isPeriodBeginningBefore(Instant time)
{
return _startTime.isBefore(time);
}
boolean isCounterUnderLimit(int counter)
{
return counter < _limit;
}
Duration getDuration()
{
return _duration;
}
int getLimit()
{
return _limit;
}
}
}