1 package com.atlassian.plugin.servlet.filter;
2
3 import com.atlassian.plugin.servlet.filter.FilterTestUtils.FilterAdapter;
4 import com.atlassian.plugin.servlet.filter.FilterTestUtils.SoundOffFilter;
5 import org.junit.Rule;
6 import org.junit.Test;
7 import org.junit.rules.ExpectedException;
8
9 import javax.servlet.Filter;
10 import javax.servlet.FilterChain;
11 import javax.servlet.ServletException;
12 import javax.servlet.ServletRequest;
13 import javax.servlet.ServletResponse;
14 import javax.servlet.http.HttpServletRequest;
15 import javax.servlet.http.HttpServletResponse;
16 import java.io.IOException;
17 import java.util.ArrayList;
18 import java.util.LinkedList;
19 import java.util.List;
20
21 import static com.atlassian.plugin.servlet.filter.FilterTestUtils.singletonFilterChain;
22 import static org.hamcrest.MatcherAssert.assertThat;
23 import static org.hamcrest.Matchers.contains;
24 import static org.junit.Assert.fail;
25 import static org.mockito.Mockito.mock;
26 import static org.mockito.Mockito.when;
27
28 public class TestIteratingFilterChain {
29 @Rule
30 public final ExpectedException expectedException = ExpectedException.none();
31
32 @Test
33 public void testFiltersCalledInProperOrder() throws IOException, ServletException {
34 List<Integer> filterCallOrder = new LinkedList<>();
35 List<Filter> filters = new ArrayList<>();
36 for (int i = 0; i < 5; i++) {
37 filters.add(new SoundOffFilter(filterCallOrder, i));
38 }
39
40 FilterChain chain = new IteratingFilterChain(filters.iterator(), singletonFilterChain(new SoundOffFilter(filterCallOrder, 100)));
41
42 HttpServletRequest mockRequest = mock(HttpServletRequest.class);
43 when(mockRequest.getPathInfo()).thenReturn("some/path");
44 HttpServletResponse mockResponse = mock(HttpServletResponse.class);
45
46 chain.doFilter(mockRequest, mockResponse);
47
48
49 assertThat(filterCallOrder, contains(0, 1, 2, 3, 4, 100, 100, 4, 3, 2, 1, 0));
50 }
51
52 @Test
53 public void testFilterCanAbortChain() throws IOException, ServletException {
54 final List<Integer> filterCallOrder = new LinkedList<>();
55 List<Filter> filters = new ArrayList<>();
56 for (int i = 0; i < 2; i++) {
57 filters.add(new SoundOffFilter(filterCallOrder, i));
58 }
59 filters.add(new FilterAdapter() {
60 @Override
61 public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) {
62 filterCallOrder.add(50);
63 }
64 });
65 for (int i = 3; i < 5; i++) {
66 filters.add(new SoundOffFilter(filterCallOrder, i));
67 }
68
69 FilterChain chain = new IteratingFilterChain(filters.iterator(), singletonFilterChain(new SoundOffFilter(filterCallOrder, 100)));
70
71 HttpServletRequest mockRequest = mock(HttpServletRequest.class);
72 when(mockRequest.getPathInfo()).thenReturn("some/path");
73 HttpServletResponse mockResponse = mock(HttpServletResponse.class);
74
75 chain.doFilter(mockRequest, mockResponse);
76
77
78 assertThat(filterCallOrder, contains(0, 1, 50, 1, 0));
79 }
80
81 @Test
82 public void testExceptionFiltersUpWhenFilterThrowsException() throws IOException {
83 final List<Integer> filterCallOrder = new LinkedList<>();
84 List<Filter> filters = new ArrayList<>();
85 for (int i = 0; i < 2; i++) {
86 filters.add(new SoundOffFilter(filterCallOrder, i));
87 }
88 filters.add(new FilterAdapter() {
89 @Override
90 public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
91 throws ServletException {
92 throw new ServletException();
93 }
94 });
95 for (int i = 3; i < 5; i++) {
96 filters.add(new SoundOffFilter(filterCallOrder, i));
97 }
98
99 FilterChain chain = new IteratingFilterChain(filters.iterator(), singletonFilterChain(new SoundOffFilter(filterCallOrder, 100)));
100
101 HttpServletRequest mockRequest = mock(HttpServletRequest.class);
102 when(mockRequest.getPathInfo()).thenReturn("some/path");
103 HttpServletResponse mockResponse = mock(HttpServletResponse.class);
104
105 try {
106 chain.doFilter(mockRequest, mockResponse);
107 fail("ServletException should filter up");
108 } catch (ServletException e) {
109
110 assertThat(filterCallOrder, contains(0, 1));
111 }
112 }
113 }