1 package com.atlassian.plugin.servlet.filter;
2
3 import com.atlassian.plugin.IllegalPluginStateException;
4 import com.atlassian.plugin.Plugin;
5 import com.atlassian.plugin.PluginArtifact;
6 import com.atlassian.plugin.classloader.PluginClassLoader;
7 import com.atlassian.plugin.impl.DefaultDynamicPlugin;
8 import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptor;
9 import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptorBuilder;
10 import com.atlassian.plugin.servlet.filter.FilterTestUtils.FilterAdapter;
11 import com.atlassian.plugin.test.PluginJarBuilder;
12 import org.junit.Before;
13 import org.junit.Rule;
14 import org.junit.Test;
15 import org.junit.rules.ExpectedException;
16 import org.junit.runner.RunWith;
17 import org.mockito.Mock;
18 import org.mockito.junit.MockitoJUnitRunner;
19
20 import javax.servlet.Filter;
21 import javax.servlet.FilterChain;
22 import javax.servlet.ServletException;
23 import javax.servlet.ServletRequest;
24 import javax.servlet.ServletResponse;
25 import javax.servlet.http.HttpServletRequest;
26 import javax.servlet.http.HttpServletResponse;
27 import java.io.File;
28 import java.io.IOException;
29
30 import static com.atlassian.plugin.servlet.filter.FilterTestUtils.emptyChain;
31 import static com.atlassian.plugin.test.PluginTestUtils.getFileForResource;
32 import static com.google.common.collect.Lists.newArrayList;
33 import static org.hamcrest.MatcherAssert.assertThat;
34 import static org.hamcrest.Matchers.is;
35 import static org.mockito.Mockito.mock;
36 import static org.mockito.Mockito.when;
37
38 @RunWith(MockitoJUnitRunner.Silent.class)
39 public class TestDelegatingPluginFilter {
40 @Rule
41 public final ExpectedException expectedException = ExpectedException.none();
42
43 @Mock
44 private HttpServletRequest httpServletRequest;
45 @Mock
46 private HttpServletResponse httpServletResponse;
47
48 @Before
49 public void setUp() {
50 when(httpServletRequest.getPathInfo()).thenReturn("/servlet");
51 }
52
53 @Test
54 public void testPluginClassLoaderIsThreadContextClassLoaderWhenFiltering() throws Exception {
55 createClassLoaderCheckingFilter("filter").doFilter(httpServletRequest, httpServletResponse, emptyChain);
56 }
57
58 @Test
59 public void testClassLoaderResetDuringFilterChainExecution() throws Exception {
60 final ClassLoader initialClassLoader = Thread.currentThread().getContextClassLoader();
61 final FilterChain chain = (servletRequest, servletResponse) ->
62 assertThat(Thread.currentThread().getContextClassLoader(), is(initialClassLoader));
63 createClassLoaderCheckingFilter("filter").doFilter(httpServletRequest, httpServletResponse, chain);
64 }
65
66 @Test
67 public void testPluginClassLoaderIsThreadContextLoaderWhenFiltersInChainAreFromDifferentPlugins() throws Exception {
68 final Iterable<Filter> filters = newArrayList(
69 createClassLoaderCheckingFilter("filter-1"),
70 createClassLoaderCheckingFilter("filter-2"),
71 createClassLoaderCheckingFilter("filter-3")
72 );
73 final FilterChain chain = new IteratingFilterChain(filters.iterator(), emptyChain);
74 chain.doFilter(httpServletRequest, httpServletResponse);
75 }
76
77 @Test
78 public void testPluginClassLoaderIsRestoredProperlyWhenAnExceptionIsThrownFromFilter() throws Exception {
79 final Iterable<Filter> filters = newArrayList(
80 createClassLoaderCheckingFilter("filter-1"),
81 createClassLoaderCheckingFilter("filter-2"),
82 createExceptionThrowingFilter("exception-filter"),
83 createClassLoaderCheckingFilter("filter-3")
84 );
85 final FilterChain chain = new IteratingFilterChain(filters.iterator(), (request, response) -> {
86 throw new ServletException("Exception should be thrown before reaching here.");
87 });
88 expectedException.expect(ServletException.class);
89 expectedException.expectMessage("exception-filter");
90 chain.doFilter(httpServletRequest, httpServletResponse);
91 }
92
93 @Test
94 public void pluginCanBeUninstalledFromFilterChain() throws Exception {
95 final Plugin plugin = mock(Plugin.class);
96 when(plugin.getClassLoader()).thenReturn(Thread.currentThread().getContextClassLoader());
97
98 final FilterAdapter filter = new FilterAdapter() {
99 @Override
100 public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
101 throws IOException, ServletException {
102 chain.doFilter(request, response);
103 }
104 };
105
106 final ServletFilterModuleDescriptor servletFilterModuleDescriptor = mock(ServletFilterModuleDescriptor.class);
107 when(servletFilterModuleDescriptor.getPlugin()).thenReturn(plugin);
108 when(servletFilterModuleDescriptor.getModule()).thenReturn(filter);
109
110
111
112 final boolean chainCalled[] = {false};
113 final FilterChain chain = (servletRequest, servletResponse) -> {
114
115
116 when(plugin.getClassLoader()).thenThrow(new IllegalPluginStateException("Plugin Uninstalled"));
117 chainCalled[0] = true;
118 };
119
120 final DelegatingPluginFilter delegatingPluginFilter = new DelegatingPluginFilter(servletFilterModuleDescriptor);
121
122 delegatingPluginFilter.doFilter(httpServletRequest, httpServletResponse, chain);
123 assertThat(chainCalled[0], is(true));
124 }
125
126 private Filter createClassLoaderCheckingFilter(final String name) throws Exception {
127 final File pluginFile = new PluginJarBuilder()
128 .addFormattedJava("my.SimpleFilter",
129 "package my;" +
130 "import java.io.IOException;" +
131 "import javax.servlet.Filter;" +
132 "import javax.servlet.FilterChain;" +
133 "import javax.servlet.FilterConfig;" +
134 "import javax.servlet.ServletException;" +
135 "import javax.servlet.ServletRequest;" +
136 "import javax.servlet.ServletResponse;" +
137 "" +
138 "public class SimpleFilter implements Filter" +
139 "{" +
140 " String name;" +
141 " public void init(FilterConfig filterConfig) throws ServletException" +
142 " {" +
143 " name = filterConfig.getInitParameter('name');" +
144 " }" +
145 "" +
146 " public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException" +
147 " {" +
148 " response.getWriter().write('entered: ' + name + '\');" +
149 " chain.doFilter(request, response);" +
150 " response.getWriter().write('exiting: ' + name + '\');" +
151 " }" +
152 " public void destroy() {}" +
153 "}")
154 .addFile("atlassian-plugin.xml", getFileForResource("com/atlassian/plugin/servlet/filter/atlassian-plugin-filter.xml"))
155 .build();
156 final PluginClassLoader pluginClassLoader = new PluginClassLoader(pluginFile);
157 final PluginArtifact pluginArtifact = mock(PluginArtifact.class);
158 final Plugin plugin = new DefaultDynamicPlugin(pluginArtifact, pluginClassLoader);
159 final FilterAdapter testFilter = new FilterAdapter() {
160 @Override
161 public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
162 throws IOException, ServletException {
163 assertThat(name + " plugin ClassLoader should be current when entering",
164 Thread.currentThread().getContextClassLoader(), is(pluginClassLoader));
165 chain.doFilter(request, response);
166 assertThat(name + " plugin ClassLoader should be current when exiting",
167 Thread.currentThread().getContextClassLoader(), is(pluginClassLoader));
168 }
169 };
170
171 final ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
172 .with(testFilter)
173 .with(plugin)
174 .build();
175
176 return new DelegatingPluginFilter(filterDescriptor);
177 }
178
179 private Filter createExceptionThrowingFilter(final String name) {
180 return new FilterAdapter() {
181 @Override
182 public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
183 throws ServletException {
184 throw new ServletException(name);
185 }
186 };
187 }
188 }