| package org.apache.cassandra.stress.util; |
| /* |
| * |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| * |
| */ |
| |
| |
| import java.net.InetAddress; |
| import java.nio.ByteBuffer; |
| import java.util.*; |
| import java.util.concurrent.ConcurrentHashMap; |
| import java.util.concurrent.ConcurrentLinkedQueue; |
| import java.util.concurrent.ThreadLocalRandom; |
| import java.util.concurrent.atomic.AtomicInteger; |
| |
| import com.google.common.collect.Iterators; |
| |
| import com.datastax.driver.core.Host; |
| import com.datastax.driver.core.Metadata; |
| import org.apache.cassandra.stress.settings.StressSettings; |
| import org.apache.cassandra.thrift.*; |
| import org.apache.cassandra.utils.ByteBufferUtil; |
| import org.apache.thrift.TException; |
| |
| public class SmartThriftClient implements ThriftClient |
| { |
| |
| final String keyspace; |
| final Metadata metadata; |
| final StressSettings settings; |
| final ConcurrentHashMap<InetAddress, ConcurrentLinkedQueue<Client>> cache = new ConcurrentHashMap<>(); |
| |
| final AtomicInteger queryIdCounter = new AtomicInteger(); |
| final ConcurrentHashMap<Integer, String> queryStrings = new ConcurrentHashMap<>(); |
| final ConcurrentHashMap<String, Integer> queryIds = new ConcurrentHashMap<>(); |
| final Set<InetAddress> whiteset; |
| final List<InetAddress> whitelist; |
| |
| public SmartThriftClient(StressSettings settings, String keyspace, Metadata metadata) |
| { |
| this.metadata = metadata; |
| this.keyspace = keyspace; |
| this.settings = settings; |
| if (!settings.node.isWhiteList) |
| { |
| whiteset = null; |
| whitelist = null; |
| } |
| else |
| { |
| whiteset = settings.node.resolveAllSpecified(); |
| whitelist = Arrays.asList(whiteset.toArray(new InetAddress[0])); |
| } |
| } |
| |
| private final AtomicInteger roundrobin = new AtomicInteger(); |
| |
| private Integer getId(String query) |
| { |
| Integer r; |
| if ((r = queryIds.get(query)) != null) |
| return r; |
| r = queryIdCounter.incrementAndGet(); |
| if (queryIds.putIfAbsent(query, r) == null) |
| { |
| queryStrings.put(r, query); |
| return r; |
| } |
| return queryIds.get(query); |
| } |
| |
| final class Client |
| { |
| final Cassandra.Client client; |
| final InetAddress server; |
| final Map<Integer, Integer> queryMap = new HashMap<>(); |
| |
| Client(Cassandra.Client client, InetAddress server) |
| { |
| this.client = client; |
| this.server = server; |
| } |
| |
| Integer get(Integer id, boolean cql3) throws TException |
| { |
| Integer serverId = queryMap.get(id); |
| if (serverId != null) |
| return serverId; |
| prepare(id, cql3); |
| return queryMap.get(id); |
| } |
| |
| void prepare(Integer id, boolean cql3) throws TException |
| { |
| String query; |
| while ( null == (query = queryStrings.get(id)) ) ; |
| if (cql3) |
| { |
| Integer serverId = client.prepare_cql3_query(ByteBufferUtil.bytes(query), Compression.NONE).itemId; |
| queryMap.put(id, serverId); |
| } |
| else |
| { |
| Integer serverId = client.prepare_cql_query(ByteBufferUtil.bytes(query), Compression.NONE).itemId; |
| queryMap.put(id, serverId); |
| } |
| } |
| } |
| |
| private Client get(ByteBuffer pk) |
| { |
| Set<Host> hosts = metadata.getReplicas(metadata.quote(keyspace), pk); |
| InetAddress address = null; |
| if (hosts.size() > 0) |
| { |
| int pos = roundrobin.incrementAndGet() % hosts.size(); |
| for (int i = 0 ; address == null && i < hosts.size() ; i++) |
| { |
| if (pos < 0) |
| pos = -pos; |
| Host host = Iterators.get(hosts.iterator(), (pos + i) % hosts.size()); |
| if (whiteset == null || whiteset.contains(host.getAddress())) |
| address = host.getAddress(); |
| } |
| } |
| if (address == null) |
| address = whitelist.get(ThreadLocalRandom.current().nextInt(whitelist.size())); |
| ConcurrentLinkedQueue<Client> q = cache.get(address); |
| if (q == null) |
| { |
| ConcurrentLinkedQueue<Client> newQ = new ConcurrentLinkedQueue<Client>(); |
| q = cache.putIfAbsent(address, newQ); |
| if (q == null) |
| q = newQ; |
| } |
| Client tclient = q.poll(); |
| if (tclient != null) |
| return tclient; |
| return new Client(settings.getRawThriftClient(address.getHostAddress()), address); |
| } |
| |
| @Override |
| public void batch_mutate(Map<ByteBuffer, Map<String, List<Mutation>>> record, ConsistencyLevel consistencyLevel) throws TException |
| { |
| for (Map.Entry<ByteBuffer, Map<String, List<Mutation>>> e : record.entrySet()) |
| { |
| Client client = get(e.getKey()); |
| try |
| { |
| client.client.batch_mutate(Collections.singletonMap(e.getKey(), e.getValue()), consistencyLevel); |
| } finally |
| { |
| cache.get(client.server).add(client); |
| } |
| } |
| } |
| |
| @Override |
| public List<ColumnOrSuperColumn> get_slice(ByteBuffer key, ColumnParent parent, SlicePredicate predicate, ConsistencyLevel consistencyLevel) throws TException |
| { |
| Client client = get(key); |
| try |
| { |
| return client.client.get_slice(key, parent, predicate, consistencyLevel); |
| } finally |
| { |
| cache.get(client.server).add(client); |
| } |
| } |
| |
| @Override |
| public void insert(ByteBuffer key, ColumnParent column_parent, Column column, ConsistencyLevel consistency_level) throws TException |
| { |
| Client client = get(key); |
| try |
| { |
| client.client.insert(key, column_parent, column, consistency_level); |
| } finally |
| { |
| cache.get(client.server).add(client); |
| } |
| } |
| |
| @Override |
| public CqlResult execute_cql_query(String query, ByteBuffer key, Compression compression) throws TException |
| { |
| Client client = get(key); |
| try |
| { |
| return client.client.execute_cql_query(ByteBufferUtil.bytes(query), compression); |
| } finally |
| { |
| cache.get(client.server).add(client); |
| } |
| } |
| |
| @Override |
| public CqlResult execute_cql3_query(String query, ByteBuffer key, Compression compression, ConsistencyLevel consistency) throws TException |
| { |
| Client client = get(key); |
| try |
| { |
| return client.client.execute_cql3_query(ByteBufferUtil.bytes(query), compression, consistency); |
| } finally |
| { |
| cache.get(client.server).add(client); |
| } |
| } |
| |
| @Override |
| public Integer prepare_cql3_query(String query, Compression compression) throws TException |
| { |
| return getId(query); |
| } |
| |
| @Override |
| public CqlResult execute_prepared_cql3_query(int queryId, ByteBuffer key, List<ByteBuffer> values, ConsistencyLevel consistency) throws TException |
| { |
| Client client = get(key); |
| try |
| { |
| return client.client.execute_prepared_cql3_query(client.get(queryId, true), values, consistency); |
| } finally |
| { |
| cache.get(client.server).add(client); |
| } |
| } |
| |
| @Override |
| public Integer prepare_cql_query(String query, Compression compression) throws TException |
| { |
| return getId(query); |
| } |
| |
| @Override |
| public CqlResult execute_prepared_cql_query(int queryId, ByteBuffer key, List<ByteBuffer> values) throws TException |
| { |
| Client client = get(key); |
| try |
| { |
| return client.client.execute_prepared_cql_query(client.get(queryId, true), values); |
| } finally |
| { |
| cache.get(client.server).add(client); |
| } |
| } |
| |
| @Override |
| public Map<ByteBuffer, List<ColumnOrSuperColumn>> multiget_slice(List<ByteBuffer> keys, ColumnParent column_parent, SlicePredicate predicate, ConsistencyLevel consistency_level) throws TException |
| { |
| throw new UnsupportedOperationException(); |
| } |
| |
| @Override |
| public List<KeySlice> get_range_slices(ColumnParent column_parent, SlicePredicate predicate, KeyRange range, ConsistencyLevel consistency_level) throws TException |
| { |
| throw new UnsupportedOperationException(); |
| } |
| |
| @Override |
| public List<KeySlice> get_indexed_slices(ColumnParent column_parent, IndexClause index_clause, SlicePredicate column_predicate, ConsistencyLevel consistency_level) throws TException |
| { |
| throw new UnsupportedOperationException(); |
| } |
| |
| } |