blob: a278abf6ccc4fd66c1db0e3ea00760747c904b6c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.zookeeper.server;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.jute.Index;
import org.apache.jute.InputArchive;
import org.apache.jute.OutputArchive;
import org.apache.zookeeper.ZooDefs;
import org.apache.zookeeper.data.ACL;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class ReferenceCountedACLCache {
private static final Logger LOG = LoggerFactory.getLogger(ReferenceCountedACLCache.class);
final Map<Long, List<ACL>> longKeyMap = new HashMap<Long, List<ACL>>();
final Map<List<ACL>, Long> aclKeyMap = new HashMap<List<ACL>, Long>();
final Map<Long, AtomicLongWithEquals> referenceCounter = new HashMap<Long, AtomicLongWithEquals>();
private static final long OPEN_UNSAFE_ACL_ID = -1L;
/**
* these are the number of acls that we have in the datatree
*/
long aclIndex = 0;
/**
* converts the list of acls to a long.
* Increments the reference counter for this ACL.
* @param acls
* @return a long that map to the acls
*/
public synchronized Long convertAcls(List<ACL> acls) {
if (acls == null) {
return OPEN_UNSAFE_ACL_ID;
}
// get the value from the map
Long ret = aclKeyMap.get(acls);
if (ret == null) {
ret = incrementIndex();
longKeyMap.put(ret, acls);
aclKeyMap.put(acls, ret);
}
addUsage(ret);
return ret;
}
/**
* converts a long to a list of acls.
*
* @param longVal
* @return a list of ACLs that map to the long
*/
public synchronized List<ACL> convertLong(Long longVal) {
if (longVal == null) {
return null;
}
if (longVal == OPEN_UNSAFE_ACL_ID) {
return ZooDefs.Ids.OPEN_ACL_UNSAFE;
}
List<ACL> acls = longKeyMap.get(longVal);
if (acls == null) {
LOG.error("ERROR: ACL not available for long " + longVal);
throw new RuntimeException("Failed to fetch acls for " + longVal);
}
return acls;
}
private long incrementIndex() {
return ++aclIndex;
}
public synchronized void deserialize(InputArchive ia) throws IOException {
clear();
int i = ia.readInt("map");
while (i > 0) {
Long val = ia.readLong("long");
if (aclIndex < val) {
aclIndex = val;
}
List<ACL> aclList = new ArrayList<ACL>();
Index j = ia.startVector("acls");
if (j == null) {
throw new RuntimeException("Incorrent format of InputArchive when deserialize DataTree - missing acls");
}
while (!j.done()) {
ACL acl = new ACL();
acl.deserialize(ia, "acl");
aclList.add(acl);
j.incr();
}
longKeyMap.put(val, aclList);
aclKeyMap.put(aclList, val);
referenceCounter.put(val, new AtomicLongWithEquals(0));
i--;
}
}
public synchronized void serialize(OutputArchive oa) throws IOException {
oa.writeInt(longKeyMap.size(), "map");
Set<Map.Entry<Long, List<ACL>>> set = longKeyMap.entrySet();
for (Map.Entry<Long, List<ACL>> val : set) {
oa.writeLong(val.getKey(), "long");
List<ACL> aclList = val.getValue();
oa.startVector(aclList, "acls");
for (ACL acl : aclList) {
acl.serialize(oa, "acl");
}
oa.endVector(aclList, "acls");
}
}
public int size() {
return aclKeyMap.size();
}
private void clear() {
aclKeyMap.clear();
longKeyMap.clear();
referenceCounter.clear();
}
public synchronized void addUsage(Long acl) {
if (acl == OPEN_UNSAFE_ACL_ID) {
return;
}
if (!longKeyMap.containsKey(acl)) {
LOG.info("Ignoring acl " + acl + " as it does not exist in the cache");
return;
}
AtomicLong count = referenceCounter.get(acl);
if (count == null) {
referenceCounter.put(acl, new AtomicLongWithEquals(1));
} else {
count.incrementAndGet();
}
}
public synchronized void removeUsage(Long acl) {
if (acl == OPEN_UNSAFE_ACL_ID) {
return;
}
if (!longKeyMap.containsKey(acl)) {
LOG.info("Ignoring acl " + acl + " as it does not exist in the cache");
return;
}
long newCount = referenceCounter.get(acl).decrementAndGet();
if (newCount <= 0) {
referenceCounter.remove(acl);
aclKeyMap.remove(longKeyMap.get(acl));
longKeyMap.remove(acl);
}
}
public synchronized void purgeUnused() {
Iterator<Map.Entry<Long, AtomicLongWithEquals>> refCountIter = referenceCounter.entrySet().iterator();
while (refCountIter.hasNext()) {
Map.Entry<Long, AtomicLongWithEquals> entry = refCountIter.next();
if (entry.getValue().get() <= 0) {
Long acl = entry.getKey();
aclKeyMap.remove(longKeyMap.get(acl));
longKeyMap.remove(acl);
refCountIter.remove();
}
}
}
private static class AtomicLongWithEquals extends AtomicLong {
private static final long serialVersionUID = 3355155896813725462L;
public AtomicLongWithEquals(long i) {
super(i);
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
return equals((AtomicLongWithEquals) o);
}
public boolean equals(AtomicLongWithEquals that) {
return get() == that.get();
}
@Override
public int hashCode() {
return 31 * Long.valueOf(get()).hashCode();
}
}
}