blob: 61cb98308b0e814024034cdb2f6958bec43aa2e9 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.ranger.common.db;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import javax.persistence.EntityManager;
import javax.persistence.NoResultException;
import javax.persistence.Query;
import javax.persistence.Table;
import javax.persistence.TypedQuery;
import org.apache.ranger.authorization.hadoop.config.RangerAdminConfig;
import org.apache.ranger.biz.RangerBizUtil;
import org.apache.ranger.common.AppConstants;
import org.apache.ranger.db.RangerDaoManager;
import org.apache.ranger.db.RangerDaoManagerBase;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public abstract class BaseDao<T> {
private static final Logger logger = LoggerFactory.getLogger(BaseDao.class);
private static final String PROP_BATCH_DELETE_BATCH_SIZE = "ranger.admin.dao.batch.delete.batch.size";
private static final int DEFAULT_BATCH_DELETE_BATCH_SIZE = 1000;
private static int BATCH_DELETE_BATCH_SIZE;
private static final String GDS_TABLES = "x_gds_";
static {
try {
BATCH_DELETE_BATCH_SIZE = RangerAdminConfig.getInstance().getInt(PROP_BATCH_DELETE_BATCH_SIZE, DEFAULT_BATCH_DELETE_BATCH_SIZE);
if (BATCH_DELETE_BATCH_SIZE > DEFAULT_BATCH_DELETE_BATCH_SIZE) {
logger.warn("Configuration {}={}, which is larger than default value {}", PROP_BATCH_DELETE_BATCH_SIZE, BATCH_DELETE_BATCH_SIZE, DEFAULT_BATCH_DELETE_BATCH_SIZE);
}
} catch(Exception e) {
// When we get the Number format exception due to the invalid value entered into the config file.
BATCH_DELETE_BATCH_SIZE = DEFAULT_BATCH_DELETE_BATCH_SIZE;
}
logger.info(PROP_BATCH_DELETE_BATCH_SIZE + "=" + BATCH_DELETE_BATCH_SIZE);
}
protected RangerDaoManager daoManager;
EntityManager em;
protected Class<T> tClass;
public BaseDao(RangerDaoManagerBase daoManager) {
this.daoManager = (RangerDaoManager) daoManager;
this.init(daoManager.getEntityManager());
}
public BaseDao(RangerDaoManagerBase daoManager, String persistenceContextUnit) {
this.daoManager = (RangerDaoManager) daoManager;
EntityManager em = this.daoManager.getEntityManager(persistenceContextUnit);
this.init(em);
}
@SuppressWarnings("unchecked")
private void init(EntityManager em) {
this.em = em;
ParameterizedType genericSuperclass = (ParameterizedType) getClass()
.getGenericSuperclass();
Type type = genericSuperclass.getActualTypeArguments()[0];
if (type instanceof ParameterizedType) {
this.tClass = (Class<T>) ((ParameterizedType) type).getRawType();
} else {
this.tClass = (Class<T>) type;
}
}
public EntityManager getEntityManager() {
return this.em;
}
public T create(T obj) {
T ret = null;
em.persist(obj);
if (!RangerBizUtil.isBulkMode()) {
em.flush();
}
ret = obj;
return ret;
}
public List<T> batchCreate(List<T> obj) {
List<T> ret = null;
for (int n = 0; n < obj.size(); ++n) {
em.persist(obj.get(n));
if (!RangerBizUtil.isBulkMode() && (n % RangerBizUtil.batchPersistSize == 0)) {
em.flush();
}
}
if (!RangerBizUtil.isBulkMode()) {
em.flush();
}
ret = obj;
return ret;
}
public void batchDeleteByIds(String namedQuery, List<Long> ids, String paramName) {
if (BATCH_DELETE_BATCH_SIZE <= 0) {
getEntityManager()
.createNamedQuery(namedQuery, tClass)
.setParameter(paramName, ids).executeUpdate();
} else {
for (int fromIndex = 0; fromIndex < ids.size(); fromIndex += BATCH_DELETE_BATCH_SIZE) {
int toIndex = fromIndex + BATCH_DELETE_BATCH_SIZE;
if (toIndex > ids.size()) {
toIndex = ids.size();
}
if (logger.isDebugEnabled()) {
logger.debug("batchDeleteByIds({}, idCount={}): deleting fromIndex={}, toIndex={}", namedQuery, ids.size(), fromIndex, toIndex);
}
List<Long> subList = ids.subList(fromIndex, toIndex);
getEntityManager()
.createNamedQuery(namedQuery, tClass)
.setParameter(paramName, subList).executeUpdate();
}
}
}
public T update(T obj) {
em.merge(obj);
if (!RangerBizUtil.isBulkMode()) {
em.flush();
}
return obj;
}
public boolean remove(Long id) {
return remove(getById(id));
}
public boolean remove(T obj) {
if (obj == null) {
return true;
}
if (!em.contains(obj)) {
obj = em.merge(obj);
}
em.remove(obj);
if (!RangerBizUtil.isBulkMode()) {
em.flush();
}
return true;
}
public void flush() {
em.flush();
}
public void clear() {
em.clear();
}
public T create(T obj, boolean flush) {
T ret = null;
em.persist(obj);
if(flush) {
em.flush();
}
ret = obj;
return ret;
}
public T update(T obj, boolean flush) {
em.merge(obj);
if(flush) {
em.flush();
}
return obj;
}
public boolean remove(T obj, boolean flush) {
if (obj == null) {
return true;
}
em.remove(obj);
if(flush) {
em.flush();
}
return true;
}
public T getById(Long id) {
if (id == null) {
return null;
}
T ret = null;
try {
ret = em.find(tClass, id);
} catch (NoResultException e) {
return null;
}
return ret;
}
public List<T> findByNamedQuery(String namedQuery, String paramName,
Object refId) {
List<T> ret = new ArrayList<T>();
if (namedQuery == null) {
return ret;
}
try {
TypedQuery<T> qry = em.createNamedQuery(namedQuery, tClass);
qry.setParameter(paramName, refId);
ret = qry.getResultList();
} catch (NoResultException e) {
// ignore
}
return ret;
}
public List<T> findByParentId(Long parentId) {
String namedQuery = tClass.getSimpleName() + ".findByParentId";
return findByNamedQuery(namedQuery, "parentId", parentId);
}
public List<T> executeQueryInSecurityContext(Class<T> clazz, Query query) {
return executeQueryInSecurityContext(clazz, query, true);
}
@SuppressWarnings("unchecked")
public List<T> executeQueryInSecurityContext(Class<T> clazz, Query query,
boolean userPrefFilter) {
// boolean filterEnabled = false;
List<T> rtrnList = null;
// filterEnabled = enableVisiblityFilters(clazz, userPrefFilter);
rtrnList = query.getResultList();
return rtrnList;
}
public List<Long> getIds(Query query) {
return (List<Long>) query.getResultList();
}
public Long executeCountQueryInSecurityContext(Class<T> clazz, Query query) { //NOPMD
return (Long) query.getSingleResult();
}
public List<T> getAll() {
List<T> ret = null;
TypedQuery<T> qry = em.createQuery(
"SELECT t FROM " + tClass.getSimpleName() + " t", tClass);
ret = qry.getResultList();
return ret;
}
public Long getAllCount() {
Long ret = null;
TypedQuery<Long> qry = em.createQuery(
"SELECT count(t) FROM " + tClass.getSimpleName() + " t",
Long.class);
ret = qry.getSingleResult();
return ret;
}
public void updateSequence(String seqName, long nextValue) {
if(RangerBizUtil.getDBFlavor() == AppConstants.DB_FLAVOR_ORACLE) {
String[] queries = {
"ALTER SEQUENCE " + seqName + " INCREMENT BY " + (nextValue - 1),
"select " + seqName + ".nextval from dual",
"ALTER SEQUENCE " + seqName + " INCREMENT BY 1 NOCACHE NOCYCLE"
};
for(String query : queries) {
getEntityManager().createNativeQuery(query).executeUpdate();
}
} else if(RangerBizUtil.getDBFlavor() == AppConstants.DB_FLAVOR_POSTGRES) {
String query = "SELECT setval('" + seqName + "', " + nextValue + ")";
getEntityManager().createNativeQuery(query).getSingleResult();
}
}
public void setIdentityInsert(boolean identityInsert) {
if (RangerBizUtil.getDBFlavor() != AppConstants.DB_FLAVOR_SQLSERVER) {
logger.debug("Ignoring BaseDao.setIdentityInsert(). This should be executed if DB flavor is sqlserver.");
return;
}
EntityManager entityMgr = getEntityManager();
String identityInsertStr;
if (identityInsert) {
identityInsertStr = "ON";
} else {
identityInsertStr = "OFF";
}
Table table = tClass.getAnnotation(Table.class);
if(table == null) {
throw new NullPointerException("Required annotation `Table` not found");
}
String tableName = table.name();
try {
entityMgr.unwrap(Connection.class).createStatement().execute("SET IDENTITY_INSERT " + tableName + " " + identityInsertStr);
} catch (SQLException e) {
logger.error("Error while settion identity_insert " + identityInsertStr, e);
}
}
public void updateUserIDReference(String paramName,long oldID) {
Table table = tClass.getAnnotation(Table.class);
if(table != null) {
String tableName = table.name();
String updatedValue = tableName.contains(GDS_TABLES) ? "1" : "null";
String query = "update " + tableName + " set " + paramName+"=" + updatedValue + " where " +paramName+"=" + oldID;
int count=getEntityManager().createNativeQuery(query).executeUpdate();
if(count>0){
logger.warn(count + " records updated in table '" + tableName + "' with: set " + paramName + "="+ updatedValue + " where " + paramName + "=" + oldID);
}
}else{
logger.warn("Required annotation `Table` not found");
}
}
public String getDBVersion(){
String dbVersion="Not Available";
int dbFlavor = RangerBizUtil.getDBFlavor();
String query ="SELECT 1";
try{
if(dbFlavor == AppConstants.DB_FLAVOR_MYSQL) {
query="SELECT version()";
dbVersion=(String) getEntityManager().createNativeQuery(query).getSingleResult();
}else if(dbFlavor == AppConstants.DB_FLAVOR_ORACLE){
query="SELECT banner from v$version where rownum<2";
dbVersion = (String)getEntityManager().createNativeQuery(query).getSingleResult();
}else if(dbFlavor == AppConstants.DB_FLAVOR_POSTGRES){
query="SELECT version()";
dbVersion=(String) getEntityManager().createNativeQuery(query).getSingleResult();
}else if(dbFlavor == AppConstants.DB_FLAVOR_SQLSERVER){
query="SELECT @@version";
dbVersion=(String) getEntityManager().createNativeQuery(query).getSingleResult();
}else if(dbFlavor == AppConstants.DB_FLAVOR_SQLANYWHERE){
query="SELECT @@version";
dbVersion=(String) getEntityManager().createNativeQuery(query).getSingleResult();
}
}catch(Exception ex){
logger.error("Error occurred while fetching the DB version.", ex);
}
return dbVersion;
}
}