1 package com.atlassian.plugin.servlet;
2
3 import com.atlassian.plugin.ModuleDescriptor;
4 import com.atlassian.plugin.Plugin;
5 import com.atlassian.plugin.PluginAccessor;
6 import com.atlassian.plugin.PluginController;
7 import com.atlassian.plugin.PluginException;
8 import com.atlassian.plugin.event.PluginEventManager;
9 import com.atlassian.plugin.event.events.PluginFrameworkShutdownEvent;
10 import com.atlassian.plugin.event.events.PluginFrameworkShuttingDownEvent;
11 import com.atlassian.plugin.event.events.PluginFrameworkStartedEvent;
12 import com.atlassian.plugin.servlet.DefaultServletModuleManager.LazyLoadedServletReference;
13 import com.atlassian.plugin.servlet.descriptors.ServletContextListenerModuleDescriptor;
14 import com.atlassian.plugin.servlet.descriptors.ServletContextListenerModuleDescriptorBuilder;
15 import com.atlassian.plugin.servlet.descriptors.ServletContextParamDescriptorBuilder;
16 import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptor;
17 import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptorBuilder;
18 import com.atlassian.plugin.servlet.descriptors.ServletModuleDescriptor;
19 import com.atlassian.plugin.servlet.descriptors.ServletModuleDescriptorBuilder;
20 import com.atlassian.plugin.servlet.filter.DelegatingPluginFilter;
21 import com.atlassian.plugin.servlet.filter.FilterLocation;
22 import com.atlassian.plugin.servlet.filter.FilterTestUtils.FilterAdapter;
23 import com.atlassian.plugin.servlet.filter.FilterTestUtils.SoundOffFilter;
24 import com.atlassian.plugin.servlet.filter.IteratingFilterChain;
25 import com.atlassian.plugin.servlet.util.DefaultPathMapper;
26 import com.atlassian.plugin.servlet.util.PathMapper;
27 import com.atlassian.plugin.test.CapturedLogging;
28 import com.google.common.collect.ImmutableList;
29 import com.google.common.collect.Iterables;
30 import com.google.common.collect.Iterators;
31 import org.dom4j.Element;
32 import org.dom4j.dom.DOMElement;
33 import org.junit.Before;
34 import org.junit.Rule;
35 import org.junit.Test;
36 import org.junit.rules.ExpectedException;
37 import org.junit.runner.RunWith;
38 import org.mockito.ArgumentCaptor;
39 import org.mockito.Mock;
40 import org.mockito.Mockito;
41 import org.mockito.invocation.InvocationOnMock;
42 import org.mockito.junit.MockitoJUnitRunner;
43 import org.mockito.stubbing.Answer;
44
45 import javax.servlet.DispatcherType;
46 import javax.servlet.Filter;
47 import javax.servlet.FilterChain;
48 import javax.servlet.FilterConfig;
49 import javax.servlet.ServletConfig;
50 import javax.servlet.ServletContext;
51 import javax.servlet.ServletContextEvent;
52 import javax.servlet.ServletContextListener;
53 import javax.servlet.ServletException;
54 import javax.servlet.ServletRequest;
55 import javax.servlet.ServletResponse;
56 import javax.servlet.http.HttpServlet;
57 import javax.servlet.http.HttpServletRequest;
58 import javax.servlet.http.HttpServletResponse;
59 import java.io.IOException;
60 import java.util.Arrays;
61 import java.util.Collection;
62 import java.util.Collections;
63 import java.util.HashSet;
64 import java.util.LinkedList;
65 import java.util.List;
66 import java.util.Optional;
67 import java.util.Set;
68 import java.util.Vector;
69 import java.util.concurrent.atomic.AtomicReference;
70
71 import static com.atlassian.plugin.servlet.filter.FilterLocation.BEFORE_DISPATCH;
72 import static com.atlassian.plugin.servlet.filter.FilterTestUtils.emptyChain;
73 import static com.atlassian.plugin.test.CapturedLogging.didLogWarn;
74 import static com.atlassian.plugin.test.Matchers.isElement;
75 import static com.google.common.collect.Lists.newArrayList;
76 import static java.util.Collections.emptyList;
77 import static java.util.Collections.enumeration;
78 import static javax.servlet.DispatcherType.ASYNC;
79 import static javax.servlet.DispatcherType.ERROR;
80 import static javax.servlet.DispatcherType.FORWARD;
81 import static javax.servlet.DispatcherType.INCLUDE;
82 import static javax.servlet.DispatcherType.REQUEST;
83 import static org.hamcrest.Matchers.contains;
84 import static org.hamcrest.Matchers.instanceOf;
85 import static org.hamcrest.Matchers.not;
86 import static org.hamcrest.collection.IsMapContaining.hasKey;
87 import static org.junit.Assert.assertEquals;
88 import static org.junit.Assert.assertNull;
89 import static org.junit.Assert.assertSame;
90 import static org.junit.Assert.assertThat;
91 import static org.junit.Assert.assertTrue;
92 import static org.junit.Assert.fail;
93 import static org.mockito.ArgumentMatchers.any;
94 import static org.mockito.ArgumentMatchers.same;
95 import static org.mockito.Mockito.doAnswer;
96 import static org.mockito.Mockito.mock;
97 import static org.mockito.Mockito.verify;
98 import static org.mockito.Mockito.when;
99
100 @RunWith(MockitoJUnitRunner.Silent.class)
101 public class TestDefaultServletModuleManager {
102 @Rule
103 public final CapturedLogging capturedLogging = new CapturedLogging(DefaultServletModuleManager.class);
104 @Rule
105 public final ExpectedException expectedException = ExpectedException.none();
106
107 @Mock
108 private PluginEventManager mockPluginEventManager;
109 @Mock
110 private PathMapper mockServletMapper;
111 @Mock
112 private PathMapper mockFilterMapper;
113 @Mock
114 private FilterFactory mockFilterFactory;
115 @Mock
116 private PluginController pluginController;
117
118 private DefaultServletModuleManager servletModuleManager;
119
120 @Before
121 public void setUp() {
122 servletModuleManager = new DefaultServletModuleManager(mockPluginEventManager,
123 new DefaultPathMapper(),
124 new DefaultPathMapper(),
125 new FilterFactory());
126 final PluginFrameworkStartedEvent pluginFrameworkStartedEvent = new PluginFrameworkStartedEvent(pluginController, mock(PluginAccessor.class));
127 servletModuleManager.onPluginFrameworkStartingEvent(pluginFrameworkStartedEvent);
128 }
129
130 @Test
131 public void testGettingServletWithSimplePath() throws Exception {
132 final ServletContext mockServletContext = mock(ServletContext.class);
133 when(mockServletContext.getInitParameterNames()).thenReturn(Iterators.asEnumeration(Collections.emptyIterator()));
134 final ServletConfig mockServletConfig = mock(ServletConfig.class);
135 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
136
137 final HttpServletRequest mockHttpServletRequest = mock(HttpServletRequest.class);
138 when(mockHttpServletRequest.getPathInfo()).thenReturn("/servlet");
139 final HttpServletResponse mockHttpServletResponse = mock(HttpServletResponse.class);
140
141 TestHttpServlet servlet = new TestHttpServlet();
142 ServletModuleDescriptor descriptor = new ServletModuleDescriptorBuilder()
143 .with(servlet)
144 .withPath("/servlet")
145 .with(servletModuleManager)
146 .build();
147
148 servletModuleManager.addServletModule(descriptor);
149
150 HttpServlet wrappedServlet = servletModuleManager.getServlet("/servlet", mockServletConfig);
151 wrappedServlet.service(mockHttpServletRequest, mockHttpServletResponse);
152 assertTrue(servlet.serviceCalled);
153 }
154
155 @Test
156 public void testGettingServlet() {
157 getServletTwice(false);
158 }
159
160 private void getServletTwice(boolean expectNewServletEachCall) {
161 DefaultServletModuleManager mgr = new DefaultServletModuleManager(mockPluginEventManager);
162
163 AtomicReference<HttpServlet> servletRef = new AtomicReference<>();
164 TestHttpServlet firstServlet = new TestHttpServlet();
165 servletRef.set(firstServlet);
166 ServletModuleDescriptor descriptor = new ServletModuleDescriptorBuilder()
167 .withFactory(ObjectFactories.createMutable(servletRef))
168 .withPath("/servlet")
169 .with(mgr)
170 .build();
171
172 final ServletConfig mockServletConfig = mock(ServletConfig.class);
173 final ServletContext mockServletContext = mock(ServletContext.class);
174 when(mockServletContext.getInitParameterNames()).thenReturn(enumeration(emptyList()));
175 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
176
177 assertTrue(firstServlet == ((DelegatingPluginServlet) mgr.getServlet(descriptor, mockServletConfig)).getDelegatingServlet());
178
179 TestHttpServlet secondServlet = new TestHttpServlet();
180 servletRef.set(secondServlet);
181 HttpServlet expectedServlet = (expectNewServletEachCall ? secondServlet : firstServlet);
182 assertTrue(expectedServlet == ((DelegatingPluginServlet) mgr.getServlet(descriptor, mockServletConfig)).getDelegatingServlet());
183 }
184
185 @Test
186 public void testGettingFilter() {
187 getFilterTwice(false);
188 }
189
190 private void getFilterTwice(boolean expectNewFilterEachCall) {
191 DefaultServletModuleManager mgr = new DefaultServletModuleManager(mockPluginEventManager);
192
193 AtomicReference<Filter> filterRef = new AtomicReference<>();
194 TestHttpFilter firstFilter = new TestHttpFilter();
195 filterRef.set(firstFilter);
196 ServletFilterModuleDescriptor descriptor = new ServletFilterModuleDescriptorBuilder()
197 .withFactory(ObjectFactories.createMutable(filterRef))
198 .withPath("/servlet")
199 .with(mgr)
200 .build();
201
202 final FilterConfig mockFilterConfig = mock(FilterConfig.class);
203 final ServletContext mockServletContext = mock(ServletContext.class);
204 when(mockServletContext.getInitParameterNames()).thenReturn(enumeration(emptyList()));
205 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
206
207 assertTrue(firstFilter == ((DelegatingPluginFilter) mgr.getFilter(descriptor, mockFilterConfig)).getDelegatingFilter());
208
209 TestHttpFilter secondFilter = new TestHttpFilter();
210 filterRef.set(secondFilter);
211 Filter expectedFilter = (expectNewFilterEachCall ? secondFilter : firstFilter);
212 assertTrue(expectedFilter == ((DelegatingPluginFilter) mgr.getFilter(descriptor, mockFilterConfig)).getDelegatingFilter());
213 }
214
215 @Test
216 public void testGettingServletWithException() throws Exception {
217 ServletContext mockServletContext = mock(ServletContext.class);
218 when(mockServletContext.getInitParameterNames()).thenReturn(enumeration(emptyList()));
219 ServletConfig mockServletConfig = mock(ServletConfig.class);
220 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
221
222 HttpServletRequest mockHttpServletRequest = mock(HttpServletRequest.class);
223 when(mockHttpServletRequest.getPathInfo()).thenReturn("/servlet");
224
225 TestHttpServletWithException servlet = new TestHttpServletWithException();
226 ServletModuleDescriptor descriptor = new ServletModuleDescriptorBuilder()
227 .with(servlet)
228 .withPath("/servlet")
229 .with(servletModuleManager)
230 .build();
231
232 servletModuleManager.addServletModule(descriptor);
233
234 assertNull(servletModuleManager.getServlet("/servlet", mockServletConfig));
235 }
236
237 @Test
238 public void testGettingFilterWithException() throws Exception {
239 ServletContext mockServletContext = mock(ServletContext.class);
240 when(mockServletContext.getInitParameterNames()).thenReturn(enumeration(emptyList()));
241 FilterConfig mockFilterConfig = mock(FilterConfig.class);
242 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
243
244 HttpServletRequest mockHttpServletRequest = mock(HttpServletRequest.class);
245 when(mockHttpServletRequest.getPathInfo()).thenReturn("/servlet");
246
247 TestFilterWithException servlet = new TestFilterWithException();
248 ServletFilterModuleDescriptor descriptor = new ServletFilterModuleDescriptorBuilder()
249 .with(servlet)
250 .withPath("/servlet")
251 .with(servletModuleManager)
252 .at(FilterLocation.AFTER_ENCODING)
253 .build();
254
255 servletModuleManager.addFilterModule(descriptor);
256
257 assertEquals(false, servletModuleManager.getFilters(FilterLocation.AFTER_ENCODING, "/servlet", mockFilterConfig, REQUEST).iterator().hasNext());
258 }
259
260 @Test
261 public void testGettingServletWithComplexPath() throws Exception {
262 ServletContext mockServletContext = mock(ServletContext.class);
263 when(mockServletContext.getInitParameterNames()).thenReturn(enumeration(emptyList()));
264 ServletConfig mockServletConfig = mock(ServletConfig.class);
265 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
266
267 HttpServletRequest mockHttpServletRequest = mock(HttpServletRequest.class);
268 when(mockHttpServletRequest.getPathInfo()).thenReturn("/servlet");
269 HttpServletResponse mockHttpServletResponse = mock(HttpServletResponse.class);
270
271 TestHttpServlet servlet = new TestHttpServlet();
272 ServletModuleDescriptor descriptor = new ServletModuleDescriptorBuilder()
273 .with(servlet)
274 .withPath("/servlet/*")
275 .with(servletModuleManager)
276 .build();
277
278 servletModuleManager.addServletModule(descriptor);
279
280 HttpServlet wrappedServlet = servletModuleManager.getServlet("/servlet/this/is/a/test", mockServletConfig);
281 wrappedServlet.service(mockHttpServletRequest, mockHttpServletResponse);
282 assertTrue(servlet.serviceCalled);
283 }
284
285 @Test
286 public void testMultipleFiltersWithTheSameComplexPath() throws ServletException {
287 ServletContext mockServletContext = mock(ServletContext.class);
288 FilterConfig mockFilterConfig = mock(FilterConfig.class);
289 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
290 when(mockServletContext.getInitParameterNames()).thenReturn(new Vector().elements());
291 Plugin plugin = new PluginBuilder().build();
292 ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
293 .with(plugin)
294 .withKey("foo")
295 .with(new FilterAdapter())
296 .withPath("/foo/*")
297 .with(servletModuleManager)
298 .build();
299
300 ServletFilterModuleDescriptor filterDescriptor2 = new ServletFilterModuleDescriptorBuilder()
301 .with(plugin)
302 .withKey("bar")
303 .with(new FilterAdapter())
304 .withPath("/foo/*")
305 .with(servletModuleManager)
306 .build();
307 servletModuleManager.addFilterModule(filterDescriptor);
308 servletModuleManager.addFilterModule(filterDescriptor2);
309
310 servletModuleManager.removeFilterModule(filterDescriptor);
311 assertTrue(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo/jim", mockFilterConfig, REQUEST).iterator().hasNext());
312 }
313
314 @Test
315 public void testMultipleFiltersWithTheSameSimplePath() throws ServletException {
316 ServletContext mockServletContext = mock(ServletContext.class);
317 FilterConfig mockFilterConfig = mock(FilterConfig.class);
318 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
319 when(mockServletContext.getInitParameterNames()).thenReturn(new Vector().elements());
320 Plugin plugin = new PluginBuilder().build();
321 ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
322 .with(plugin)
323 .withKey("foo")
324 .with(new FilterAdapter())
325 .withPath("/foo")
326 .with(servletModuleManager)
327 .build();
328
329 ServletFilterModuleDescriptor filterDescriptor2 = new ServletFilterModuleDescriptorBuilder()
330 .with(plugin)
331 .withKey("bar")
332 .with(new FilterAdapter())
333 .withPath("/foo")
334 .with(servletModuleManager)
335 .build();
336 servletModuleManager.addFilterModule(filterDescriptor);
337 servletModuleManager.addFilterModule(filterDescriptor2);
338
339 servletModuleManager.removeFilterModule(filterDescriptor);
340 assertTrue(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig, REQUEST).iterator().hasNext());
341 }
342
343 @Test
344 public void testPluginContextInitParamsGetMerged() throws Exception {
345 ServletContext mockServletContext = mock(ServletContext.class);
346 when(mockServletContext.getInitParameterNames()).thenReturn(enumeration(emptyList()));
347 ServletConfig mockServletConfig = mock(ServletConfig.class);
348 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
349
350 Plugin plugin = new PluginBuilder().build();
351
352 new ServletContextParamDescriptorBuilder()
353 .with(plugin)
354 .withParam("param.name", "param.value")
355 .build();
356
357
358 ServletModuleDescriptor servletDescriptor = new ServletModuleDescriptorBuilder()
359 .with(plugin)
360 .with(new TestHttpServlet() {
361 @Override
362 public void init(ServletConfig servletConfig) {
363 assertEquals("param.value", servletConfig.getServletContext().getInitParameter("param.name"));
364 }
365 })
366 .withPath("/servlet")
367 .with(servletModuleManager)
368 .build();
369 servletModuleManager.addServletModule(servletDescriptor);
370
371 servletModuleManager.getServlet("/servlet", mockServletConfig);
372 }
373
374 @Test
375 public void testServletListenerContextInitializedIsCalled() throws Exception {
376 ServletContext mockServletContext = mock(ServletContext.class);
377 when(mockServletContext.getInitParameterNames()).thenReturn(enumeration(emptyList()));
378 ServletConfig mockServletConfig = mock(ServletConfig.class);
379 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
380
381 final TestServletContextListener listener = new TestServletContextListener();
382
383 Plugin plugin = new PluginBuilder().build();
384
385 new ServletContextListenerModuleDescriptorBuilder()
386 .with(plugin)
387 .with(listener)
388 .build();
389
390 ServletModuleDescriptor servletDescriptor = new ServletModuleDescriptorBuilder()
391 .with(plugin)
392 .with(new TestHttpServlet())
393 .withPath("/servlet")
394 .with(servletModuleManager)
395 .build();
396
397 servletModuleManager.addServletModule(servletDescriptor);
398 servletModuleManager.getServlet("/servlet", mockServletConfig);
399 assertTrue(listener.initCalled);
400 }
401
402 @Test
403 public void testServletListenerContextFilterAndServletUseTheSameServletContext() throws Exception {
404 Plugin plugin = new PluginBuilder().build();
405
406 final AtomicReference<ServletContext> contextRef = new AtomicReference<>();
407
408 new ServletContextListenerModuleDescriptorBuilder()
409 .with(plugin)
410 .with(new TestServletContextListener() {
411 @Override
412 public void contextInitialized(ServletContextEvent event) {
413 contextRef.set(event.getServletContext());
414 }
415 })
416 .build();
417
418
419 ServletModuleDescriptor servletDescriptor = new ServletModuleDescriptorBuilder()
420 .with(plugin)
421 .with(new TestHttpServlet() {
422 @Override
423 public void init(ServletConfig mockServletConfig) {
424 assertSame(contextRef.get(), mockServletConfig.getServletContext());
425 }
426 })
427 .withPath("/servlet")
428 .with(servletModuleManager)
429 .build();
430 servletModuleManager.addServletModule(servletDescriptor);
431
432
433 ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
434 .with(plugin)
435 .with(new FilterAdapter() {
436 @Override
437 public void init(FilterConfig mockFilterConfig) {
438 assertSame(contextRef.get(), mockFilterConfig.getServletContext());
439 }
440 })
441 .withPath("/*")
442 .with(servletModuleManager)
443 .build();
444 servletModuleManager.addFilterModule(filterDescriptor);
445
446 ServletContext mockServletContext = mock(ServletContext.class);
447 when(mockServletContext.getInitParameterNames()).thenReturn(enumeration(emptyList()));
448
449
450
451 ServletConfig mockServletConfig = mock(ServletConfig.class);
452 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
453 servletModuleManager.getServlet("/servlet", mockServletConfig);
454
455
456 FilterConfig mockFilterConfig = mock(FilterConfig.class);
457 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
458 servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/servlet", mockFilterConfig, REQUEST);
459 }
460
461 @Test
462 public void testFiltersWithSameLocationAndWeightInTheSamePluginAppearInTheOrderTheyAreDeclared() throws Exception {
463 ServletContext mockServletContext = mock(ServletContext.class);
464 when(mockServletContext.getInitParameterNames()).thenReturn(enumeration(emptyList()));
465 FilterConfig mockFilterConfig = mock(FilterConfig.class);
466 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
467
468 Plugin plugin = new PluginBuilder().build();
469
470 List<Integer> filterCallOrder = new LinkedList<>();
471 ServletFilterModuleDescriptor d1 = new ServletFilterModuleDescriptorBuilder()
472 .with(plugin)
473 .withKey("filter-1")
474 .with(new SoundOffFilter(filterCallOrder, 1))
475 .withPath("/*")
476 .build();
477 servletModuleManager.addFilterModule(d1);
478
479 ServletFilterModuleDescriptor d2 = new ServletFilterModuleDescriptorBuilder()
480 .with(plugin)
481 .withKey("filter-2")
482 .with(new SoundOffFilter(filterCallOrder, 2))
483 .withPath("/*")
484 .build();
485 servletModuleManager.addFilterModule(d2);
486
487 HttpServletRequest mockHttpServletRequest = mock(HttpServletRequest.class);
488 when(mockHttpServletRequest.getPathInfo()).thenReturn("/servlet");
489 HttpServletResponse mockHttpServletResponse = mock(HttpServletResponse.class);
490
491 Iterable<Filter> filters = servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/some/path", mockFilterConfig, REQUEST);
492 FilterChain chain = new IteratingFilterChain(filters.iterator(), emptyChain);
493
494 chain.doFilter(mockHttpServletRequest, mockHttpServletResponse);
495 assertEquals(newArrayList(1, 2, 2, 1), filterCallOrder);
496 }
497
498 @Test
499 public void testGetFiltersWithDispatcher() throws Exception {
500 ServletContext mockServletContext = mock(ServletContext.class);
501 FilterConfig mockFilterConfig = mock(FilterConfig.class);
502 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
503 when(mockServletContext.getInitParameterNames()).thenReturn(new Vector().elements());
504 Plugin plugin = new PluginBuilder().build();
505
506 ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
507 .with(plugin)
508 .withKey("foo")
509 .with(new FilterAdapter())
510 .withPath("/foo")
511 .with(servletModuleManager)
512 .withDispatcher(REQUEST)
513 .withDispatcher(FORWARD)
514 .build();
515
516 ServletFilterModuleDescriptor filterDescriptor2 = new ServletFilterModuleDescriptorBuilder()
517 .with(plugin)
518 .withKey("bar")
519 .with(new FilterAdapter())
520 .withPath("/foo")
521 .with(servletModuleManager)
522 .withDispatcher(REQUEST)
523 .withDispatcher(INCLUDE)
524 .build();
525
526 ServletFilterModuleDescriptor filterDescriptorDefaults = new ServletFilterModuleDescriptorBuilder()
527 .with(plugin)
528 .withKey("baz")
529 .with(new FilterAdapter())
530 .withPath("/foo")
531 .with(servletModuleManager)
532 .build();
533
534 servletModuleManager.addFilterModule(filterDescriptor);
535 servletModuleManager.addFilterModule(filterDescriptor2);
536 servletModuleManager.addFilterModule(filterDescriptorDefaults);
537
538 assertEquals(3, Iterables.size(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig, REQUEST)));
539 assertEquals(1, Iterables.size(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig, INCLUDE)));
540 assertEquals(1, Iterables.size(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig, FORWARD)));
541 assertEquals(0, Iterables.size(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig, ERROR)));
542 assertEquals(1, Iterables.size(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig, ASYNC)));
543
544 try {
545 servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig, (DispatcherType) null);
546 fail("Shouldn't accept nulls");
547 } catch (NullPointerException ex) {
548
549 }
550 }
551
552 @Test
553 public void testShuttingDownDoesDestroyServletsFiltersAndContexts() throws Exception {
554 final String servletPath = "/servlet";
555 final String pathInfo = "/pathInfo";
556
557 final Plugin mockPlugin = mock(Plugin.class);
558 final ServletModuleDescriptor mockServletModuleDescriptor = mock(ServletModuleDescriptor.class);
559 when(mockServletModuleDescriptor.getPlugin()).thenReturn(mockPlugin);
560 when(mockServletModuleDescriptor.getCompleteKey()).thenReturn("plugin:servlet");
561 when(mockServletModuleDescriptor.getScopeKey()).thenReturn(Optional.empty());
562
563 when(mockServletModuleDescriptor.getPaths()).thenReturn(ImmutableList.of(servletPath));
564 final ServletFilterModuleDescriptor mockServletFilterModuleDescriptor = mock(ServletFilterModuleDescriptor.class);
565 when(mockServletFilterModuleDescriptor.getCompleteKey()).thenReturn("plugin:filter");
566
567 ServletContextListener mockServletContextListener = mock(ServletContextListener.class);
568 ServletContextListenerModuleDescriptor mockServletContextListenerModuleDescriptor = mock(ServletContextListenerModuleDescriptor.class);
569 when(mockServletContextListenerModuleDescriptor.getModule()).thenReturn(mockServletContextListener);
570
571 final Collection<ModuleDescriptor<?>> moduleDescriptors = ImmutableList.of(
572 mockServletModuleDescriptor, mockServletFilterModuleDescriptor, mockServletContextListenerModuleDescriptor);
573 when(mockPlugin.getModuleDescriptors()).thenReturn(moduleDescriptors);
574
575 final ServletContext mockServletContext = mock(ServletContext.class);
576 when(mockServletContext.getInitParameterNames()).thenReturn(Iterators.asEnumeration(Collections.emptyIterator()));
577 final ServletConfig mockServletConfig = mock(ServletConfig.class);
578 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
579 final FilterConfig mockFilterConfig = mock(FilterConfig.class);
580 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
581
582 servletModuleManager.addServletModule(mockServletModuleDescriptor);
583 servletModuleManager.addFilterModule(mockServletFilterModuleDescriptor);
584 servletModuleManager.getServlet(servletPath, mockServletConfig);
585 servletModuleManager.getFilters(FilterLocation.AFTER_ENCODING, pathInfo, mockFilterConfig, DispatcherType.REQUEST);
586
587 final PluginFrameworkShuttingDownEvent pluginFrameworkShuttingDownEvent = mock(PluginFrameworkShuttingDownEvent.class);
588
589 servletModuleManager.onPluginFrameworkBeforeShutdown(pluginFrameworkShuttingDownEvent);
590
591 verify(mockServletModuleDescriptor).destroy();
592 verify(mockServletFilterModuleDescriptor).destroy();
593 verify(mockServletContextListener).contextDestroyed(any(ServletContextEvent.class));
594 }
595
596 @Test
597 public void testFilterOrModuleRemovalDuringShutdown() throws Exception {
598 final String servletPath = "/servlet";
599 final String pathInfo = "/pathInfo";
600
601 final ServletModuleDescriptor servletModuleDescriptor = mock(ServletModuleDescriptor.class);
602 when(servletModuleDescriptor.getCompleteKey()).thenReturn("plugin:servlet");
603 doAnswer(new Answer() {
604 @Override
605 public Object answer(final InvocationOnMock invocation) throws Throwable {
606 servletModuleManager.removeServletModule(servletModuleDescriptor);
607 return null;
608 }
609 }).when(servletModuleDescriptor).destroy();
610 final ServletFilterModuleDescriptor servletFilterModuleDescriptor = mock(ServletFilterModuleDescriptor.class);
611 when(servletFilterModuleDescriptor.getCompleteKey()).thenReturn("plugin:filter");
612 doAnswer(new Answer() {
613 @Override
614 public Object answer(final InvocationOnMock invocation) throws Throwable {
615 servletModuleManager.removeFilterModule(servletFilterModuleDescriptor);
616 return null;
617 }
618 }).when(servletFilterModuleDescriptor).destroy();
619
620 final ServletContext mockServletContext = mock(ServletContext.class);
621 when(mockServletContext.getInitParameterNames()).thenReturn(Iterators.asEnumeration(Collections.emptyIterator()));
622 final ServletConfig mockServletConfig = mock(ServletConfig.class);
623 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
624 final FilterConfig mockFilterConfig = mock(FilterConfig.class);
625 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
626
627 servletModuleManager.addServletModule(servletModuleDescriptor);
628 servletModuleManager.addFilterModule(servletFilterModuleDescriptor);
629
630 final PluginFrameworkShuttingDownEvent pluginFrameworkShuttingDownEvent = mock(PluginFrameworkShuttingDownEvent.class);
631
632 servletModuleManager.onPluginFrameworkBeforeShutdown(pluginFrameworkShuttingDownEvent);
633
634 verify(servletModuleDescriptor).destroy();
635 verify(servletFilterModuleDescriptor).destroy();
636 }
637
638 @Test
639 public void addServletByClassName() {
640 final Plugin plugin = mock(Plugin.class);
641 final ArgumentCaptor<Element> elementCaptor = ArgumentCaptor.forClass(Element.class);
642 final ModuleDescriptor moduleDescriptor = mock(ModuleDescriptor.class);
643
644
645 when(pluginController.addDynamicModule(same(plugin), elementCaptor.capture())).thenReturn(moduleDescriptor);
646
647 servletModuleManager.addServlet(plugin, "roger", "com.americandad.Roger");
648
649 assertThat(elementCaptor.getValue(), isElement(rogerServletElement().addAttribute("class", "com.americandad.Roger")));
650 }
651
652 @Test
653 public void addServletWithHttpServletUnexpectedModuleDescriptor() {
654 expectedException.expect(PluginException.class);
655 expectedException.expectMessage("com.atlassian.plugin.servlet.descriptors.ServletModuleDescriptor;");
656
657 servletModuleManager.addServlet(mock(Plugin.class), "roger", mock(HttpServlet.class), mock(ServletContext.class));
658 }
659
660 @Test
661 public void addServletWithHttpServlet() {
662 final Plugin plugin = mock(Plugin.class);
663 final ArgumentCaptor<Element> elementCaptor = ArgumentCaptor.forClass(Element.class);
664 final ModuleDescriptor moduleDescriptor = mock(ServletModuleDescriptor.class);
665 final HttpServlet httpServlet = mock(HttpServlet.class);
666 final ServletContext servletContext = mock(ServletContext.class);
667
668
669 when(pluginController.addDynamicModule(same(plugin), elementCaptor.capture())).thenReturn(moduleDescriptor);
670 when(moduleDescriptor.getCompleteKey()).thenReturn("roger.key");
671
672 servletModuleManager.addServlet(plugin, "roger", httpServlet, servletContext);
673
674 assertThat(elementCaptor.getValue(), isElement(rogerServletElement()));
675
676 assertThat(servletModuleManager.getServletRefs(), hasKey("roger.key"));
677 assertThat(servletModuleManager.getServletRefs().get("roger.key"), instanceOf(LazyLoadedServletReference.class));
678 }
679
680 @Test
681 public void addSameServletTwice() {
682 expectedException.expect(IllegalStateException.class);
683 expectedException.expectMessage("roger.key");
684
685 addServletWithHttpServlet();
686 addServletWithHttpServlet();
687 }
688
689 private Element rogerServletElement() {
690 final Element e = new DOMElement("servlet");
691 e.addAttribute("key", "roger-servlet");
692 e.addAttribute("name", "rogerServlet");
693 Element url = new DOMElement("url-pattern");
694 url.setText("/roger");
695 e.add(url);
696
697 return e;
698 }
699
700 @Test
701 public void samePluginController() {
702 servletModuleManager.onPluginFrameworkStartingEvent(new PluginFrameworkStartedEvent(pluginController, mock(PluginAccessor.class)));
703
704 servletModuleManager.onPluginFrameworkShutdownEvent(new PluginFrameworkShutdownEvent(pluginController, mock(PluginAccessor.class)));
705
706 assertThat(capturedLogging, not(didLogWarn()));
707 }
708
709 @Test
710 public void differentPluginControllers() {
711 servletModuleManager.onPluginFrameworkStartingEvent(new PluginFrameworkStartedEvent(pluginController, mock(PluginAccessor.class)));
712
713 servletModuleManager.onPluginFrameworkShutdownEvent(new PluginFrameworkShutdownEvent(mock(PluginController.class), mock(PluginAccessor.class)));
714
715 assertThat(capturedLogging, didLogWarn("did not match"));
716 }
717
718 @Test
719 public void testFilterSorting() throws Exception {
720 servletModuleManager = new DefaultServletModuleManager(
721 mockPluginEventManager,
722 mockServletMapper,
723 mockFilterMapper,
724 mockFilterFactory);
725
726 final Set<DispatcherType> dispatcherTypes = new HashSet<>();
727 final DispatcherType dispatcherType = REQUEST;
728 final FilterLocation location = BEFORE_DISPATCH;
729 dispatcherTypes.add(dispatcherType);
730 final ServletFilterModuleDescriptor descriptorAlpha = stubFilterDescriptor(
731 "alpha",
732 dispatcherTypes,
733 location,
734 10
735 );
736 final ServletFilterModuleDescriptor descriptorBeta = stubFilterDescriptor(
737 "beta",
738 dispatcherTypes,
739 location,
740 17
741 );
742 final ServletFilterModuleDescriptor descriptorGamma = stubFilterDescriptor(
743 "gamma",
744 dispatcherTypes,
745 location,
746 8
747 );
748 final Filter filterAlpha = stubFilter(descriptorAlpha);
749 final Filter filterBeta = stubFilter(descriptorBeta);
750 final Filter filterGamma = stubFilter(descriptorGamma);
751 final String dummyPath = "/dummy/path";
752 when(mockFilterMapper.getAll(dummyPath)).thenReturn(Arrays.asList("alpha", "beta", "gamma"));
753 final FilterConfig mockFilterConfig = mock(FilterConfig.class, Mockito.RETURNS_DEEP_STUBS);
754
755 final Iterable<Filter> filters =
756 servletModuleManager.getFilters(location, dummyPath, mockFilterConfig, dispatcherType);
757
758 assertThat(filters, contains(filterGamma, filterAlpha, filterBeta));
759 }
760
761 private Filter stubFilter(final ServletFilterModuleDescriptor descriptor) {
762 final Filter filter = mock(Filter.class);
763 final String filterMockDescription = "Filter " + descriptor.getCompleteKey();
764 when(filter.toString()).thenReturn(filterMockDescription);
765 when(mockFilterFactory.newFilter(descriptor)).thenReturn(filter);
766 return filter;
767 }
768
769 private ServletFilterModuleDescriptor stubFilterDescriptor(
770 final String descriptorCompleteKey,
771 final Set<DispatcherType> dispatcherTypes,
772 final FilterLocation location,
773 int weight
774 ) {
775 final ServletFilterModuleDescriptor descriptor = mock(ServletFilterModuleDescriptor.class);
776 final Plugin plugin = mock(Plugin.class);
777 when(descriptor.toString()).thenReturn("Filter descriptor " + descriptorCompleteKey);
778 when(descriptor.getCompleteKey()).thenReturn(descriptorCompleteKey);
779 when(descriptor.getDispatcherTypes()).thenReturn(dispatcherTypes);
780 when(descriptor.getLocation()).thenReturn(location);
781 when(descriptor.getPlugin()).thenReturn(plugin);
782 when(descriptor.getWeight()).thenReturn(weight);
783
784 servletModuleManager.addFilterModule(descriptor);
785 return descriptor;
786 }
787
788 static class TestServletContextListener implements ServletContextListener {
789 boolean initCalled = false;
790
791 public void contextInitialized(ServletContextEvent event) {
792 initCalled = true;
793 }
794
795 public void contextDestroyed(ServletContextEvent event) {
796 }
797 }
798
799 static class TestHttpServlet extends HttpServlet {
800 boolean serviceCalled = false;
801
802 @Override
803 public void service(ServletRequest request, ServletResponse response) {
804 serviceCalled = true;
805 }
806 }
807
808 static class TestHttpServletWithException extends HttpServlet {
809 @Override
810 public void init(ServletConfig mockServletConfig) throws ServletException {
811 throw new RuntimeException("exception thrown");
812 }
813 }
814
815 static class TestFilterWithException implements Filter {
816 public void init(FilterConfig mockFilterConfig) throws ServletException {
817 throw new RuntimeException("exception thrown");
818 }
819
820 public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
821 throws IOException, ServletException {
822 }
823
824 public void destroy() {
825 }
826 }
827
828 static class TestHttpFilter implements Filter {
829 public void init(FilterConfig mockFilterConfig) throws ServletException {
830 }
831
832 public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
833 throws IOException, ServletException {
834 }
835
836 public void destroy() {
837 }
838 }
839
840 static final class WeightedValue {
841 final int weight;
842 final String value;
843
844 WeightedValue(int weight, String value) {
845 this.weight = weight;
846 this.value = value;
847 }
848
849 @Override
850 public boolean equals(Object o) {
851 if (this == o) {
852 return true;
853 }
854 if (!(o instanceof WeightedValue)) {
855 return false;
856 }
857 WeightedValue rhs = (WeightedValue) o;
858 return weight == rhs.weight && value.equals(rhs.value);
859 }
860
861 @Override
862 public String toString() {
863 return "[" + weight + ", " + value + "]";
864 }
865 }
866 }