blob: 9c917198c26ed18a3db8bc1f0ced05d52294eb8f [file] [log] [blame]
/**
*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.yoko.rmi.impl;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.logging.Logger;
import org.omg.CORBA.MARSHAL;
public final class CopyState {
static Logger logger = Logger.getLogger(CopyState.class.getName());
private IdentityHashMap copied;
private IdentityHashMap recursionResolverMap;
private TypeRepository rep;
private static Object recursionCheck = new Object();
public CopyState(TypeRepository rep) {
this.rep = rep;
this.copied = new IdentityHashMap();
this.recursionResolverMap = new IdentityHashMap();
}
public void registerRecursion(CopyRecursionResolver resolver) {
Object key = resolver.orig;
CopyRecursionResolver orig = (CopyRecursionResolver) recursionResolverMap
.get(key);
resolver.next = orig;
recursionResolverMap.put(key, resolver);
logger.fine("registering recursion resolver " + resolver + " for "
+ key.getClass() + "@" + System.identityHashCode(key));
}
public void put(Object orig, Object copy) {
Object old = copied.put(orig, copy);
if (old == recursionCheck) {
for (CopyRecursionResolver resolver = (CopyRecursionResolver) recursionResolverMap
.get(orig); resolver != null; resolver = resolver.next) {
logger
.fine("invoking " + resolver + " for "
+ orig.getClass() + "@"
+ System.identityHashCode(orig) + " ===> "
+ copy.getClass() + "@"
+ System.identityHashCode(copy));
resolver.resolve(copy);
}
}
}
private String spaces(int c) {
StringBuffer sb = new StringBuffer();
while (c-- != 0) {
sb.append(' ');
}
return sb.toString();
}
private int idx;
public Object copy(Object orig) throws CopyRecursionException {
if (orig == null)
return orig;
Object copy = copied.get(orig);
if (copy != null) {
if (copy == recursionCheck) {
logger.fine("throwign CopyRecursion for " + orig.getClass()
+ "@" + System.identityHashCode(orig));
throw new CopyRecursionException(this, orig);
}
return copy;
}
Class origClass = orig.getClass();
logger.fine("[" + hashCode() + "]" + spaces(idx++)
+ "copying instance of " + origClass);
TypeDescriptor desc = rep.getDescriptor(origClass);
copied.put(orig, recursionCheck);
copy = desc.copyObject(orig, this);
copied.put(orig, copy);
logger.fine(spaces(--idx) + "=> " + copy);
return copy;
}
public ObjectWriter createObjectWriter(final java.io.Serializable obj) {
try {
return (ObjectWriter) java.security.AccessController
.doPrivileged(new java.security.PrivilegedExceptionAction() {
public Object run() throws IOException {
return new Writer(obj);
}
});
} catch (java.security.PrivilegedActionException e) {
throw (InternalError)new InternalError(e.getMessage()).initCause(e);
}
}
class Writer extends ObjectWriter {
List contents = new ArrayList();
private void enqueue(Object o) {
contents.add(o);
}
Writer(java.io.Serializable obj) throws java.io.IOException {
super(obj);
}
public ObjectReader getObjectReader(Object val) {
try {
// register mapping from old value to new
put(object, val);
return new Reader((Serializable) val, contents);
} catch (IOException ex) {
throw (MARSHAL)new MARSHAL(ex.getMessage()).initCause(ex);
}
}
public void write(int val) throws java.io.IOException {
beforeWriteData();
enqueue(new Byte((byte) val));
}
public void write(byte[] val) throws java.io.IOException {
write(val, 0, val.length);
}
public void write(byte[] arr, int off, int len)
throws java.io.IOException {
if (arr == null || arr.length == 0) {
beforeWriteData();
enqueue(null);
} else if (off == 0 && len == arr.length) {
beforeWriteData();
enqueue(arr);
} else {
byte[] data = new byte[len];
System.arraycopy(arr, off, data, 0, len);
beforeWriteData();
enqueue(data);
}
}
public void writeBoolean(boolean val) throws java.io.IOException {
beforeWriteData();
enqueue(new Boolean(val));
}
public void writeByte(int val) throws java.io.IOException {
beforeWriteData();
enqueue(new Byte((byte) val));
}
public void writeShort(int val) throws java.io.IOException {
beforeWriteData();
enqueue(new Short((short) val));
}
public void writeChar(int val) throws java.io.IOException {
beforeWriteData();
enqueue(new Character((char) val));
}
public void writeInt(int val) throws java.io.IOException {
beforeWriteData();
enqueue(new Integer(val));
}
public void writeLong(long val) throws java.io.IOException {
beforeWriteData();
enqueue(new Long(val));
}
public void writeFloat(float val) throws java.io.IOException {
beforeWriteData();
enqueue(new Float(val));
}
public void writeDouble(double val) throws java.io.IOException {
beforeWriteData();
enqueue(new Double(val));
}
public void writeBytes(java.lang.String val) throws java.io.IOException {
beforeWriteData();
byte[] data = val.getBytes();
enqueue(data);
}
public void writeChars(java.lang.String val) throws java.io.IOException {
beforeWriteData();
char[] data = val.toCharArray();
enqueue(data);
}
public void writeUTF(java.lang.String val) throws java.io.IOException {
beforeWriteData();
enqueue(val);
}
public void writeObjectOverride(final Object obj) {
Object copy;
try {
copy = copy(obj);
} catch (CopyRecursionException rec) {
copy = recursionCheck;
// save position of newly added element
final int idx = contents.size();
registerRecursion(new CopyRecursionResolver(obj) {
public void resolve(Object newValue) {
contents.set(idx, newValue);
}
});
}
enqueue(copy);
}
public void writeValueObject(Object obj) {
writeObjectOverride(obj);
}
public void writeCorbaObject(Object obj) {
writeObjectOverride(obj);
}
public void writeRemoteObject(Object obj) {
writeObjectOverride(obj);
}
public void writeAny(Object obj) {
writeObjectOverride(obj);
}
/*
* (non-Javadoc)
*
* @see org.apache.yoko.rmi.impl.ObjectWriter#_startValue(java.lang.String)
*/
protected void _startValue(String rep_id) throws IOException {
}
/*
* (non-Javadoc)
*
* @see org.apache.yoko.rmi.impl.ObjectWriter#_endValue()
*/
protected void _endValue() throws IOException {
}
protected void _nullValue() throws IOException {
}
}
public class Reader extends ObjectReaderBase {
int cpos;
List contents;
private Object dequeue() {
Object result = contents.get(cpos++);
if (result == recursionCheck) {
throw new IllegalStateException("recursion was not resolved?");
} else {
return result;
}
}
private void undequeue(Object o) {
contents.set(--cpos, o);
}
Reader(java.io.Serializable obj, List data) throws IOException {
super(obj);
contents = data;
cpos = 0;
}
public Object readAbstractObject() {
return dequeue();
}
public Object readAny() {
return dequeue();
}
public Object readValueObject() {
return dequeue();
}
public Object readValueObject(Class<?> clz) {
return dequeue();
}
public org.omg.CORBA.Object readCorbaObject(Class<?> type) {
return (org.omg.CORBA.Object) dequeue();
}
public java.rmi.Remote readRemoteObject(Class<?> type) {
return (java.rmi.Remote) dequeue();
}
public void readFully(byte[] arr, int off, int val)
throws java.io.IOException {
while (val > 0) {
int bytes = read(arr, off, val);
off += bytes;
val -= bytes;
}
}
public int read(byte[] arr, int off, int len)
throws java.io.IOException {
Object obj = dequeue();
if (obj instanceof byte[]) {
byte[] data = (byte[]) obj;
if (data.length <= len) {
System.arraycopy(data, 0, arr, off, data.length);
return data.length;
} else {
// copy only portion of the data
System.arraycopy(data, 0, arr, off, len);
byte[] newdata = new byte[data.length - len];
System.arraycopy(data, len, newdata, 0, data.length - len);
undequeue(newdata);
return len;
}
} else if (obj instanceof Byte) {
Byte b = (Byte) obj;
arr[off] = b.byteValue();
return 1;
} else {
throw new IOException("stream contents is not byte[], it is: "
+ obj.getClass().getName());
}
}
public int skipBytes(int len) throws java.io.IOException {
byte[] data = new byte[len];
readFully(data);
return len;
}
public boolean readBoolean() throws java.io.IOException {
try {
return ((Boolean) dequeue()).booleanValue();
} catch (ClassCastException ex) {
IOException iox = new IOException(ex.getMessage());
iox.initCause(ex);
throw iox;
}
}
public byte readByte() throws java.io.IOException {
try {
return ((Byte) dequeue()).byteValue();
} catch (ClassCastException ex) {
IOException iox = new IOException(ex.getMessage());
iox.initCause(ex);
throw iox;
}
}
public short readShort() throws java.io.IOException {
try {
return ((Short) dequeue()).shortValue();
} catch (ClassCastException ex) {
IOException iox = new IOException(ex.getMessage());
iox.initCause(ex);
throw iox;
}
}
public char readChar() throws java.io.IOException {
try {
return ((Character) dequeue()).charValue();
} catch (ClassCastException ex) {
IOException iox = new IOException(ex.getMessage());
iox.initCause(ex);
throw iox;
}
}
public int readInt() throws java.io.IOException {
try {
return ((Integer) dequeue()).intValue();
} catch (ClassCastException ex) {
IOException iox = new IOException(ex.getMessage());
iox.initCause(ex);
throw iox;
}
}
public long readLong() throws java.io.IOException {
try {
return ((Long) dequeue()).longValue();
} catch (ClassCastException ex) {
IOException iox = new IOException(ex.getMessage());
iox.initCause(ex);
throw iox;
}
}
public float readFloat() throws java.io.IOException {
try {
return ((Float) dequeue()).floatValue();
} catch (ClassCastException ex) {
IOException iox = new IOException(ex.getMessage());
iox.initCause(ex);
throw iox;
}
}
public double readDouble() throws java.io.IOException {
try {
return ((Double) dequeue()).doubleValue();
} catch (ClassCastException ex) {
IOException iox = new IOException(ex.getMessage());
iox.initCause(ex);
throw iox;
}
}
/** @deprecated */
public java.lang.String readLine() throws java.io.IOException {
throw new InternalError("cannot use readline");
}
public java.lang.String readUTF() throws java.io.IOException {
return (String) dequeue();
}
@Override
protected void _startValue() {
}
@Override
protected void _endValue() {
}
}
}