1   package com.atlassian.plugin.servlet;
2   
3   import static com.atlassian.plugin.servlet.DefaultServletModuleManager.sortedInsert;
4   import static com.atlassian.plugin.servlet.filter.FilterTestUtils.emptyChain;
5   
6   import java.util.ArrayList;
7   import java.util.Collections;
8   import java.util.Comparator;
9   import java.util.LinkedList;
10  import java.util.List;
11  import java.util.Vector;
12  import java.util.concurrent.atomic.AtomicReference;
13  import java.io.IOException;
14  
15  import javax.servlet.Filter;
16  import javax.servlet.FilterChain;
17  import javax.servlet.FilterConfig;
18  import javax.servlet.ServletConfig;
19  import javax.servlet.ServletContext;
20  import javax.servlet.ServletContextEvent;
21  import javax.servlet.ServletContextListener;
22  import javax.servlet.ServletRequest;
23  import javax.servlet.ServletResponse;
24  import javax.servlet.ServletException;
25  import javax.servlet.http.HttpServlet;
26  import javax.servlet.http.HttpServletRequest;
27  import javax.servlet.http.HttpServletResponse;
28  
29  import junit.framework.TestCase;
30  
31  import com.atlassian.plugin.Plugin;
32  import com.atlassian.plugin.event.PluginEventManager;
33  import com.atlassian.plugin.servlet.descriptors.ServletContextListenerModuleDescriptorBuilder;
34  import com.atlassian.plugin.servlet.descriptors.ServletContextParamDescriptorBuilder;
35  import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptor;
36  import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptorBuilder;
37  import com.atlassian.plugin.servlet.descriptors.ServletModuleDescriptor;
38  import com.atlassian.plugin.servlet.descriptors.ServletModuleDescriptorBuilder;
39  import com.atlassian.plugin.servlet.filter.FilterLocation;
40  import com.atlassian.plugin.servlet.filter.IteratingFilterChain;
41  import com.atlassian.plugin.servlet.filter.FilterTestUtils.FilterAdapter;
42  import com.atlassian.plugin.servlet.filter.FilterTestUtils.SoundOffFilter;
43  import com.mockobjects.dynamic.C;
44  import com.mockobjects.dynamic.Mock;
45  import static org.mockito.Mockito.mock;
46  import static org.mockito.Mockito.when;
47  
48  public class TestDefaultServletModuleManager extends TestCase
49  {
50      ServletModuleManager servletModuleManager;
51      
52      Mock mockPluginEventManager;
53      
54      public void setUp()
55      {
56          mockPluginEventManager = new Mock(PluginEventManager.class);
57          mockPluginEventManager.expect("register", C.anyArgs(1));
58          servletModuleManager = new DefaultServletModuleManager((PluginEventManager) mockPluginEventManager.proxy());
59      }
60      
61      public void testSortedInsertInsertsDistinctElementProperly()
62      {
63          List<String> list = newList("cat", "dog", "fish", "monkey");
64          List<String> endList = newList("cat", "dog", "elephant", "fish", "monkey");
65          sortedInsert(list, "elephant", naturalOrder(String.class));
66          assertEquals(endList, list); 
67      }
68      
69      public void testSortedInsertInsertsNonDistinctElementProperly()
70      {
71          List<WeightedValue> list = newList
72          (
73              new WeightedValue(10, "dog"), new WeightedValue(20, "monkey"), new WeightedValue(20, "tiger"),
74              new WeightedValue(30, "fish"), new WeightedValue(100, "cat")
75          );
76          List<WeightedValue> endList = newList
77          (
78              new WeightedValue(10, "dog"), new WeightedValue(20, "monkey"), new WeightedValue(20, "tiger"),
79              new WeightedValue(20, "elephant"), new WeightedValue(30, "fish"), new WeightedValue(100, "cat")
80          );
81          sortedInsert(list, new WeightedValue(20, "elephant"), WeightedValue.byWeight);
82          assertEquals(endList, list); 
83      }
84      
85      public void testGettingServletWithSimplePath() throws Exception
86      {
87          Mock mockServletContext = new Mock(ServletContext.class);
88          mockServletContext.expectAndReturn("getInitParameterNames", Collections.enumeration(Collections.emptyList()));
89          mockServletContext.expect("log", C.ANY_ARGS);
90          Mock mockServletConfig = new Mock(ServletConfig.class);
91          mockServletConfig.expectAndReturn("getServletContext", mockServletContext.proxy());
92          
93          Mock mockHttpServletRequest = new Mock(HttpServletRequest.class);
94          mockHttpServletRequest.expectAndReturn("getPathInfo", "/servlet");
95          Mock mockHttpServletResponse = new Mock(HttpServletResponse.class);
96          
97          TestHttpServlet servlet = new TestHttpServlet();
98          ServletModuleDescriptor descriptor = new ServletModuleDescriptorBuilder()
99              .with(servlet)
100             .withPath("/servlet")
101             .with(servletModuleManager)
102             .build();
103         
104         servletModuleManager.addServletModule(descriptor);
105         
106         HttpServlet wrappedServlet = servletModuleManager.getServlet("/servlet", (ServletConfig) mockServletConfig.proxy());
107         wrappedServlet.service((HttpServletRequest) mockHttpServletRequest.proxy(), (HttpServletResponse) mockHttpServletResponse.proxy());
108         assertTrue(servlet.serviceCalled);
109     }
110 
111     public void testGettingServletWithException() throws Exception
112     {
113         Mock mockServletContext = new Mock(ServletContext.class);
114         mockServletContext.expectAndReturn("getInitParameterNames", Collections.enumeration(Collections.emptyList()));
115         mockServletContext.expect("log", C.ANY_ARGS);
116         Mock mockServletConfig = new Mock(ServletConfig.class);
117         mockServletConfig.expectAndReturn("getServletContext", mockServletContext.proxy());
118 
119         Mock mockHttpServletRequest = new Mock(HttpServletRequest.class);
120         mockHttpServletRequest.expectAndReturn("getPathInfo", "/servlet");
121         Mock mockHttpServletResponse = new Mock(HttpServletResponse.class);
122 
123         TestHttpServletWithException servlet = new TestHttpServletWithException();
124         ServletModuleDescriptor descriptor = new ServletModuleDescriptorBuilder()
125             .with(servlet)
126             .withPath("/servlet")
127             .with(servletModuleManager)
128             .build();
129 
130         servletModuleManager.addServletModule(descriptor);
131 
132         assertNull(servletModuleManager.getServlet("/servlet", (ServletConfig) mockServletConfig.proxy()));
133     }
134 
135     public void testGettingFilterWithException() throws Exception
136     {
137         Mock mockServletContext = new Mock(ServletContext.class);
138         mockServletContext.expectAndReturn("getInitParameterNames", Collections.enumeration(Collections.emptyList()));
139         mockServletContext.expect("log", C.ANY_ARGS);
140         Mock mockFilterConfig = new Mock(FilterConfig.class);
141         mockFilterConfig.expectAndReturn("getServletContext", mockServletContext.proxy());
142 
143         Mock mockHttpServletRequest = new Mock(HttpServletRequest.class);
144         mockHttpServletRequest.expectAndReturn("getPathInfo", "/servlet");
145 
146         TestFilterWithException servlet = new TestFilterWithException();
147         ServletFilterModuleDescriptor descriptor = new ServletFilterModuleDescriptorBuilder()
148             .with(servlet)
149             .withPath("/servlet")
150             .with(servletModuleManager)
151             .at(FilterLocation.AFTER_ENCODING)
152             .build();
153 
154         servletModuleManager.addFilterModule(descriptor);
155 
156         assertEquals(false, servletModuleManager.getFilters(FilterLocation.AFTER_ENCODING, "/servlet", (FilterConfig) mockFilterConfig.proxy()).iterator().hasNext());
157     }
158     
159     public void testGettingServletWithComplexPath() throws Exception
160     {
161         Mock mockServletContext = new Mock(ServletContext.class);
162         mockServletContext.expectAndReturn("getInitParameterNames", Collections.enumeration(Collections.emptyList()));
163         mockServletContext.expect("log", C.ANY_ARGS);
164         Mock mockServletConfig = new Mock(ServletConfig.class);
165         mockServletConfig.expectAndReturn("getServletContext", mockServletContext.proxy());
166         
167         Mock mockHttpServletRequest = new Mock(HttpServletRequest.class);
168         mockHttpServletRequest.expectAndReturn("getPathInfo", "/servlet");
169         Mock mockHttpServletResponse = new Mock(HttpServletResponse.class);
170         
171         TestHttpServlet servlet = new TestHttpServlet();
172         ServletModuleDescriptor descriptor = new ServletModuleDescriptorBuilder()
173             .with(servlet)
174             .withPath("/servlet/*")
175             .with(servletModuleManager)
176             .build();
177         
178         servletModuleManager.addServletModule(descriptor);
179         
180         HttpServlet wrappedServlet = servletModuleManager.getServlet("/servlet/this/is/a/test", (ServletConfig) mockServletConfig.proxy());
181         wrappedServlet.service((HttpServletRequest) mockHttpServletRequest.proxy(), (HttpServletResponse) mockHttpServletResponse.proxy());
182         assertTrue(servlet.serviceCalled);
183     }
184 
185     public void testMultipleFitlersWithTheSameComplexPath() throws ServletException
186     {
187         ServletContext servletContext = mock(ServletContext.class);
188         FilterConfig filterConfig = mock(FilterConfig.class);
189         when(filterConfig.getServletContext()).thenReturn(servletContext);
190         when(servletContext.getInitParameterNames()).thenReturn(new Vector().elements());
191         Plugin plugin = new PluginBuilder().build();
192         ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
193             .with(plugin)
194             .withKey("foo")
195             .with(new FilterAdapter())
196             .withPath("/foo/*")
197             .with(servletModuleManager)
198             .build();
199 
200         ServletFilterModuleDescriptor filterDescriptor2 = new ServletFilterModuleDescriptorBuilder()
201             .with(plugin)
202             .withKey("bar")
203             .with(new FilterAdapter())
204             .withPath("/foo/*")
205             .with(servletModuleManager)
206             .build();
207         servletModuleManager.addFilterModule(filterDescriptor);
208         servletModuleManager.addFilterModule(filterDescriptor2);
209 
210         servletModuleManager.removeFilterModule(filterDescriptor);
211         assertTrue(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo/jim", filterConfig).iterator().hasNext());
212     }
213 
214     public void testMultipleFitlersWithTheSameSimplePath() throws ServletException
215     {
216         ServletContext servletContext = mock(ServletContext.class);
217         FilterConfig filterConfig = mock(FilterConfig.class);
218         when(filterConfig.getServletContext()).thenReturn(servletContext);
219         when(servletContext.getInitParameterNames()).thenReturn(new Vector().elements());
220         Plugin plugin = new PluginBuilder().build();
221         ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
222             .with(plugin)
223             .withKey("foo")
224             .with(new FilterAdapter())
225             .withPath("/foo")
226             .with(servletModuleManager)
227             .build();
228 
229         ServletFilterModuleDescriptor filterDescriptor2 = new ServletFilterModuleDescriptorBuilder()
230             .with(plugin)
231             .withKey("bar")
232             .with(new FilterAdapter())
233             .withPath("/foo")
234             .with(servletModuleManager)
235             .build();
236         servletModuleManager.addFilterModule(filterDescriptor);
237         servletModuleManager.addFilterModule(filterDescriptor2);
238 
239         servletModuleManager.removeFilterModule(filterDescriptor);
240         assertTrue(servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/foo", filterConfig).iterator().hasNext());
241     }
242     
243     public void testPluginContextInitParamsGetMerged() throws Exception
244     {
245         Mock mockServletContext = new Mock(ServletContext.class);
246         mockServletContext.expectAndReturn("getInitParameterNames", Collections.enumeration(Collections.emptyList()));
247         mockServletContext.expect("log", C.ANY_ARGS);
248         Mock mockServletConfig = new Mock(ServletConfig.class);
249         mockServletConfig.expectAndReturn("getServletContext", mockServletContext.proxy());
250 
251         Plugin plugin = new PluginBuilder().build();
252 
253         new ServletContextParamDescriptorBuilder()
254             .with(plugin)
255             .withParam("param.name", "param.value")
256             .build();
257 
258         // a servlet that will check for param.name to be in the servlet context
259         ServletModuleDescriptor servletDescriptor = new ServletModuleDescriptorBuilder()
260             .with(plugin)
261             .with(new TestHttpServlet()
262             {
263                 @Override
264                 public void init(ServletConfig servletConfig)
265                 {
266                     assertEquals("param.value", servletConfig.getServletContext().getInitParameter("param.name"));
267                 }
268             })
269             .withPath("/servlet")
270             .with(servletModuleManager)
271             .build();
272         servletModuleManager.addServletModule(servletDescriptor);
273         
274         servletModuleManager.getServlet("/servlet", (ServletConfig) mockServletConfig.proxy());
275     }
276     
277     public void testServletListenerContextInitializedIsCalled() throws Exception
278     {
279         Mock mockServletContext = new Mock(ServletContext.class);
280         mockServletContext.expectAndReturn("getInitParameterNames", Collections.enumeration(Collections.emptyList()));
281         mockServletContext.expect("log", C.ANY_ARGS);
282         Mock mockServletConfig = new Mock(ServletConfig.class);
283         mockServletConfig.expectAndReturn("getServletContext", mockServletContext.proxy());
284         
285         final TestServletContextListener listener = new TestServletContextListener();
286         
287         Plugin plugin = new PluginBuilder().build();
288         
289         new ServletContextListenerModuleDescriptorBuilder()
290             .with(plugin)
291             .with(listener)
292             .build();
293         
294         ServletModuleDescriptor servletDescriptor = new ServletModuleDescriptorBuilder()
295             .with(plugin)
296             .with(new TestHttpServlet())
297             .withPath("/servlet")
298             .with(servletModuleManager)
299             .build();
300         
301         servletModuleManager.addServletModule(servletDescriptor);
302         servletModuleManager.getServlet("/servlet", (ServletConfig) mockServletConfig.proxy());
303         assertTrue(listener.initCalled);
304     }
305     
306     public void testServletListenerContextFilterAndServletUseTheSameServletContext() throws Exception
307     {
308         Plugin plugin = new PluginBuilder().build();
309 
310         final AtomicReference<ServletContext> contextRef = new AtomicReference<ServletContext>();
311         // setup a context listener to capture the context
312         new ServletContextListenerModuleDescriptorBuilder()
313             .with(plugin)
314             .with(new TestServletContextListener()
315             {
316                 @Override
317                 public void contextInitialized(ServletContextEvent event)
318                 {
319                     contextRef.set(event.getServletContext());
320                 }
321             })
322             .build();
323         
324         // a servlet that checks that the context is the same for it as it was for the context listener
325         ServletModuleDescriptor servletDescriptor = new ServletModuleDescriptorBuilder()
326             .with(plugin)
327             .with(new TestHttpServlet()
328             {
329                 @Override
330                 public void init(ServletConfig servletConfig)
331                 {
332                     assertSame(contextRef.get(), servletConfig.getServletContext());
333                 }
334             })
335             .withPath("/servlet")
336             .with(servletModuleManager)
337             .build();
338         servletModuleManager.addServletModule(servletDescriptor);
339         
340         // a filter that checks that the context is the same for it as it was for the context listener
341         ServletFilterModuleDescriptor filterDescriptor = new ServletFilterModuleDescriptorBuilder()
342             .with(plugin)
343             .with(new FilterAdapter()
344             {
345                 @Override
346                 public void init(FilterConfig filterConfig)
347                 {
348                     assertSame(contextRef.get(), filterConfig.getServletContext());
349                 }
350             })
351             .withPath("/*")
352             .with(servletModuleManager)
353             .build();
354         servletModuleManager.addFilterModule(filterDescriptor);
355         
356         Mock mockServletContext = new Mock(ServletContext.class);
357         mockServletContext.expectAndReturn("getInitParameterNames", Collections.enumeration(Collections.emptyList()));
358         mockServletContext.expect("log", C.ANY_ARGS);
359 
360         // get a servlet, this will initialize the servlet context for the first time in addition to the servlet itself.
361         // if the servlet doesn't get the same context as the context listener did, the assert will fail
362         Mock mockServletConfig = new Mock(ServletConfig.class);
363         mockServletConfig.expectAndReturn("getServletContext", mockServletContext.proxy());
364         servletModuleManager.getServlet("/servlet", (ServletConfig) mockServletConfig.proxy());
365         
366         // get the filters, if the filter doesn't get the same context as the context listener did, the assert will fail
367         Mock mockFilterConfig = new Mock(FilterConfig.class);
368         mockFilterConfig.expectAndReturn("getServletContext", mockServletContext.proxy());
369         servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/servlet", (FilterConfig) mockFilterConfig.proxy());
370     }
371     
372     public void testFiltersWithSameLocationAndWeightInTheSamePluginAppearInTheOrderTheyAreDeclared() throws Exception
373     {
374         Mock mockServletContext = new Mock(ServletContext.class);
375         mockServletContext.matchAndReturn("getInitParameterNames", Collections.enumeration(Collections.emptyList()));
376         mockServletContext.expect("log", C.ANY_ARGS);
377         Mock mockFilterConfig = new Mock(FilterConfig.class);
378         mockFilterConfig.matchAndReturn("getServletContext", mockServletContext.proxy());
379 
380         Plugin plugin = new PluginBuilder().build();
381         
382         List<Integer> filterCallOrder = new LinkedList<Integer>();
383         ServletFilterModuleDescriptor d1 = new ServletFilterModuleDescriptorBuilder()
384             .with(plugin)
385             .withKey("filter-1")
386             .with(new SoundOffFilter(filterCallOrder, 1))
387             .withPath("/*")
388             .build();
389         servletModuleManager.addFilterModule(d1);
390         
391         ServletFilterModuleDescriptor d2 = new ServletFilterModuleDescriptorBuilder()
392             .with(plugin)
393             .withKey("filter-2")
394             .with(new SoundOffFilter(filterCallOrder, 2))
395             .withPath("/*")
396             .build();
397         servletModuleManager.addFilterModule(d2);
398         
399         Mock mockHttpServletRequest = new Mock(HttpServletRequest.class);
400         mockHttpServletRequest.matchAndReturn("getPathInfo", "/servlet");
401         Mock mockHttpServletResponse = new Mock(HttpServletResponse.class);
402         
403         Iterable<Filter> filters = servletModuleManager.getFilters(FilterLocation.BEFORE_DISPATCH, "/some/path", (FilterConfig) mockFilterConfig.proxy());
404         FilterChain chain = new IteratingFilterChain(filters.iterator(), emptyChain);
405         
406         chain.doFilter((HttpServletRequest) mockHttpServletRequest.proxy(), (HttpServletResponse) mockHttpServletResponse.proxy());
407         assertEquals(newList(1, 2, 2, 1), filterCallOrder);
408     }        
409 
410     static class TestServletContextListener implements ServletContextListener
411     {
412         boolean initCalled = false;
413         
414         public void contextInitialized(ServletContextEvent event)
415         {
416             initCalled = true;
417         }
418 
419         public void contextDestroyed(ServletContextEvent event) {}
420     }
421     
422     static class TestHttpServlet extends HttpServlet
423     {
424         boolean serviceCalled = false;
425         
426         @Override
427         public void service(ServletRequest request, ServletResponse response)
428         {
429             serviceCalled = true;
430         }
431     }
432 
433     static class TestHttpServletWithException extends HttpServlet
434     {
435         @Override
436         public void init(ServletConfig servletConfig) throws ServletException
437         {
438             throw new RuntimeException("exception thrown");
439         }
440     }
441 
442     static class TestFilterWithException implements Filter
443     {
444         public void init(FilterConfig filterConfig) throws ServletException
445         {
446             throw new RuntimeException("exception thrown");
447         }
448 
449         public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException
450         {
451         }
452 
453         public void destroy()
454         {
455         }
456     }
457 
458     static final class WeightedValue
459     {
460         final int weight;
461         final String value;
462         
463         WeightedValue(int weight, String value)
464         {
465             this.weight = weight;
466             this.value = value;
467         }
468         
469         @Override
470         public boolean equals(Object o)
471         {
472             if (this == o)
473                 return true;
474             if (!(o instanceof WeightedValue))
475                 return false;
476             WeightedValue rhs = (WeightedValue) o;
477             return weight == rhs.weight && value.equals(rhs.value);
478         }
479         
480         @Override
481         public String toString()
482         {
483             return "[" + weight + ", " + value + "]";
484         }
485         
486         static final Comparator<WeightedValue> byWeight = new Comparator<WeightedValue>()
487         {
488             public int compare(WeightedValue o1, WeightedValue o2)
489             {
490                 return Integer.valueOf(o1.weight).compareTo(o2.weight);
491             }
492         };
493     }
494     
495     static <T> List<T> newList(T... elements)
496     {
497         List<T> list = new ArrayList<T>();
498         for (T e : elements)
499         {
500             list.add(e);
501         }
502         return list;
503     }
504     
505     static <T extends Comparable<T>> Comparator<T> naturalOrder(Class<T> type)
506     {
507         return new Comparator<T>()
508         {
509             public int compare(T o1, T o2)
510             {
511                 return o1.compareTo(o2);
512             }
513         };
514     }    
515 }