1 package com.atlassian.plugins.rest.common.security.jersey;
2
3 import com.atlassian.plugin.tracker.PluginModuleTracker;
4 import com.atlassian.plugins.rest.common.security.CorsHeaders;
5 import com.atlassian.plugins.rest.common.security.CorsPreflightCheckCompleteException;
6 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaults;
7 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaultsModuleDescriptor;
8 import com.google.common.collect.ImmutableSet;
9 import com.sun.jersey.spi.container.ContainerRequest;
10 import com.sun.jersey.spi.container.ContainerResponse;
11 import org.junit.Before;
12 import org.junit.Rule;
13 import org.junit.Test;
14 import org.junit.rules.ExpectedException;
15 import org.junit.runner.RunWith;
16 import org.mockito.ArgumentCaptor;
17 import org.mockito.Matchers;
18 import org.mockito.Mock;
19 import org.mockito.runners.MockitoJUnitRunner;
20
21 import javax.ws.rs.core.MultivaluedMap;
22 import javax.ws.rs.core.Response;
23 import java.util.Arrays;
24 import java.util.Collections;
25 import java.util.List;
26 import java.util.Map;
27 import java.util.Set;
28
29 import static com.google.common.collect.Maps.newHashMap;
30 import static com.google.common.collect.Sets.newHashSet;
31 import static org.hamcrest.MatcherAssert.assertThat;
32 import static org.hamcrest.core.Is.is;
33 import static org.junit.Assert.assertEquals;
34 import static org.junit.Assert.fail;
35 import static org.mockito.Mockito.*;
36
37 @RunWith(MockitoJUnitRunner.class)
38 public class TestCorsResourceFilter {
39 @Mock
40 private CorsDefaults corsDefaults;
41
42 @Mock
43 private CorsDefaults corsDefaults2;
44
45 private CorsResourceFilter corsResourceFilter;
46 @Mock
47 private ContainerRequest request;
48 private Map<String, Object> requestProps;
49 @Mock
50 private ContainerResponse response;
51 @Mock
52 private PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> tracker;
53
54 @Rule
55 public ExpectedException exception = ExpectedException.none();
56
57 @Before
58 public void setUp() {
59 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults));
60 when(response.getResponse()).thenReturn(Response.ok().build());
61 requestProps = newHashMap();
62 when(request.getProperties()).thenReturn(requestProps);
63 corsResourceFilter = new CorsResourceFilter(tracker, "GET");
64 }
65
66 @Test
67 public void testSimplePreflightForGet() {
68 String origin = "http://localhost";
69 requestProps.put("Cors-Preflight-Requested", "true");
70 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
71 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
72 when(request.getHeaderValue("Origin")).thenReturn(origin);
73 MultivaluedMap<String, Object> headers = execPreflight();
74 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
75 }
76
77 @Test
78 public void testPreflightSucceedsWhenOneCorsDefaultsAllowsOrigin() {
79 String origin = "http://localhost";
80 MultivaluedMap<String, Object> headers = execPreflightWithTwoCorsDefaults(origin);
81 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
82 }
83
84 @Test
85 public void testSecondCorsDefaultsIsNotHitIfDoesntAllowOrigin() {
86 String origin = "http://localhost";
87 MultivaluedMap<String, Object> headers = execPreflightWithTwoCorsDefaults(origin);
88 verify(corsDefaults2, never()).allowsCredentials(Matchers.<String>any());
89 verify(corsDefaults2, never()).getAllowedRequestHeaders(Matchers.<String>any());
90 verify(corsDefaults2, never()).getAllowedResponseHeaders(Matchers.<String>any());
91 }
92
93 @Test
94 public void relativeOriginUriIsNotAllowed() {
95 MultivaluedMap<String, Object> headers = execPreflightWithTwoCorsDefaults(
96 "/not-absolute.com");
97 assertEquals(headers.get(CorsHeaders.ACCESS_CONTROL_ALLOW_ORIGIN.value()),
98 null);
99 }
100
101 @Test
102 public void opaqueOriginUriIsNotAllowed() {
103 MultivaluedMap<String, Object> headers = execPreflightWithTwoCorsDefaults(
104 "opaque:test.com");
105 assertEquals(headers.get(CorsHeaders.ACCESS_CONTROL_ALLOW_ORIGIN.value()),
106 null);
107 }
108
109 @Test
110 public void nullOriginIsNotAllowed() {
111 MultivaluedMap<String, Object> headers = execPreflightWithTwoCorsDefaults(
112 "null");
113 assertEquals(headers.get(CorsHeaders.ACCESS_CONTROL_ALLOW_ORIGIN.value()),
114 null);
115 }
116
117 @Test
118 public void testSimplePreflightForGetWrongDomain() {
119 String origin = "http://localhost";
120 requestProps.put("Cors-Preflight-Requested", "true");
121 when(corsDefaults.allowsOrigin(origin)).thenReturn(false);
122 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
123 when(request.getHeaderValue("Origin")).thenReturn(origin);
124 exception.expect(CorsPreflightCheckCompleteException.class);
125 execBadPreflight();
126 }
127
128 @Test
129 public void testSimplePreflightForGetWrongMethod() {
130 String origin = "http://localhost";
131 requestProps.put("Cors-Preflight-Requested", "true");
132 when(corsDefaults.allowsOrigin(origin)).thenReturn(false);
133 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("POST");
134 when(request.getHeaderValue("Origin")).thenReturn(origin);
135 exception.expect(CorsPreflightCheckCompleteException.class);
136 execBadPreflight();
137 }
138
139 @Test
140 public void testSimplePreflightForGetWrongHeaders() {
141 String origin = "http://localhost";
142 requestProps.put("Cors-Preflight-Requested", "true");
143 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
144 when(corsDefaults.getAllowedRequestHeaders(origin)).thenReturn(ImmutableSet.<String>of("Foo-Header"));
145 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
146 when(request.getRequestHeader("Access-Control-Request-Headers")).thenReturn(Arrays.asList("Bar-Header"));
147 when(request.getHeaderValue("Origin")).thenReturn(origin);
148 exception.expect(CorsPreflightCheckCompleteException.class);
149 execBadPreflight();
150 }
151
152 @Test
153 public void testSimpleGet() {
154 String origin = "http://localhost";
155 when(request.getMethod()).thenReturn("GET");
156 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
157 when(request.getHeaderValue("Origin")).thenReturn(origin);
158 MultivaluedMap<String, Object> headers = execNoPreflightWithHeaders();
159 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
160 }
161
162 @Test
163 public void testSimpleGetWhenOneCorsDefaultsAllowsOrigin() {
164 String origin = "http://localhost";
165 MultivaluedMap<String, Object> headers = execNoPreflightWithHeadersForTwoCorsDefaults(origin);
166 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
167 }
168
169 @Test
170 public void testSecondCorsDefaultIsNotCalledWhenItDoesntAllowOrigin() {
171 String origin = "http://localhost";
172 execNoPreflightWithHeadersForTwoCorsDefaults(origin);
173 verify(corsDefaults2, never()).allowsCredentials(Matchers.<String>any());
174 verify(corsDefaults2, never()).getAllowedRequestHeaders(Matchers.<String>any());
175 verify(corsDefaults2, never()).getAllowedResponseHeaders(Matchers.<String>any());
176 }
177
178 @Test
179 public void testSimpleGetWrongOrigin() {
180 String origin = "http://localhost";
181 when(request.getMethod()).thenReturn("GET");
182 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
183 when(request.getHeaderValue("Origin")).thenReturn("http://foo.com");
184 execNoPreflightNoHeaders();
185 }
186
187 @Test
188 public void testSimpleGetNoOrigin() {
189 when(request.getMethod()).thenReturn("GET");
190 execNoPreflightNoHeaders();
191 }
192
193 @Test
194 public void testSimplePreflightWithMultipleRequestHeaderValues() {
195 execPreflightWithRequestHeaders(Collections.singletonList("foo-header,bar-header"));
196 assertEquals("true", requestProps.get("Cors-Preflight-Succeeded"));
197 }
198
199 @Test
200 public void simplePreflightWithMixedCasingRequestHeaderValues() {
201 execPreflightWithRequestHeaders(
202 Arrays.asList("fOO-header,bar-header"),
203 ImmutableSet.of("Foo-Header", "Bar-Header")
204 );
205 assertEquals("true", requestProps.get("Cors-Preflight-Succeeded"));
206 }
207
208 @Test
209 public void testSimplePreflightWithMultipleRequestHeaders() {
210 execPreflightWithRequestHeaders(Arrays.asList("foo-header", "bar-header"));
211 assertEquals("true", requestProps.get("Cors-Preflight-Succeeded"));
212 }
213
214 @Test
215 public void testSimplePreflightWithSpacesInRequestHeaders() {
216 execPreflightWithRequestHeaders(Collections.singletonList(" foo-header, bar-header "));
217 assertEquals("true", requestProps.get("Cors-Preflight-Succeeded"));
218 }
219
220 @Test
221 public void preflightWithMultipleRequestHeadersHasExpectedAccessControlAllowHeaders() {
222 final String headerName = CorsHeaders.ACCESS_CONTROL_ALLOW_HEADERS.value();
223 Set<String> values = ImmutableSet.of("foo-header", "bar-header");
224 MultivaluedMap<String, Object> headers = execPreflightWithRequestHeaders(
225 Collections.<String>emptyList(), values);
226 assertThat(headers.get(headerName).size(), is(1));
227 assertThat(ImmutableSet.copyOf(
228 headers.getFirst(headerName).toString().split(", ")),
229 is(values));
230 }
231
232 @Test
233 public void corsResponseHasExpectedAccessControlExposeHeaders() {
234 final String origin = "http://localhost";
235 final String headerName = CorsHeaders.ACCESS_CONTROL_EXPOSE_HEADERS.value();
236 Set<String> values = ImmutableSet.of("foo-header", "bar-header");
237
238 when(corsDefaults.getAllowedResponseHeaders(origin)).thenReturn(values);
239 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
240 when(request.getHeaderValue("Origin")).thenReturn(origin);
241
242 MultivaluedMap<String, Object> headers = execNoPreflightWithHeaders();
243 assertThat(headers.get(headerName).size(), is(1));
244 assertThat(ImmutableSet.copyOf(
245 headers.getFirst(headerName).toString().split(", ")),
246 is(values));
247 }
248
249 private MultivaluedMap<String, Object> execPreflight() {
250 try {
251 corsResourceFilter.filter(request);
252 fail("Should have thrown preflight exception");
253 return null;
254 } catch (CorsPreflightCheckCompleteException ex) {
255 return ex.getResponse().getMetadata();
256 }
257 }
258
259 private MultivaluedMap<String, Object> execNoPreflightWithHeaders() {
260 corsResourceFilter.filter(request);
261 ArgumentCaptor<Response> argument = ArgumentCaptor.forClass(Response.class);
262 corsResourceFilter.filter(request, response);
263 verify(response).setResponse(argument.capture());
264 return argument.getValue().getMetadata();
265 }
266
267 private void execNoPreflightNoHeaders() {
268 try {
269 corsResourceFilter.filter(request);
270 ArgumentCaptor<Response> argument = ArgumentCaptor.forClass(Response.class);
271 corsResourceFilter.filter(request, response);
272 verify(response, never()).setResponse(argument.capture());
273 } catch (CorsPreflightCheckCompleteException ex) {
274 fail("Shouldn't have thrown preflight exception");
275 }
276 }
277
278 private MultivaluedMap<String, Object> execPreflightWithRequestHeaders(
279 List<String> headers,
280 Set<String> allowedHeaders) {
281 String origin = "http://localhost";
282 requestProps.put("Cors-Preflight-Requested", "true");
283 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
284 when(corsDefaults.getAllowedRequestHeaders(origin)).thenReturn(allowedHeaders);
285 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
286 when(request.getRequestHeader("Access-Control-Request-Headers")).thenReturn(headers);
287 when(request.getHeaderValue("Origin")).thenReturn(origin);
288 return execPreflight();
289 }
290
291 private MultivaluedMap<String, Object> execPreflightWithRequestHeaders(
292 List<String> headers) {
293 return execPreflightWithRequestHeaders(headers,
294 ImmutableSet.of("foo-header", "bar-header"));
295 }
296
297 private void execBadPreflight()
298 throws CorsPreflightCheckCompleteException {
299 corsResourceFilter.filter(request);
300 corsResourceFilter.filter(request, response);
301 }
302
303 private MultivaluedMap<String, Object> execPreflightWithTwoCorsDefaults(String origin) {
304 requestProps.put("Cors-Preflight-Requested", "true");
305 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults, corsDefaults2));
306 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
307 when(corsDefaults2.allowsOrigin(origin)).thenReturn(false);
308 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
309 when(request.getHeaderValue("Origin")).thenReturn(origin);
310 MultivaluedMap<String, Object> headers = execPreflight();
311 return headers;
312 }
313
314 private MultivaluedMap<String, Object> execNoPreflightWithHeadersForTwoCorsDefaults(String origin) {
315 when(request.getMethod()).thenReturn("GET");
316 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults, corsDefaults2));
317 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
318 when(corsDefaults2.allowsOrigin(origin)).thenReturn(false);
319 when(request.getHeaderValue("Origin")).thenReturn(origin);
320 MultivaluedMap<String, Object> headers = execNoPreflightWithHeaders();
321 return headers;
322 }
323 }