1 package com.atlassian.scheduler.core.util;
2
3 import org.junit.Test;
4 import org.mockito.invocation.InvocationOnMock;
5 import org.mockito.stubbing.Answer;
6
7 import java.io.IOException;
8 import java.io.ObjectStreamClass;
9 import java.io.Serializable;
10 import java.util.Map;
11
12 import static com.atlassian.scheduler.core.Constants.BYTES_DEADF00D;
13 import static com.atlassian.scheduler.core.Constants.BYTES_EMPTY_MAP;
14 import static com.atlassian.scheduler.core.Constants.BYTES_NULL;
15 import static com.atlassian.scheduler.core.Constants.BYTES_PARAMETERS;
16 import static com.atlassian.scheduler.core.Constants.EMPTY_MAP;
17 import static com.atlassian.scheduler.core.Constants.PARAMETERS;
18 import static org.hamcrest.Matchers.instanceOf;
19 import static org.junit.Assert.assertEquals;
20 import static org.junit.Assert.assertSame;
21 import static org.junit.Assert.assertThat;
22 import static org.junit.Assert.fail;
23 import static org.mockito.Mockito.mock;
24 import static org.mockito.Mockito.verify;
25 import static org.mockito.Mockito.when;
26
27
28
29
30 @SuppressWarnings("ConstantConditions")
31 public class ClassLoaderAwareObjectInputStreamTest {
32 @Test
33 public void testNulls() throws Exception {
34 assertConstructorThrows(IllegalArgumentException.class, null, null);
35 assertConstructorThrows(IllegalArgumentException.class, null, BYTES_PARAMETERS);
36 assertConstructorThrows(IllegalArgumentException.class, getClass().getClassLoader(), null);
37 }
38
39 @Test
40 public void testInvalidSerializationData() {
41 assertConstructorThrows(IOException.class, getClass().getClassLoader(), BYTES_DEADF00D);
42 }
43
44 @Test
45 public void testNullMap() throws IOException, ClassNotFoundException {
46 assertReadObject(null, getClass().getClassLoader(), BYTES_NULL);
47 }
48
49 @Test
50 public void testEmptyMap() throws IOException, ClassNotFoundException {
51 assertReadObject(EMPTY_MAP, getClass().getClassLoader(), BYTES_EMPTY_MAP);
52 }
53
54 @Test
55 public void testParameters() throws IOException, ClassNotFoundException {
56 assertReadObject(PARAMETERS, getClass().getClassLoader(), BYTES_PARAMETERS);
57 }
58
59 @Test
60 public void testResolveClassYieldsFirstExceptionWhenNotFound() throws Exception {
61
62 final ClassNotFoundException cnfe = new ClassNotFoundException("Expected");
63 assertResolveClassThrows(cnfe, classLoaderThatThrows(cnfe), "foo.bar");
64 }
65
66 @SuppressWarnings({"rawtypes", "unchecked"})
67 @Test
68 public void testResolveClassSucceedsUsingOurClassLoader() throws Exception {
69 final ClassLoader classLoader = mock(ClassLoader.class);
70 when(classLoader.loadClass(String.class.getName())).thenReturn((Class) String.class);
71
72 assertResolveClass(String.class, classLoader, String.class.getName());
73
74
75 verify(classLoader).loadClass(String.class.getName());
76 }
77
78 @Test
79 public void testResolveClassSucceedsByFallbackForNormalClasses() throws Exception {
80
81 final ClassNotFoundException cnfe = new ClassNotFoundException("Expected");
82 assertResolveClass(String.class, classLoaderThatThrows(cnfe), String.class.getName());
83 }
84
85
86 @Test
87 public void testResolveClassSucceedsByFallbackForPrimitives() throws Exception {
88
89 final ClassNotFoundException cnfe = new ClassNotFoundException("Expected");
90 assertResolveClass(Integer.TYPE, classLoaderThatThrows(cnfe), "int");
91 }
92
93 @Test
94 public void testResolveClassDoesNotAttemptFallbackOnSerializationErrors() throws Exception {
95
96 final IOException ioe = new IOException("Expected");
97 assertResolveClassThrows(ioe, classLoaderThatThrows(ioe), String.class.getName());
98 }
99
100
101 static ClassLoader classLoaderThatThrows(final Throwable e) {
102 return mock(ClassLoader.class, new Answer() {
103 @Override
104 public Object answer(final InvocationOnMock invocation) throws Throwable {
105 throw e;
106 }
107 });
108 }
109
110 static ObjectStreamClass desc(String className) {
111 final ObjectStreamClass osc = mock(ObjectStreamClass.class);
112 when(osc.getName()).thenReturn(className);
113 return osc;
114 }
115
116 static void assertConstructorThrows(final Class<? extends Exception> expected,
117 final ClassLoader classLoader, final byte[] bytes) {
118 try {
119 final ClassLoaderAwareObjectInputStream is = new ClassLoaderAwareObjectInputStream(classLoader, bytes);
120 is.close();
121 try {
122 fail("Expected construction to fail with " + expected.getName() + ", but it succeeded!");
123 } finally {
124 is.close();
125 }
126 } catch (Exception ex) {
127 assertThat(ex, instanceOf(expected));
128 }
129 }
130
131 private static void assertReadObject(final Map<String, Serializable> expected,
132 final ClassLoader classLoader, final byte[] bytes) throws IOException, ClassNotFoundException {
133 final ClassLoaderAwareObjectInputStream is = new ClassLoaderAwareObjectInputStream(classLoader, bytes);
134 try {
135 assertEquals(expected, is.readObject());
136 } finally {
137 is.close();
138 }
139 }
140
141 private static void assertResolveClass(final Class<?> expected,
142 final ClassLoader classLoader, final String className) throws IOException, ClassNotFoundException {
143 final ClassLoaderAwareObjectInputStream is = new ClassLoaderAwareObjectInputStream(classLoader, BYTES_EMPTY_MAP);
144 try {
145 assertEquals(expected, is.resolveClass(desc(className)));
146 } finally {
147 is.close();
148 }
149 }
150
151 private static void assertResolveClassThrows(final Exception expected,
152 final ClassLoader classLoader, final String className) throws IOException, ClassNotFoundException {
153 final ClassLoaderAwareObjectInputStream is = new ClassLoaderAwareObjectInputStream(classLoader, BYTES_EMPTY_MAP);
154 try {
155 throw new AssertionError("Expected resolveClass to throw " + expected + ", but got " +
156 is.resolveClass(desc(className)).getName());
157 } catch (Exception ex) {
158 assertSame(expected, ex);
159 } finally {
160 is.close();
161 }
162 }
163 }
164