1 package com.atlassian.plugin.servlet.filter;
2
3 import java.io.File;
4 import java.io.IOException;
5
6 import javax.servlet.Filter;
7 import javax.servlet.FilterChain;
8 import javax.servlet.ServletException;
9 import javax.servlet.ServletRequest;
10 import javax.servlet.ServletResponse;
11 import javax.servlet.http.HttpServletRequest;
12 import javax.servlet.http.HttpServletResponse;
13
14 import com.atlassian.plugin.IllegalPluginStateException;
15 import com.atlassian.plugin.Plugin;
16 import com.atlassian.plugin.PluginArtifact;
17 import com.atlassian.plugin.classloader.PluginClassLoader;
18 import com.atlassian.plugin.impl.DefaultDynamicPlugin;
19 import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptor;
20 import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptorBuilder;
21 import com.atlassian.plugin.servlet.filter.FilterTestUtils.FilterAdapter;
22 import com.atlassian.plugin.test.PluginJarBuilder;
23
24 import org.junit.Before;
25 import org.junit.Rule;
26 import org.junit.Test;
27 import org.junit.rules.ExpectedException;
28 import org.junit.runner.RunWith;
29 import org.mockito.Mock;
30 import org.mockito.runners.MockitoJUnitRunner;
31
32 import static com.atlassian.plugin.servlet.filter.FilterTestUtils.emptyChain;
33 import static com.atlassian.plugin.servlet.filter.FilterTestUtils.newList;
34 import static com.atlassian.plugin.test.PluginTestUtils.getFileForResource;
35 import static org.hamcrest.MatcherAssert.assertThat;
36 import static org.hamcrest.Matchers.is;
37 import static org.mockito.Mockito.mock;
38 import static org.mockito.Mockito.when;
39
40 @RunWith (MockitoJUnitRunner.class)
41 public class TestDelegatingPluginFilter
42 {
43 @Rule
44 public ExpectedException expectedException = ExpectedException.none();
45
46 @Mock
47 private HttpServletRequest httpServletRequest;
48
49 @Mock
50 private HttpServletResponse httpServletResponse;
51
52 @Before
53 public void setUp() throws Exception
54 {
55 when(httpServletRequest.getPathInfo()).thenReturn("/servlet");
56 }
57
58 @Test
59 public void testPluginClassLoaderIsThreadContextClassLoaderWhenFiltering() throws Exception
60 {
61 createClassLoaderCheckingFilter("filter").doFilter(httpServletRequest, httpServletResponse, emptyChain);
62 }
63
64 @Test
65 public void testClassLoaderResetDuringFilterChainExecution() throws Exception
66 {
67 final ClassLoader initialClassLoader = Thread.currentThread().getContextClassLoader();
68 final FilterChain chain = new FilterChain()
69 {
70 public void doFilter(final ServletRequest servletRequest, final ServletResponse servletResponse)
71 throws IOException, ServletException
72 {
73 assertThat(Thread.currentThread().getContextClassLoader(), is(initialClassLoader));
74 }
75 };
76 createClassLoaderCheckingFilter("filter").doFilter(httpServletRequest, httpServletResponse, chain);
77 }
78
79 @Test
80 public void testPluginClassLoaderIsThreadContextLoaderWhenFiltersInChainAreFromDifferentPlugins() throws Exception
81 {
82 final Iterable<Filter> filters = newList(
83 createClassLoaderCheckingFilter("filter-1"),
84 createClassLoaderCheckingFilter("filter-2"),
85 createClassLoaderCheckingFilter("filter-3")
86 );
87 final FilterChain chain = new IteratingFilterChain(filters.iterator(), emptyChain);
88 chain.doFilter(httpServletRequest, httpServletResponse);
89 }
90
91 @Test
92 public void testPluginClassLoaderIsRestoredProperlyWhenAnExceptionIsThrownFromFilter() throws Exception
93 {
94 final Iterable<Filter> filters = newList(
95 createClassLoaderCheckingFilter("filter-1"),
96 createClassLoaderCheckingFilter("filter-2"),
97 createExceptionThrowingFilter("exception-filter"),
98 createClassLoaderCheckingFilter("filter-3")
99 );
100 final FilterChain chain = new IteratingFilterChain(filters.iterator(), new FilterChain()
101 {
102 public void doFilter(final ServletRequest request, final ServletResponse response)
103 throws IOException, ServletException
104 {
105 throw new ServletException("Exception should be thrown before reaching here.");
106 }
107 });
108 expectedException.expect(ServletException.class);
109 expectedException.expectMessage("exception-filter");
110 chain.doFilter(httpServletRequest, httpServletResponse);
111 }
112
113 @Test
114 public void pluginCanBeUninstalledFromFilterChain() throws Exception
115 {
116 final Plugin plugin = mock(Plugin.class);
117 when(plugin.getClassLoader()).thenReturn(Thread.currentThread().getContextClassLoader());
118
119 final FilterAdapter filter = new FilterAdapter()
120 {
121 @Override
122 public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
123 throws IOException, ServletException
124 {
125 chain.doFilter(request, response);
126 }
127 };
128
129 final ServletFilterModuleDescriptor servletFilterModuleDescriptor = mock(ServletFilterModuleDescriptor.class);
130 when(servletFilterModuleDescriptor.getPlugin()).thenReturn(plugin);
131 when(servletFilterModuleDescriptor.getModule()).thenReturn(filter);
132
133
134
135 final boolean chainCalled[] = { false };
136 final FilterChain chain = new FilterChain()
137 {
138 public void doFilter(final ServletRequest servletRequest, final ServletResponse servletResponse)
139 throws IOException, ServletException
140 {
141
142
143 when(plugin.getClassLoader()).thenThrow(new IllegalPluginStateException("Plugin Uninstalled"));
144 chainCalled[0] = true;
145 }
146 };
147
148 final DelegatingPluginFilter delegatingPluginFilter = new DelegatingPluginFilter(servletFilterModuleDescriptor);
149
150 delegatingPluginFilter.doFilter(httpServletRequest, httpServletResponse, chain);
151 assertThat(chainCalled[0], is(true));
152 }
153
154 private Filter createClassLoaderCheckingFilter(final String name) throws Exception
155 {
156 final File pluginFile = new PluginJarBuilder()
157 .addFormattedJava("my.SimpleFilter",
158 "package my;" +
159 "import java.io.IOException;" +
160 "import javax.servlet.Filter;" +
161 "import javax.servlet.FilterChain;" +
162 "import javax.servlet.FilterConfig;" +
163 "import javax.servlet.ServletException;" +
164 "import javax.servlet.ServletRequest;" +
165 "import javax.servlet.ServletResponse;" +
166 "" +
167 "public class SimpleFilter implements Filter" +
168 "{" +
169 " String name;" +
170 " public void init(FilterConfig filterConfig) throws ServletException" +
171 " {" +
172 " name = filterConfig.getInitParameter('name');" +
173 " }" +
174 "" +
175 " public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException" +
176 " {" +
177 " response.getWriter().write('entered: ' + name + '\');" +
178 " chain.doFilter(request, response);" +
179 " response.getWriter().write('exiting: ' + name + '\');" +
180 " }" +
181 " public void destroy() {}" +
182 "}")
183 .addFile("atlassian-plugin.xml", getFileForResource("com/atlassian/plugin/servlet/filter/atlassian-plugin-filter.xml"))
184 .build();
185 final PluginClassLoader pluginClassLoader = new PluginClassLoader(pluginFile);
186 final PluginArtifact pluginArtifact = mock(PluginArtifact.class);
187 final Plugin plugin = new DefaultDynamicPlugin(pluginArtifact, pluginClassLoader);
188 final FilterAdapter testFilter = new FilterAdapter()
189 {
190 @Override
191 public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
192 throws IOException, ServletException
193 {
194 assertThat(name + " plugin ClassLoader should be current when entering",
195 Thread.currentThread().getContextClassLoader(), is((ClassLoader) pluginClassLoader));
196 chain.doFilter(request, response);
197 assertThat(name + " plugin ClassLoader should be current when exiting",
198 Thread.currentThread().getContextClassLoader(), is((ClassLoader) pluginClassLoader));
199 }
200 };
201
202 final ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
203 .with(testFilter)
204 .with(plugin)
205 .build();
206
207 return new DelegatingPluginFilter(filterDescriptor);
208 }
209
210 private Filter createExceptionThrowingFilter(final String name)
211 {
212 return new FilterAdapter()
213 {
214 @Override
215 public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
216 throws IOException, ServletException
217 {
218 throw new ServletException(name);
219 }
220 };
221 }
222 }