1 package com.atlassian.plugin.servlet;
2
3 import java.io.IOException;
4 import java.util.ArrayList;
5 import java.util.Arrays;
6 import java.util.Collection;
7 import java.util.Collections;
8 import java.util.Comparator;
9 import java.util.HashSet;
10 import java.util.LinkedList;
11 import java.util.List;
12 import java.util.Set;
13 import java.util.Vector;
14 import java.util.concurrent.atomic.AtomicReference;
15
16 import javax.servlet.Filter;
17 import javax.servlet.FilterChain;
18 import javax.servlet.FilterConfig;
19 import javax.servlet.ServletConfig;
20 import javax.servlet.ServletContext;
21 import javax.servlet.ServletContextEvent;
22 import javax.servlet.ServletContextListener;
23 import javax.servlet.ServletException;
24 import javax.servlet.ServletRequest;
25 import javax.servlet.ServletResponse;
26 import javax.servlet.http.HttpServlet;
27 import javax.servlet.http.HttpServletRequest;
28 import javax.servlet.http.HttpServletResponse;
29
30 import com.atlassian.plugin.ModuleDescriptor;
31 import com.atlassian.plugin.Plugin;
32 import com.atlassian.plugin.event.PluginEventManager;
33 import com.atlassian.plugin.event.events.PluginFrameworkShutdownEvent;
34 import com.atlassian.plugin.event.events.PluginFrameworkShuttingDownEvent;
35 import com.atlassian.plugin.servlet.descriptors.ServletContextListenerModuleDescriptor;
36 import com.atlassian.plugin.servlet.descriptors.ServletContextListenerModuleDescriptorBuilder;
37 import com.atlassian.plugin.servlet.descriptors.ServletContextParamDescriptorBuilder;
38 import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptor;
39 import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptorBuilder;
40 import com.atlassian.plugin.servlet.descriptors.ServletModuleDescriptor;
41 import com.atlassian.plugin.servlet.descriptors.ServletModuleDescriptorBuilder;
42 import com.atlassian.plugin.servlet.filter.DelegatingPluginFilter;
43 import com.atlassian.plugin.servlet.filter.FilterDispatcherCondition;
44 import com.atlassian.plugin.servlet.filter.FilterLocation;
45 import com.atlassian.plugin.servlet.filter.FilterTestUtils.FilterAdapter;
46 import com.atlassian.plugin.servlet.filter.FilterTestUtils.SoundOffFilter;
47 import com.atlassian.plugin.servlet.filter.IteratingFilterChain;
48
49 import com.atlassian.plugin.servlet.util.PathMapper;
50
51 import com.google.common.collect.ImmutableList;
52 import com.google.common.collect.Iterables;
53 import com.google.common.collect.Iterators;
54
55 import org.junit.Before;
56 import org.junit.Test;
57 import org.junit.runner.RunWith;
58 import org.mockito.Mock;
59 import org.mockito.Mockito;
60 import org.mockito.invocation.InvocationOnMock;
61 import org.mockito.runners.MockitoJUnitRunner;
62 import org.mockito.stubbing.Answer;
63
64 import static com.atlassian.plugin.servlet.filter.FilterDispatcherCondition.ERROR;
65 import static com.atlassian.plugin.servlet.filter.FilterDispatcherCondition.FORWARD;
66 import static com.atlassian.plugin.servlet.filter.FilterDispatcherCondition.INCLUDE;
67 import static com.atlassian.plugin.servlet.filter.FilterDispatcherCondition.REQUEST;
68 import static com.atlassian.plugin.servlet.filter.FilterLocation.BEFORE_DISPATCH;
69 import static com.atlassian.plugin.servlet.filter.FilterTestUtils.emptyChain;
70 import static org.hamcrest.Matchers.contains;
71 import static org.junit.Assert.assertEquals;
72 import static org.junit.Assert.assertNull;
73 import static org.junit.Assert.assertSame;
74 import static org.junit.Assert.assertThat;
75 import static org.junit.Assert.assertTrue;
76 import static org.junit.Assert.fail;
77 import static org.mockito.Matchers.any;
78 import static org.mockito.Mockito.doAnswer;
79 import static org.mockito.Mockito.mock;
80 import static org.mockito.Mockito.verify;
81 import static org.mockito.Mockito.when;
82
83 @RunWith(MockitoJUnitRunner.class)
84 public class TestDefaultServletModuleManager
85 {
86 private DefaultServletModuleManager servletModuleManager;
87
88 @Mock
89 private PluginEventManager mockPluginEventManager;
90
91 @Mock
92 private PathMapper mockServletMapper;
93
94 @Mock
95 private PathMapper mockFilterMapper;
96
97 @Mock
98 private FilterFactory mockFilterFactory;
99
100 @Before
101 public void setUp()
102 {
103 servletModuleManager = new DefaultServletModuleManager(mockPluginEventManager);
104 }
105
106 @Test
107 public void testGettingServletWithSimplePath() throws Exception
108 {
109 final ServletContext mockServletContext = mock(ServletContext.class);
110 when(mockServletContext.getInitParameterNames()).thenReturn(Iterators.asEnumeration(Iterators.<String>emptyIterator()));
111 final ServletConfig mockServletConfig = mock(ServletConfig.class);
112 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
113
114 final HttpServletRequest mockHttpServletRequest = mock(HttpServletRequest.class);
115 when(mockHttpServletRequest.getPathInfo()).thenReturn("/servlet");
116 final HttpServletResponse mockHttpServletResponse = mock(HttpServletResponse.class);
117
118 TestHttpServlet servlet = new TestHttpServlet();
119 ServletModuleDescriptor descriptor = new ServletModuleDescriptorBuilder()
120 .with(servlet)
121 .withPath("/servlet")
122 .with(servletModuleManager)
123 .build();
124
125 servletModuleManager.addServletModule(descriptor);
126
127 HttpServlet wrappedServlet = servletModuleManager.getServlet("/servlet", mockServletConfig);
128 wrappedServlet.service(mockHttpServletRequest, mockHttpServletResponse);
129 assertTrue(servlet.serviceCalled);
130 }
131
132 @Test
133 public void testGettingServlet()
134 {
135 getServletTwice(false);
136 }
137
138 private void getServletTwice(boolean expectNewServletEachCall)
139 {
140 DefaultServletModuleManager mgr = new DefaultServletModuleManager(mockPluginEventManager);
141
142 AtomicReference<HttpServlet> servletRef = new AtomicReference<HttpServlet>();
143 TestHttpServlet firstServlet = new TestHttpServlet();
144 servletRef.set(firstServlet);
145 ServletModuleDescriptor descriptor = new ServletModuleDescriptorBuilder()
146 .withFactory(ObjectFactories.createMutable(servletRef))
147 .withPath("/servlet")
148 .with(mgr)
149 .build();
150
151 final ServletConfig mockServletConfig = mock(ServletConfig.class);
152 final ServletContext mockServletContext = mock(ServletContext.class);
153 when(mockServletContext.getInitParameterNames()).thenReturn(Collections.enumeration(Collections.<String>emptyList()));
154 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
155
156 assertTrue(firstServlet == ((DelegatingPluginServlet) mgr.getServlet(descriptor, mockServletConfig)).getDelegatingServlet());
157
158 TestHttpServlet secondServlet = new TestHttpServlet();
159 servletRef.set(secondServlet);
160 HttpServlet expectedServlet = (expectNewServletEachCall ? secondServlet : firstServlet);
161 assertTrue(expectedServlet == ((DelegatingPluginServlet) mgr.getServlet(descriptor, mockServletConfig)).getDelegatingServlet());
162 }
163
164 @Test
165 public void testGettingFilter()
166 {
167 getFilterTwice(false);
168 }
169
170 private void getFilterTwice(boolean expectNewFilterEachCall)
171 {
172 DefaultServletModuleManager mgr = new DefaultServletModuleManager(mockPluginEventManager);
173
174 AtomicReference<Filter> filterRef = new AtomicReference<Filter>();
175 TestHttpFilter firstFilter = new TestHttpFilter();
176 filterRef.set(firstFilter);
177 ServletFilterModuleDescriptor descriptor = new ServletFilterModuleDescriptorBuilder()
178 .withFactory(ObjectFactories.createMutable(filterRef))
179 .withPath("/servlet")
180 .with(mgr)
181 .build();
182
183 final FilterConfig mockFilterConfig = mock(FilterConfig.class);
184 final ServletContext mockServletContext = mock(ServletContext.class);
185 when(mockServletContext.getInitParameterNames()).thenReturn(Collections.enumeration(Collections.<String>emptyList()));
186 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
187
188 assertTrue(firstFilter == ((DelegatingPluginFilter) mgr.getFilter(descriptor, mockFilterConfig)).getDelegatingFilter());
189
190 TestHttpFilter secondFilter = new TestHttpFilter();
191 filterRef.set(secondFilter);
192 Filter expectedFilter = (expectNewFilterEachCall ? secondFilter : firstFilter);
193 assertTrue(expectedFilter == ((DelegatingPluginFilter) mgr.getFilter(descriptor, mockFilterConfig)).getDelegatingFilter());
194 }
195
196 @Test
197 public void testGettingServletWithException() throws Exception
198 {
199 ServletContext mockServletContext = mock(ServletContext.class);
200 when(mockServletContext.getInitParameterNames()).thenReturn(Collections.enumeration(Collections.<String>emptyList()));
201 ServletConfig mockServletConfig = mock(ServletConfig.class);
202 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
203
204 HttpServletRequest mockHttpServletRequest = mock(HttpServletRequest.class);
205 when(mockHttpServletRequest.getPathInfo()).thenReturn("/servlet");
206
207 TestHttpServletWithException servlet = new TestHttpServletWithException();
208 ServletModuleDescriptor descriptor = new ServletModuleDescriptorBuilder()
209 .with(servlet)
210 .withPath("/servlet")
211 .with(servletModuleManager)
212 .build();
213
214 servletModuleManager.addServletModule(descriptor);
215
216 assertNull(servletModuleManager.getServlet("/servlet", mockServletConfig));
217 }
218
219 @Test
220 public void testGettingFilterWithException() throws Exception
221 {
222 ServletContext mockServletContext = mock(ServletContext.class);
223 when(mockServletContext.getInitParameterNames()).thenReturn(Collections.enumeration(Collections.<String>emptyList()));
224 FilterConfig mockFilterConfig = mock(FilterConfig.class);
225 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
226
227 HttpServletRequest mockHttpServletRequest = mock(HttpServletRequest.class);
228 when(mockHttpServletRequest.getPathInfo()).thenReturn("/servlet");
229
230 TestFilterWithException servlet = new TestFilterWithException();
231 ServletFilterModuleDescriptor descriptor = new ServletFilterModuleDescriptorBuilder()
232 .with(servlet)
233 .withPath("/servlet")
234 .with(servletModuleManager)
235 .at(FilterLocation.AFTER_ENCODING)
236 .build();
237
238 servletModuleManager.addFilterModule(descriptor);
239
240 assertEquals(false, servletModuleManager.getFilters(FilterLocation.AFTER_ENCODING, "/servlet", mockFilterConfig).iterator().hasNext());
241 }
242
243 @Test
244 public void testGettingServletWithComplexPath() throws Exception
245 {
246 ServletContext mockServletContext = mock(ServletContext.class);
247 when(mockServletContext.getInitParameterNames()).thenReturn(Collections.enumeration(Collections.<String>emptyList()));
248 ServletConfig mockServletConfig = mock(ServletConfig.class);
249 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
250
251 HttpServletRequest mockHttpServletRequest = mock(HttpServletRequest.class);
252 when(mockHttpServletRequest.getPathInfo()).thenReturn("/servlet");
253 HttpServletResponse mockHttpServletResponse = mock(HttpServletResponse.class);
254
255 TestHttpServlet servlet = new TestHttpServlet();
256 ServletModuleDescriptor descriptor = new ServletModuleDescriptorBuilder()
257 .with(servlet)
258 .withPath("/servlet/*")
259 .with(servletModuleManager)
260 .build();
261
262 servletModuleManager.addServletModule(descriptor);
263
264 HttpServlet wrappedServlet = servletModuleManager.getServlet("/servlet/this/is/a/test", mockServletConfig);
265 wrappedServlet.service(mockHttpServletRequest, mockHttpServletResponse);
266 assertTrue(servlet.serviceCalled);
267 }
268
269 @Test
270 public void testMultipleFiltersWithTheSameComplexPath() throws ServletException
271 {
272 ServletContext mockServletContext = mock(ServletContext.class);
273 FilterConfig mockFilterConfig = mock(FilterConfig.class);
274 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
275 when(mockServletContext.getInitParameterNames()).thenReturn(new Vector().elements());
276 Plugin plugin = new PluginBuilder().build();
277 ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
278 .with(plugin)
279 .withKey("foo")
280 .with(new FilterAdapter())
281 .withPath("/foo/*")
282 .with(servletModuleManager)
283 .build();
284
285 ServletFilterModuleDescriptor filterDescriptor2 = new ServletFilterModuleDescriptorBuilder()
286 .with(plugin)
287 .withKey("bar")
288 .with(new FilterAdapter())
289 .withPath("/foo/*")
290 .with(servletModuleManager)
291 .build();
292 servletModuleManager.addFilterModule(filterDescriptor);
293 servletModuleManager.addFilterModule(filterDescriptor2);
294
295 servletModuleManager.removeFilterModule(filterDescriptor);
296 assertTrue(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo/jim", mockFilterConfig).iterator().hasNext());
297 }
298
299 @Test
300 public void testMultipleFiltersWithTheSameSimplePath() throws ServletException
301 {
302 ServletContext mockServletContext = mock(ServletContext.class);
303 FilterConfig mockFilterConfig = mock(FilterConfig.class);
304 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
305 when(mockServletContext.getInitParameterNames()).thenReturn(new Vector().elements());
306 Plugin plugin = new PluginBuilder().build();
307 ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
308 .with(plugin)
309 .withKey("foo")
310 .with(new FilterAdapter())
311 .withPath("/foo")
312 .with(servletModuleManager)
313 .build();
314
315 ServletFilterModuleDescriptor filterDescriptor2 = new ServletFilterModuleDescriptorBuilder()
316 .with(plugin)
317 .withKey("bar")
318 .with(new FilterAdapter())
319 .withPath("/foo")
320 .with(servletModuleManager)
321 .build();
322 servletModuleManager.addFilterModule(filterDescriptor);
323 servletModuleManager.addFilterModule(filterDescriptor2);
324
325 servletModuleManager.removeFilterModule(filterDescriptor);
326 assertTrue(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig).iterator().hasNext());
327 }
328
329 @Test
330 public void testPluginContextInitParamsGetMerged() throws Exception
331 {
332 ServletContext mockServletContext = mock(ServletContext.class);
333 when(mockServletContext.getInitParameterNames()).thenReturn(Collections.enumeration(Collections.<String>emptyList()));
334 ServletConfig mockServletConfig = mock(ServletConfig.class);
335 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
336
337 Plugin plugin = new PluginBuilder().build();
338
339 new ServletContextParamDescriptorBuilder()
340 .with(plugin)
341 .withParam("param.name", "param.value")
342 .build();
343
344
345 ServletModuleDescriptor servletDescriptor = new ServletModuleDescriptorBuilder()
346 .with(plugin)
347 .with(new TestHttpServlet()
348 {
349 @Override
350 public void init(ServletConfig servletConfig)
351 {
352 assertEquals("param.value", servletConfig.getServletContext().getInitParameter("param.name"));
353 }
354 })
355 .withPath("/servlet")
356 .with(servletModuleManager)
357 .build();
358 servletModuleManager.addServletModule(servletDescriptor);
359
360 servletModuleManager.getServlet("/servlet", mockServletConfig);
361 }
362
363 @Test
364 public void testServletListenerContextInitializedIsCalled() throws Exception
365 {
366 ServletContext mockServletContext = mock(ServletContext.class);
367 when(mockServletContext.getInitParameterNames()).thenReturn(Collections.enumeration(Collections.<String>emptyList()));
368 ServletConfig mockServletConfig = mock(ServletConfig.class);
369 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
370
371 final TestServletContextListener listener = new TestServletContextListener();
372
373 Plugin plugin = new PluginBuilder().build();
374
375 new ServletContextListenerModuleDescriptorBuilder()
376 .with(plugin)
377 .with(listener)
378 .build();
379
380 ServletModuleDescriptor servletDescriptor = new ServletModuleDescriptorBuilder()
381 .with(plugin)
382 .with(new TestHttpServlet())
383 .withPath("/servlet")
384 .with(servletModuleManager)
385 .build();
386
387 servletModuleManager.addServletModule(servletDescriptor);
388 servletModuleManager.getServlet("/servlet", mockServletConfig);
389 assertTrue(listener.initCalled);
390 }
391
392 @Test
393 public void testServletListenerContextFilterAndServletUseTheSameServletContext() throws Exception
394 {
395 Plugin plugin = new PluginBuilder().build();
396
397 final AtomicReference<ServletContext> contextRef = new AtomicReference<ServletContext>();
398
399 new ServletContextListenerModuleDescriptorBuilder()
400 .with(plugin)
401 .with(new TestServletContextListener()
402 {
403 @Override
404 public void contextInitialized(ServletContextEvent event)
405 {
406 contextRef.set(event.getServletContext());
407 }
408 })
409 .build();
410
411
412 ServletModuleDescriptor servletDescriptor = new ServletModuleDescriptorBuilder()
413 .with(plugin)
414 .with(new TestHttpServlet()
415 {
416 @Override
417 public void init(ServletConfig mockServletConfig)
418 {
419 assertSame(contextRef.get(), mockServletConfig.getServletContext());
420 }
421 })
422 .withPath("/servlet")
423 .with(servletModuleManager)
424 .build();
425 servletModuleManager.addServletModule(servletDescriptor);
426
427
428 ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
429 .with(plugin)
430 .with(new FilterAdapter()
431 {
432 @Override
433 public void init(FilterConfig mockFilterConfig)
434 {
435 assertSame(contextRef.get(), mockFilterConfig.getServletContext());
436 }
437 })
438 .withPath("/*")
439 .with(servletModuleManager)
440 .build();
441 servletModuleManager.addFilterModule(filterDescriptor);
442
443 ServletContext mockServletContext = mock(ServletContext.class);
444 when(mockServletContext.getInitParameterNames()).thenReturn(Collections.enumeration(Collections.<String>emptyList()));
445
446
447
448 ServletConfig mockServletConfig = mock(ServletConfig.class);
449 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
450 servletModuleManager.getServlet("/servlet", mockServletConfig);
451
452
453 FilterConfig mockFilterConfig = mock(FilterConfig.class);
454 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
455 servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/servlet", mockFilterConfig);
456 }
457
458 @Test
459 public void testFiltersWithSameLocationAndWeightInTheSamePluginAppearInTheOrderTheyAreDeclared() throws Exception
460 {
461 ServletContext mockServletContext = mock(ServletContext.class);
462 when(mockServletContext.getInitParameterNames()).thenReturn(Collections.enumeration(Collections.<String>emptyList()));
463 FilterConfig mockFilterConfig = mock(FilterConfig.class);
464 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
465
466 Plugin plugin = new PluginBuilder().build();
467
468 List<Integer> filterCallOrder = new LinkedList<Integer>();
469 ServletFilterModuleDescriptor d1 = new ServletFilterModuleDescriptorBuilder()
470 .with(plugin)
471 .withKey("filter-1")
472 .with(new SoundOffFilter(filterCallOrder, 1))
473 .withPath("/*")
474 .build();
475 servletModuleManager.addFilterModule(d1);
476
477 ServletFilterModuleDescriptor d2 = new ServletFilterModuleDescriptorBuilder()
478 .with(plugin)
479 .withKey("filter-2")
480 .with(new SoundOffFilter(filterCallOrder, 2))
481 .withPath("/*")
482 .build();
483 servletModuleManager.addFilterModule(d2);
484
485 HttpServletRequest mockHttpServletRequest = mock(HttpServletRequest.class);
486 when(mockHttpServletRequest.getPathInfo()).thenReturn("/servlet");
487 HttpServletResponse mockHttpServletResponse = mock(HttpServletResponse.class);
488
489 Iterable<Filter> filters = servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/some/path", mockFilterConfig);
490 FilterChain chain = new IteratingFilterChain(filters.iterator(), emptyChain);
491
492 chain.doFilter(mockHttpServletRequest, mockHttpServletResponse);
493 assertEquals(newList(1, 2, 2, 1), filterCallOrder);
494 }
495
496 @Test
497 public void testGetFiltersWithDispatcher() throws Exception
498 {
499 ServletContext mockServletContext = mock(ServletContext.class);
500 FilterConfig mockFilterConfig = mock(FilterConfig.class);
501 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
502 when(mockServletContext.getInitParameterNames()).thenReturn(new Vector().elements());
503 Plugin plugin = new PluginBuilder().build();
504
505 ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
506 .with(plugin)
507 .withKey("foo")
508 .with(new FilterAdapter())
509 .withPath("/foo")
510 .with(servletModuleManager)
511 .withDispatcher(REQUEST)
512 .withDispatcher(FORWARD)
513 .build();
514
515 ServletFilterModuleDescriptor filterDescriptor2 = new ServletFilterModuleDescriptorBuilder()
516 .with(plugin)
517 .withKey("bar")
518 .with(new FilterAdapter())
519 .withPath("/foo")
520 .with(servletModuleManager)
521 .withDispatcher(REQUEST)
522 .withDispatcher(INCLUDE)
523 .build();
524
525 servletModuleManager.addFilterModule(filterDescriptor);
526 servletModuleManager.addFilterModule(filterDescriptor2);
527
528 assertEquals(2, Iterables.size(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig, REQUEST)));
529 assertEquals(1, Iterables.size(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig, INCLUDE)));
530 assertEquals(1, Iterables.size(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig, FORWARD)));
531 assertEquals(0, Iterables.size(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig, ERROR)));
532
533 try
534 {
535 servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", mockFilterConfig, null);
536 fail("Shouldn't accept nulls");
537 }
538 catch (NullPointerException ex)
539 {
540
541 }
542 }
543
544 @Test
545 public void testShuttingDownDoesDestroyServletsFiltersAndContexts() throws Exception
546 {
547 final String servletPath = "/servlet";
548 final String pathInfo = "/pathInfo";
549
550 final Plugin mockPlugin = mock(Plugin.class);
551 final ServletModuleDescriptor mockServletModuleDescriptor = mock(ServletModuleDescriptor.class);
552 when(mockServletModuleDescriptor.getCompleteKey()).thenReturn("plugin:servlet");
553 when(mockServletModuleDescriptor.getPlugin()).thenReturn(mockPlugin);
554 when(mockServletModuleDescriptor.getPaths()).thenReturn(ImmutableList.of(servletPath));
555 final ServletFilterModuleDescriptor mockServletFilterModuleDescriptor = mock(ServletFilterModuleDescriptor.class);
556 when(mockServletFilterModuleDescriptor.getCompleteKey()).thenReturn("plugin:filter");
557
558 ServletContextListener mockServletContextListener = mock(ServletContextListener.class);
559 ServletContextListenerModuleDescriptor mockServletContextListenerModuleDescriptor = mock(ServletContextListenerModuleDescriptor.class);
560 when(mockServletContextListenerModuleDescriptor.getModule()).thenReturn(mockServletContextListener);
561
562 final Collection<ModuleDescriptor<?>> moduleDescriptors = ImmutableList.<ModuleDescriptor<?>>of(
563 mockServletModuleDescriptor, mockServletFilterModuleDescriptor, mockServletContextListenerModuleDescriptor);
564 when(mockPlugin.getModuleDescriptors()).thenReturn(moduleDescriptors);
565
566 final ServletContext mockServletContext = mock(ServletContext.class);
567 when(mockServletContext.getInitParameterNames()).thenReturn(Iterators.asEnumeration(Iterators.<String>emptyIterator()));
568 final ServletConfig mockServletConfig = mock(ServletConfig.class);
569 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
570 final FilterConfig mockFilterConfig = mock(FilterConfig.class);
571 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
572
573 servletModuleManager.addServletModule(mockServletModuleDescriptor);
574 servletModuleManager.addFilterModule(mockServletFilterModuleDescriptor);
575 servletModuleManager.getServlet(servletPath, mockServletConfig);
576 servletModuleManager.getFilters(FilterLocation.AFTER_ENCODING, pathInfo, mockFilterConfig, FilterDispatcherCondition.REQUEST);
577
578 final PluginFrameworkShuttingDownEvent pluginFrameworkShuttingDownEvent = mock(PluginFrameworkShuttingDownEvent.class);
579
580 servletModuleManager.onPluginFrameworkBeforeShutdown(pluginFrameworkShuttingDownEvent);
581
582 verify(mockServletModuleDescriptor).destroy();
583 verify(mockServletFilterModuleDescriptor).destroy();
584 verify(mockServletContextListener).contextDestroyed(any(ServletContextEvent.class));
585
586
587 final PluginFrameworkShutdownEvent mockPluginFrameworkShutdownEvent = mock(PluginFrameworkShutdownEvent.class);
588
589 servletModuleManager.onPluginFrameworkShutdown(mockPluginFrameworkShutdownEvent);
590
591
592 verify(mockServletModuleDescriptor).destroy();
593 verify(mockServletFilterModuleDescriptor).destroy();
594 verify(mockServletContextListener).contextDestroyed(any(ServletContextEvent.class));
595 }
596
597 @Test
598 public void testFilterOrModuleRemovalDuringShutdown() throws Exception
599 {
600 final String servletPath = "/servlet";
601 final String pathInfo = "/pathInfo";
602
603 final ServletModuleDescriptor servletModuleDescriptor = mock(ServletModuleDescriptor.class);
604 when(servletModuleDescriptor.getCompleteKey()).thenReturn("plugin:servlet");
605 doAnswer(new Answer()
606 {
607 @Override
608 public Object answer(final InvocationOnMock invocation) throws Throwable
609 {
610 servletModuleManager.removeServletModule(servletModuleDescriptor);
611 return null;
612 }
613 }).when(servletModuleDescriptor).destroy();
614 final ServletFilterModuleDescriptor servletFilterModuleDescriptor = mock(ServletFilterModuleDescriptor.class);
615 when(servletFilterModuleDescriptor.getCompleteKey()).thenReturn("plugin:filter");
616 doAnswer(new Answer()
617 {
618 @Override
619 public Object answer(final InvocationOnMock invocation) throws Throwable
620 {
621 servletModuleManager.removeFilterModule(servletFilterModuleDescriptor);
622 return null;
623 }
624 }).when(servletFilterModuleDescriptor).destroy();
625
626 final ServletContext mockServletContext = mock(ServletContext.class);
627 when(mockServletContext.getInitParameterNames()).thenReturn(Iterators.asEnumeration(Iterators.emptyIterator()));
628 final ServletConfig mockServletConfig = mock(ServletConfig.class);
629 when(mockServletConfig.getServletContext()).thenReturn(mockServletContext);
630 final FilterConfig mockFilterConfig = mock(FilterConfig.class);
631 when(mockFilterConfig.getServletContext()).thenReturn(mockServletContext);
632
633 servletModuleManager.addServletModule(servletModuleDescriptor);
634 servletModuleManager.addFilterModule(servletFilterModuleDescriptor);
635
636 final PluginFrameworkShuttingDownEvent pluginFrameworkShuttingDownEvent = mock(PluginFrameworkShuttingDownEvent.class);
637
638 servletModuleManager.onPluginFrameworkBeforeShutdown(pluginFrameworkShuttingDownEvent);
639
640 verify(servletModuleDescriptor).destroy();
641 verify(servletFilterModuleDescriptor).destroy();
642 }
643
644
645
646
647 @Test
648 public void testFilterSorting() throws Exception
649 {
650 servletModuleManager = new DefaultServletModuleManager(
651 mockPluginEventManager,
652 mockServletMapper,
653 mockFilterMapper,
654 mockFilterFactory
655 );
656 final Set<FilterDispatcherCondition> dispatcherConditions = new HashSet<FilterDispatcherCondition>();
657 final FilterDispatcherCondition dispatcherCondition = REQUEST;
658 final FilterLocation location = BEFORE_DISPATCH;
659 dispatcherConditions.add(dispatcherCondition);
660 final ServletFilterModuleDescriptor descriptorAlpha = stubFilterDescriptor(
661 "alpha",
662 dispatcherConditions,
663 location,
664 10
665 );
666 final ServletFilterModuleDescriptor descriptorBeta = stubFilterDescriptor(
667 "beta",
668 dispatcherConditions,
669 location,
670 17
671 );
672 final ServletFilterModuleDescriptor descriptorGamma = stubFilterDescriptor(
673 "gamma",
674 dispatcherConditions,
675 location,
676 8
677 );
678 final Filter filterAlpha = stubFilter(descriptorAlpha);
679 final Filter filterBeta = stubFilter(descriptorBeta);
680 final Filter filterGamma = stubFilter(descriptorGamma);
681 final String dummyPath = "/dummy/path";
682 when(mockFilterMapper.getAll(dummyPath)).thenReturn(Arrays.asList("alpha", "beta", "gamma"));
683 final FilterConfig mockFilterConfig = mock(FilterConfig.class, Mockito.RETURNS_DEEP_STUBS);
684
685 final Iterable<Filter> filters =
686 servletModuleManager.getFilters(location, dummyPath, mockFilterConfig, dispatcherCondition);
687
688 assertThat(filters, contains(filterGamma, filterAlpha, filterBeta));
689 }
690
691 private Filter stubFilter(final ServletFilterModuleDescriptor descriptor)
692 {
693 final Filter filter = mock(Filter.class);
694 final String filterMockDescription = "Filter " + descriptor.getCompleteKey();
695 when(filter.toString()).thenReturn(filterMockDescription);
696 when(mockFilterFactory.newFilter(descriptor)).thenReturn(filter);
697 return filter;
698 }
699
700 private ServletFilterModuleDescriptor stubFilterDescriptor(
701 final String descriptorCompleteKey,
702 final Set<FilterDispatcherCondition> dispatcherConditions,
703 final FilterLocation location,
704 int weight
705 )
706 {
707 final ServletFilterModuleDescriptor descriptor = mock(ServletFilterModuleDescriptor.class);
708 final Plugin plugin = mock(Plugin.class);
709 when(descriptor.toString()).thenReturn("Filter descriptor " + descriptorCompleteKey);
710 when(descriptor.getCompleteKey()).thenReturn(descriptorCompleteKey);
711 when(descriptor.getDispatcherConditions()).thenReturn(dispatcherConditions);
712 when(descriptor.getLocation()).thenReturn(location);
713 when(descriptor.getPlugin()).thenReturn(plugin);
714 when(descriptor.getWeight()).thenReturn(weight);
715 servletModuleManager.addFilterModule(descriptor);
716 return descriptor;
717 }
718
719 static class TestServletContextListener implements ServletContextListener
720 {
721 boolean initCalled = false;
722
723 public void contextInitialized(ServletContextEvent event)
724 {
725 initCalled = true;
726 }
727
728 public void contextDestroyed(ServletContextEvent event)
729 {
730 }
731 }
732
733 static class TestHttpServlet extends HttpServlet
734 {
735 boolean serviceCalled = false;
736
737 @Override
738 public void service(ServletRequest request, ServletResponse response)
739 {
740 serviceCalled = true;
741 }
742 }
743
744 static class TestHttpServletWithException extends HttpServlet
745 {
746 @Override
747 public void init(ServletConfig mockServletConfig) throws ServletException
748 {
749 throw new RuntimeException("exception thrown");
750 }
751 }
752
753 static class TestFilterWithException implements Filter
754 {
755 public void init(FilterConfig mockFilterConfig) throws ServletException
756 {
757 throw new RuntimeException("exception thrown");
758 }
759
760 public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
761 throws IOException, ServletException
762 {
763 }
764
765 public void destroy()
766 {
767 }
768 }
769
770 static class TestHttpFilter implements Filter
771 {
772 public void init(FilterConfig mockFilterConfig) throws ServletException
773 {
774 }
775
776 public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
777 throws IOException, ServletException
778 {
779 }
780
781 public void destroy()
782 {
783 }
784 }
785
786 static final class WeightedValue
787 {
788 final int weight;
789 final String value;
790
791 WeightedValue(int weight, String value)
792 {
793 this.weight = weight;
794 this.value = value;
795 }
796
797 @Override
798 public boolean equals(Object o)
799 {
800 if (this == o)
801 {
802 return true;
803 }
804 if (!(o instanceof WeightedValue))
805 {
806 return false;
807 }
808 WeightedValue rhs = (WeightedValue) o;
809 return weight == rhs.weight && value.equals(rhs.value);
810 }
811
812 @Override
813 public String toString()
814 {
815 return "[" + weight + ", " + value + "]";
816 }
817
818 static final Comparator<WeightedValue> byWeight = new Comparator<WeightedValue>()
819 {
820 public int compare(WeightedValue o1, WeightedValue o2)
821 {
822 return Integer.valueOf(o1.weight).compareTo(o2.weight);
823 }
824 };
825 }
826
827 static <T> List<T> newList(T... elements)
828 {
829 List<T> list = new ArrayList<T>();
830 for (T e : elements)
831 {
832 list.add(e);
833 }
834 return list;
835 }
836
837 static <T extends Comparable<T>> Comparator<T> naturalOrder(Class<T> type)
838 {
839 return new Comparator<T>()
840 {
841 public int compare(T o1, T o2)
842 {
843 return o1.compareTo(o2);
844 }
845 };
846 }
847 }