blob: babbd7a1cdf34207a98ffa100408e929cb5442e9 [file] [log] [blame]
package org.apache.cassandra.stress.util;
/*
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/
import java.net.InetAddress;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.collect.Iterators;
import com.datastax.driver.core.Host;
import com.datastax.driver.core.Metadata;
import org.apache.cassandra.stress.settings.StressSettings;
import org.apache.cassandra.thrift.*;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.thrift.TException;
public class SmartThriftClient implements ThriftClient
{
final String keyspace;
final Metadata metadata;
final StressSettings settings;
final ConcurrentHashMap<InetAddress, ConcurrentLinkedQueue<Client>> cache = new ConcurrentHashMap<>();
final AtomicInteger queryIdCounter = new AtomicInteger();
final ConcurrentHashMap<Integer, String> queryStrings = new ConcurrentHashMap<>();
final ConcurrentHashMap<String, Integer> queryIds = new ConcurrentHashMap<>();
final Set<InetAddress> whiteset;
final List<InetAddress> whitelist;
public SmartThriftClient(StressSettings settings, String keyspace, Metadata metadata)
{
this.metadata = metadata;
this.keyspace = keyspace;
this.settings = settings;
if (!settings.node.isWhiteList)
{
whiteset = null;
whitelist = null;
}
else
{
whiteset = settings.node.resolveAllSpecified();
whitelist = Arrays.asList(whiteset.toArray(new InetAddress[0]));
}
}
private final AtomicInteger roundrobin = new AtomicInteger();
private Integer getId(String query)
{
Integer r;
if ((r = queryIds.get(query)) != null)
return r;
r = queryIdCounter.incrementAndGet();
if (queryIds.putIfAbsent(query, r) == null)
{
queryStrings.put(r, query);
return r;
}
return queryIds.get(query);
}
final class Client
{
final Cassandra.Client client;
final InetAddress server;
final Map<Integer, Integer> queryMap = new HashMap<>();
Client(Cassandra.Client client, InetAddress server)
{
this.client = client;
this.server = server;
}
Integer get(Integer id, boolean cql3) throws TException
{
Integer serverId = queryMap.get(id);
if (serverId != null)
return serverId;
prepare(id, cql3);
return queryMap.get(id);
}
void prepare(Integer id, boolean cql3) throws TException
{
String query;
while ( null == (query = queryStrings.get(id)) ) ;
if (cql3)
{
Integer serverId = client.prepare_cql3_query(ByteBufferUtil.bytes(query), Compression.NONE).itemId;
queryMap.put(id, serverId);
}
else
{
Integer serverId = client.prepare_cql_query(ByteBufferUtil.bytes(query), Compression.NONE).itemId;
queryMap.put(id, serverId);
}
}
}
private Client get(ByteBuffer pk)
{
Set<Host> hosts = metadata.getReplicas(metadata.quote(keyspace), pk);
InetAddress address = null;
if (hosts.size() > 0)
{
int pos = roundrobin.incrementAndGet() % hosts.size();
for (int i = 0 ; address == null && i < hosts.size() ; i++)
{
if (pos < 0)
pos = -pos;
Host host = Iterators.get(hosts.iterator(), (pos + i) % hosts.size());
if (whiteset == null || whiteset.contains(host.getAddress()))
address = host.getAddress();
}
}
if (address == null)
address = whitelist.get(ThreadLocalRandom.current().nextInt(whitelist.size()));
ConcurrentLinkedQueue<Client> q = cache.get(address);
if (q == null)
{
ConcurrentLinkedQueue<Client> newQ = new ConcurrentLinkedQueue<Client>();
q = cache.putIfAbsent(address, newQ);
if (q == null)
q = newQ;
}
Client tclient = q.poll();
if (tclient != null)
return tclient;
return new Client(settings.getRawThriftClient(address.getHostAddress()), address);
}
@Override
public void batch_mutate(Map<ByteBuffer, Map<String, List<Mutation>>> record, ConsistencyLevel consistencyLevel) throws TException
{
for (Map.Entry<ByteBuffer, Map<String, List<Mutation>>> e : record.entrySet())
{
Client client = get(e.getKey());
try
{
client.client.batch_mutate(Collections.singletonMap(e.getKey(), e.getValue()), consistencyLevel);
} finally
{
cache.get(client.server).add(client);
}
}
}
@Override
public List<ColumnOrSuperColumn> get_slice(ByteBuffer key, ColumnParent parent, SlicePredicate predicate, ConsistencyLevel consistencyLevel) throws TException
{
Client client = get(key);
try
{
return client.client.get_slice(key, parent, predicate, consistencyLevel);
} finally
{
cache.get(client.server).add(client);
}
}
@Override
public void insert(ByteBuffer key, ColumnParent column_parent, Column column, ConsistencyLevel consistency_level) throws TException
{
Client client = get(key);
try
{
client.client.insert(key, column_parent, column, consistency_level);
} finally
{
cache.get(client.server).add(client);
}
}
@Override
public CqlResult execute_cql_query(String query, ByteBuffer key, Compression compression) throws TException
{
Client client = get(key);
try
{
return client.client.execute_cql_query(ByteBufferUtil.bytes(query), compression);
} finally
{
cache.get(client.server).add(client);
}
}
@Override
public CqlResult execute_cql3_query(String query, ByteBuffer key, Compression compression, ConsistencyLevel consistency) throws TException
{
Client client = get(key);
try
{
return client.client.execute_cql3_query(ByteBufferUtil.bytes(query), compression, consistency);
} finally
{
cache.get(client.server).add(client);
}
}
@Override
public Integer prepare_cql3_query(String query, Compression compression) throws TException
{
return getId(query);
}
@Override
public CqlResult execute_prepared_cql3_query(int queryId, ByteBuffer key, List<ByteBuffer> values, ConsistencyLevel consistency) throws TException
{
Client client = get(key);
try
{
return client.client.execute_prepared_cql3_query(client.get(queryId, true), values, consistency);
} finally
{
cache.get(client.server).add(client);
}
}
@Override
public Integer prepare_cql_query(String query, Compression compression) throws TException
{
return getId(query);
}
@Override
public CqlResult execute_prepared_cql_query(int queryId, ByteBuffer key, List<ByteBuffer> values) throws TException
{
Client client = get(key);
try
{
return client.client.execute_prepared_cql_query(client.get(queryId, true), values);
} finally
{
cache.get(client.server).add(client);
}
}
@Override
public Map<ByteBuffer, List<ColumnOrSuperColumn>> multiget_slice(List<ByteBuffer> keys, ColumnParent column_parent, SlicePredicate predicate, ConsistencyLevel consistency_level) throws TException
{
throw new UnsupportedOperationException();
}
@Override
public List<KeySlice> get_range_slices(ColumnParent column_parent, SlicePredicate predicate, KeyRange range, ConsistencyLevel consistency_level) throws TException
{
throw new UnsupportedOperationException();
}
@Override
public List<KeySlice> get_indexed_slices(ColumnParent column_parent, IndexClause index_clause, SlicePredicate column_predicate, ConsistencyLevel consistency_level) throws TException
{
throw new UnsupportedOperationException();
}
}