| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.commons.math4.neuralnet; |
| |
| import java.io.Serializable; |
| import java.io.ObjectInputStream; |
| import java.util.NoSuchElementException; |
| import java.util.List; |
| import java.util.ArrayList; |
| import java.util.Set; |
| import java.util.HashSet; |
| import java.util.Collection; |
| import java.util.Iterator; |
| import java.util.Comparator; |
| import java.util.Collections; |
| import java.util.Map; |
| import java.util.concurrent.ConcurrentHashMap; |
| import java.util.concurrent.atomic.AtomicLong; |
| import java.util.stream.Collectors; |
| |
| import org.apache.commons.math4.neuralnet.internal.NeuralNetException; |
| |
| /** |
| * Neural network, composed of {@link Neuron} instances and the links |
| * between them. |
| * |
| * Although updating a neuron's state is thread-safe, modifying the |
| * network's topology (adding or removing links) is not. |
| * |
| * @since 3.3 |
| */ |
| public class Network |
| implements Iterable<Neuron>, |
| Serializable { |
| /** Serializable. */ |
| private static final long serialVersionUID = 20130207L; |
| /** Neurons. */ |
| private final ConcurrentHashMap<Long, Neuron> neuronMap |
| = new ConcurrentHashMap<>(); |
| /** Next available neuron identifier. */ |
| private final AtomicLong nextId; |
| /** Neuron's features set size. */ |
| private final int featureSize; |
| /** Links. */ |
| private final ConcurrentHashMap<Long, Set<Long>> linkMap |
| = new ConcurrentHashMap<>(); |
| |
| /** |
| * Comparator that prescribes an order of the neurons according |
| * to the increasing order of their identifier. |
| */ |
| public static class NeuronIdentifierComparator |
| implements Comparator<Neuron>, |
| Serializable { |
| /** Version identifier. */ |
| private static final long serialVersionUID = 20130207L; |
| |
| /** {@inheritDoc} */ |
| @Override |
| public int compare(Neuron a, |
| Neuron b) { |
| final long aId = a.getIdentifier(); |
| final long bId = b.getIdentifier(); |
| return aId < bId ? -1 : |
| aId > bId ? 1 : 0; |
| } |
| } |
| |
| /** |
| * Constructor with restricted access, solely used for deserialization. |
| * |
| * @param nextId Next available identifier. |
| * @param featureSize Number of features. |
| * @param neuronList Neurons. |
| * @param neighbourIdList Links associated to each of the neurons in |
| * {@code neuronList}. |
| * @throws IllegalStateException if an inconsistency is detected |
| * (which probably means that the serialized form has been corrupted). |
| */ |
| Network(long nextId, |
| int featureSize, |
| Neuron[] neuronList, |
| long[][] neighbourIdList) { |
| final int numNeurons = neuronList.length; |
| if (numNeurons != neighbourIdList.length) { |
| throw new IllegalStateException(); |
| } |
| |
| for (int i = 0; i < numNeurons; i++) { |
| final Neuron n = neuronList[i]; |
| final long id = n.getIdentifier(); |
| if (id >= nextId) { |
| throw new IllegalStateException(); |
| } |
| neuronMap.put(id, n); |
| linkMap.put(id, new HashSet<Long>()); |
| } |
| |
| for (int i = 0; i < numNeurons; i++) { |
| final long aId = neuronList[i].getIdentifier(); |
| final Set<Long> aLinks = linkMap.get(aId); |
| for (final Long bId : neighbourIdList[i]) { |
| if (neuronMap.get(bId) == null) { |
| throw new IllegalStateException(); |
| } |
| addLinkToLinkSet(aLinks, bId); |
| } |
| } |
| |
| this.nextId = new AtomicLong(nextId); |
| this.featureSize = featureSize; |
| } |
| |
| /** |
| * @param initialIdentifier Identifier for the first neuron that |
| * will be added to this network. |
| * @param featureSize Size of the neuron's features. |
| */ |
| public Network(long initialIdentifier, |
| int featureSize) { |
| nextId = new AtomicLong(initialIdentifier); |
| this.featureSize = featureSize; |
| } |
| |
| /** |
| * Performs a deep copy of this instance. |
| * Upon return, the copied and original instances will be independent: |
| * Updating one will not affect the other. |
| * |
| * @return a new instance with the same state as this instance. |
| * @since 3.6 |
| */ |
| public synchronized Network copy() { |
| final Network copy = new Network(nextId.get(), |
| featureSize); |
| |
| |
| for (final Map.Entry<Long, Neuron> e : neuronMap.entrySet()) { |
| copy.neuronMap.put(e.getKey(), e.getValue().copy()); |
| } |
| |
| for (final Map.Entry<Long, Set<Long>> e : linkMap.entrySet()) { |
| copy.linkMap.put(e.getKey(), new HashSet<>(e.getValue())); |
| } |
| |
| return copy; |
| } |
| |
| /** |
| * {@inheritDoc} |
| */ |
| @Override |
| public Iterator<Neuron> iterator() { |
| return neuronMap.values().iterator(); |
| } |
| |
| /** |
| * Creates a list of the neurons, sorted in a custom order. |
| * |
| * @param comparator {@link Comparator} used for sorting the neurons. |
| * @return a list of neurons, sorted in the order prescribed by the |
| * given {@code comparator}. |
| * @see NeuronIdentifierComparator |
| */ |
| public Collection<Neuron> getNeurons(Comparator<Neuron> comparator) { |
| final List<Neuron> neurons = new ArrayList<>(neuronMap.values()); |
| |
| Collections.sort(neurons, comparator); |
| |
| return neurons; |
| } |
| |
| /** |
| * Creates a neuron and assigns it a unique identifier. |
| * |
| * @param features Initial values for the neuron's features. |
| * @return the neuron's identifier. |
| * @throws IllegalArgumentException if the length of {@code features} |
| * is different from the expected size (as set by the |
| * {@link #Network(long,int) constructor}). |
| */ |
| public long createNeuron(double[] features) { |
| if (features.length != featureSize) { |
| throw new NeuralNetException(NeuralNetException.SIZE_MISMATCH, |
| features.length, featureSize); |
| } |
| |
| final long id = createNextId(); |
| neuronMap.put(id, new Neuron(id, features)); |
| linkMap.put(id, new HashSet<Long>()); |
| return id; |
| } |
| |
| /** |
| * Deletes a neuron. |
| * Links from all neighbours to the removed neuron will also be |
| * {@link #deleteLink(Neuron,Neuron) deleted}. |
| * |
| * @param neuron Neuron to be removed from this network. |
| * @throws NoSuchElementException if {@code n} does not belong to |
| * this network. |
| */ |
| public void deleteNeuron(Neuron neuron) { |
| // Delete links to from neighbours. |
| getNeighbours(neuron).forEach(neighbour -> deleteLink(neighbour, neuron)); |
| |
| // Remove neuron. |
| neuronMap.remove(neuron.getIdentifier()); |
| } |
| |
| /** |
| * Gets the size of the neurons' features set. |
| * |
| * @return the size of the features set. |
| */ |
| public int getFeaturesSize() { |
| return featureSize; |
| } |
| |
| /** |
| * Adds a link from neuron {@code a} to neuron {@code b}. |
| * Note: the link is not bi-directional; if a bi-directional link is |
| * required, an additional call must be made with {@code a} and |
| * {@code b} exchanged in the argument list. |
| * |
| * @param a Neuron. |
| * @param b Neuron. |
| * @throws NoSuchElementException if the neurons do not exist in the |
| * network. |
| */ |
| public void addLink(Neuron a, |
| Neuron b) { |
| // Check that the neurons belong to this network. |
| final long aId = a.getIdentifier(); |
| if (a != getNeuron(aId)) { |
| throw new NoSuchElementException(Long.toString(aId)); |
| } |
| final long bId = b.getIdentifier(); |
| if (b != getNeuron(bId)) { |
| throw new NoSuchElementException(Long.toString(bId)); |
| } |
| |
| // Add link from "a" to "b". |
| addLinkToLinkSet(linkMap.get(aId), bId); |
| } |
| |
| /** |
| * Adds a link to neuron {@code id} in given {@code linkSet}. |
| * Note: no check verifies that the identifier indeed belongs |
| * to this network. |
| * |
| * @param linkSet Neuron identifier. |
| * @param id Neuron identifier. |
| */ |
| private void addLinkToLinkSet(Set<Long> linkSet, |
| long id) { |
| linkSet.add(id); |
| } |
| |
| /** |
| * Deletes the link between neurons {@code a} and {@code b}. |
| * |
| * @param a Neuron. |
| * @param b Neuron. |
| * @throws NoSuchElementException if the neurons do not exist in the |
| * network. |
| */ |
| public void deleteLink(Neuron a, |
| Neuron b) { |
| // Check that the neurons belong to this network. |
| final long aId = a.getIdentifier(); |
| if (a != getNeuron(aId)) { |
| throw new NoSuchElementException(Long.toString(aId)); |
| } |
| final long bId = b.getIdentifier(); |
| if (b != getNeuron(bId)) { |
| throw new NoSuchElementException(Long.toString(bId)); |
| } |
| |
| // Delete link from "a" to "b". |
| deleteLinkFromLinkSet(linkMap.get(aId), bId); |
| } |
| |
| /** |
| * Deletes a link to neuron {@code id} in given {@code linkSet}. |
| * Note: no check verifies that the identifier indeed belongs |
| * to this network. |
| * |
| * @param linkSet Neuron identifier. |
| * @param id Neuron identifier. |
| */ |
| private void deleteLinkFromLinkSet(Set<Long> linkSet, |
| long id) { |
| linkSet.remove(id); |
| } |
| |
| /** |
| * Retrieves the neuron with the given (unique) {@code id}. |
| * |
| * @param id Identifier. |
| * @return the neuron associated with the given {@code id}. |
| * @throws NoSuchElementException if the neuron does not exist in the |
| * network. |
| */ |
| public Neuron getNeuron(long id) { |
| final Neuron n = neuronMap.get(id); |
| if (n == null) { |
| throw new NoSuchElementException(Long.toString(id)); |
| } |
| return n; |
| } |
| |
| /** |
| * Retrieves the neurons in the neighbourhood of any neuron in the |
| * {@code neurons} list. |
| * @param neurons Neurons for which to retrieve the neighbours. |
| * @return the list of neighbours. |
| * @see #getNeighbours(Iterable,Iterable) |
| */ |
| public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons) { |
| return getNeighbours(neurons, null); |
| } |
| |
| /** |
| * Retrieves the neurons in the neighbourhood of any neuron in the |
| * {@code neurons} list. |
| * The {@code exclude} list allows to retrieve the "concentric" |
| * neighbourhoods by removing the neurons that belong to the inner |
| * "circles". |
| * |
| * @param neurons Neurons for which to retrieve the neighbours. |
| * @param exclude Neurons to exclude from the returned list. |
| * Can be {@code null}. |
| * @return the list of neighbours. |
| */ |
| public Collection<Neuron> getNeighbours(Iterable<Neuron> neurons, |
| Iterable<Neuron> exclude) { |
| final Set<Long> idList = new HashSet<>(); |
| neurons.forEach(n -> idList.addAll(linkMap.get(n.getIdentifier()))); |
| |
| if (exclude != null) { |
| exclude.forEach(n -> idList.remove(n.getIdentifier())); |
| } |
| |
| return idList.stream().map(this::getNeuron).collect(Collectors.toList()); |
| } |
| |
| /** |
| * Retrieves the neighbours of the given neuron. |
| * |
| * @param neuron Neuron for which to retrieve the neighbours. |
| * @return the list of neighbours. |
| * @see #getNeighbours(Neuron,Iterable) |
| */ |
| public Collection<Neuron> getNeighbours(Neuron neuron) { |
| return getNeighbours(neuron, null); |
| } |
| |
| /** |
| * Retrieves the neighbours of the given neuron. |
| * |
| * @param neuron Neuron for which to retrieve the neighbours. |
| * @param exclude Neurons to exclude from the returned list. |
| * Can be {@code null}. |
| * @return the list of neighbours. |
| */ |
| public Collection<Neuron> getNeighbours(Neuron neuron, |
| Iterable<Neuron> exclude) { |
| final Set<Long> idList = linkMap.get(neuron.getIdentifier()); |
| if (exclude != null) { |
| for (final Neuron n : exclude) { |
| idList.remove(n.getIdentifier()); |
| } |
| } |
| |
| final List<Neuron> neuronList = new ArrayList<>(); |
| for (final Long id : idList) { |
| neuronList.add(getNeuron(id)); |
| } |
| |
| return neuronList; |
| } |
| |
| /** |
| * Creates a neuron identifier. |
| * |
| * @return a value that will serve as a unique identifier. |
| */ |
| private Long createNextId() { |
| return nextId.getAndIncrement(); |
| } |
| |
| /** |
| * Prevents proxy bypass. |
| * |
| * @param in Input stream. |
| */ |
| private void readObject(ObjectInputStream in) { |
| throw new IllegalStateException(); |
| } |
| |
| /** |
| * Custom serialization. |
| * |
| * @return the proxy instance that will be actually serialized. |
| */ |
| private Object writeReplace() { |
| final Neuron[] neuronList = neuronMap.values().toArray(new Neuron[0]); |
| final long[][] neighbourIdList = new long[neuronList.length][]; |
| |
| for (int i = 0; i < neuronList.length; i++) { |
| final Collection<Neuron> neighbours = getNeighbours(neuronList[i]); |
| final long[] neighboursId = new long[neighbours.size()]; |
| int count = 0; |
| for (final Neuron n : neighbours) { |
| neighboursId[count] = n.getIdentifier(); |
| ++count; |
| } |
| neighbourIdList[i] = neighboursId; |
| } |
| |
| return new SerializationProxy(nextId.get(), |
| featureSize, |
| neuronList, |
| neighbourIdList); |
| } |
| |
| /** |
| * Serialization. |
| */ |
| private static class SerializationProxy implements Serializable { |
| /** Serializable. */ |
| private static final long serialVersionUID = 20130207L; |
| /** Next identifier. */ |
| private final long nextId; |
| /** Number of features. */ |
| private final int featureSize; |
| /** Neurons. */ |
| private final Neuron[] neuronList; |
| /** Links. */ |
| private final long[][] neighbourIdList; |
| |
| /** |
| * @param nextId Next available identifier. |
| * @param featureSize Number of features. |
| * @param neuronList Neurons. |
| * @param neighbourIdList Links associated to each of the neurons in |
| * {@code neuronList}. |
| */ |
| SerializationProxy(long nextId, |
| int featureSize, |
| Neuron[] neuronList, |
| long[][] neighbourIdList) { |
| this.nextId = nextId; |
| this.featureSize = featureSize; |
| this.neuronList = neuronList; |
| this.neighbourIdList = neighbourIdList; |
| } |
| |
| /** |
| * Custom serialization. |
| * |
| * @return the {@link Network} for which this instance is the proxy. |
| */ |
| private Object readResolve() { |
| return new Network(nextId, |
| featureSize, |
| neuronList, |
| neighbourIdList); |
| } |
| } |
| } |