View Javadoc
1   package com.atlassian.plugin.servlet.filter;
2   
3   import com.atlassian.plugin.IllegalPluginStateException;
4   import com.atlassian.plugin.Plugin;
5   import com.atlassian.plugin.PluginArtifact;
6   import com.atlassian.plugin.classloader.PluginClassLoader;
7   import com.atlassian.plugin.impl.DefaultDynamicPlugin;
8   import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptor;
9   import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptorBuilder;
10  import com.atlassian.plugin.servlet.filter.FilterTestUtils.FilterAdapter;
11  import com.atlassian.plugin.test.PluginJarBuilder;
12  import org.junit.Before;
13  import org.junit.Rule;
14  import org.junit.Test;
15  import org.junit.rules.ExpectedException;
16  import org.junit.runner.RunWith;
17  import org.mockito.Mock;
18  import org.mockito.junit.MockitoJUnitRunner;
19  
20  import javax.servlet.Filter;
21  import javax.servlet.FilterChain;
22  import javax.servlet.ServletException;
23  import javax.servlet.ServletRequest;
24  import javax.servlet.ServletResponse;
25  import javax.servlet.http.HttpServletRequest;
26  import javax.servlet.http.HttpServletResponse;
27  import java.io.File;
28  import java.io.IOException;
29  
30  import static com.atlassian.plugin.servlet.filter.FilterTestUtils.emptyChain;
31  import static com.atlassian.plugin.test.PluginTestUtils.getFileForResource;
32  import static com.google.common.collect.Lists.newArrayList;
33  import static org.hamcrest.MatcherAssert.assertThat;
34  import static org.hamcrest.Matchers.is;
35  import static org.mockito.Mockito.mock;
36  import static org.mockito.Mockito.when;
37  
38  @RunWith(MockitoJUnitRunner.Silent.class)
39  public class TestDelegatingPluginFilter {
40      @Rule
41      public final ExpectedException expectedException = ExpectedException.none();
42  
43      @Mock
44      private HttpServletRequest httpServletRequest;
45      @Mock
46      private HttpServletResponse httpServletResponse;
47  
48      @Before
49      public void setUp() {
50          when(httpServletRequest.getPathInfo()).thenReturn("/servlet");
51      }
52  
53      @Test
54      public void testPluginClassLoaderIsThreadContextClassLoaderWhenFiltering() throws Exception {
55          createClassLoaderCheckingFilter("filter").doFilter(httpServletRequest, httpServletResponse, emptyChain);
56      }
57  
58      @Test
59      public void testClassLoaderResetDuringFilterChainExecution() throws Exception {
60          final ClassLoader initialClassLoader = Thread.currentThread().getContextClassLoader();
61          final FilterChain chain = (servletRequest, servletResponse) ->
62                  assertThat(Thread.currentThread().getContextClassLoader(), is(initialClassLoader));
63          createClassLoaderCheckingFilter("filter").doFilter(httpServletRequest, httpServletResponse, chain);
64      }
65  
66      @Test
67      public void testPluginClassLoaderIsThreadContextLoaderWhenFiltersInChainAreFromDifferentPlugins() throws Exception {
68          final Iterable<Filter> filters = newArrayList(
69                  createClassLoaderCheckingFilter("filter-1"),
70                  createClassLoaderCheckingFilter("filter-2"),
71                  createClassLoaderCheckingFilter("filter-3")
72          );
73          final FilterChain chain = new IteratingFilterChain(filters.iterator(), emptyChain);
74          chain.doFilter(httpServletRequest, httpServletResponse);
75      }
76  
77      @Test
78      public void testPluginClassLoaderIsRestoredProperlyWhenAnExceptionIsThrownFromFilter() throws Exception {
79          final Iterable<Filter> filters = newArrayList(
80                  createClassLoaderCheckingFilter("filter-1"),
81                  createClassLoaderCheckingFilter("filter-2"),
82                  createExceptionThrowingFilter("exception-filter"),
83                  createClassLoaderCheckingFilter("filter-3")
84          );
85          final FilterChain chain = new IteratingFilterChain(filters.iterator(), (request, response) -> {
86              throw new ServletException("Exception should be thrown before reaching here.");
87          });
88          expectedException.expect(ServletException.class);
89          expectedException.expectMessage("exception-filter");
90          chain.doFilter(httpServletRequest, httpServletResponse);
91      }
92  
93      @Test
94      public void pluginCanBeUninstalledFromFilterChain() throws Exception {
95          final Plugin plugin = mock(Plugin.class);
96          when(plugin.getClassLoader()).thenReturn(Thread.currentThread().getContextClassLoader());
97  
98          final FilterAdapter filter = new FilterAdapter() {
99              @Override
100             public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
101                     throws IOException, ServletException {
102                 chain.doFilter(request, response);
103             }
104         };
105 
106         final ServletFilterModuleDescriptor servletFilterModuleDescriptor = mock(ServletFilterModuleDescriptor.class);
107         when(servletFilterModuleDescriptor.getPlugin()).thenReturn(plugin);
108         when(servletFilterModuleDescriptor.getModule()).thenReturn(filter);
109 
110         // This boolean is a poor man's spy to check that the meat of the test is called. There doesn't seem to be a
111         // natural mock to verify, and adding a spy for it feels perverse when the code is just here.
112         final boolean chainCalled[] = {false};
113         final FilterChain chain = (servletRequest, servletResponse) -> {
114             // Pretend this chain uninstalls the plugin, which means, among other things, that you
115             // can't get the classloader any more
116             when(plugin.getClassLoader()).thenThrow(new IllegalPluginStateException("Plugin Uninstalled"));
117             chainCalled[0] = true;
118         };
119 
120         final DelegatingPluginFilter delegatingPluginFilter = new DelegatingPluginFilter(servletFilterModuleDescriptor);
121 
122         delegatingPluginFilter.doFilter(httpServletRequest, httpServletResponse, chain);
123         assertThat(chainCalled[0], is(true));
124     }
125 
126     private Filter createClassLoaderCheckingFilter(final String name) throws Exception {
127         final File pluginFile = new PluginJarBuilder()
128                 .addFormattedJava("my.SimpleFilter",
129                         "package my;" +
130                                 "import java.io.IOException;" +
131                                 "import javax.servlet.Filter;" +
132                                 "import javax.servlet.FilterChain;" +
133                                 "import javax.servlet.FilterConfig;" +
134                                 "import javax.servlet.ServletException;" +
135                                 "import javax.servlet.ServletRequest;" +
136                                 "import javax.servlet.ServletResponse;" +
137                                 "" +
138                                 "public class SimpleFilter implements Filter" +
139                                 "{" +
140                                 "    String name;" +
141                                 "    public void init(FilterConfig filterConfig) throws ServletException" +
142                                 "    {" +
143                                 "        name = filterConfig.getInitParameter('name');" +
144                                 "    }" +
145                                 "" +
146                                 "    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException" +
147                                 "    {" +
148                                 "        response.getWriter().write('entered: ' + name + '\');" +
149                                 "        chain.doFilter(request, response);" +
150                                 "        response.getWriter().write('exiting: ' + name + '\');" +
151                                 "    }" +
152                                 "    public void destroy() {}" +
153                                 "}")
154                 .addFile("atlassian-plugin.xml", getFileForResource("com/atlassian/plugin/servlet/filter/atlassian-plugin-filter.xml"))
155                 .build();
156         final PluginClassLoader pluginClassLoader = new PluginClassLoader(pluginFile);
157         final PluginArtifact pluginArtifact = mock(PluginArtifact.class);
158         final Plugin plugin = new DefaultDynamicPlugin(pluginArtifact, pluginClassLoader);
159         final FilterAdapter testFilter = new FilterAdapter() {
160             @Override
161             public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
162                     throws IOException, ServletException {
163                 assertThat(name + " plugin ClassLoader should be current when entering",
164                         Thread.currentThread().getContextClassLoader(), is(pluginClassLoader));
165                 chain.doFilter(request, response);
166                 assertThat(name + " plugin ClassLoader should be current when exiting",
167                         Thread.currentThread().getContextClassLoader(), is(pluginClassLoader));
168             }
169         };
170 
171         final ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
172                 .with(testFilter)
173                 .with(plugin)
174                 .build();
175 
176         return new DelegatingPluginFilter(filterDescriptor);
177     }
178 
179     private Filter createExceptionThrowingFilter(final String name) {
180         return new FilterAdapter() {
181             @Override
182             public void doFilter(final ServletRequest request, final ServletResponse response, final FilterChain chain)
183                     throws ServletException {
184                 throw new ServletException(name);
185             }
186         };
187     }
188 }