1 package com.atlassian.plugins.rest.common.security.jersey;
2
3 import com.atlassian.plugin.tracker.PluginModuleTracker;
4 import com.atlassian.plugins.rest.common.security.CorsPreflightCheckCompleteException;
5 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaults;
6 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaultsModuleDescriptor;
7 import com.google.common.collect.ImmutableSet;
8 import com.sun.jersey.spi.container.ContainerRequest;
9 import com.sun.jersey.spi.container.ContainerResponse;
10 import org.junit.Before;
11 import org.junit.Test;
12 import org.junit.runner.RunWith;
13 import org.mockito.ArgumentCaptor;
14 import org.mockito.Matchers;
15 import org.mockito.Mock;
16 import org.mockito.runners.MockitoJUnitRunner;
17
18 import javax.ws.rs.core.MultivaluedMap;
19 import javax.ws.rs.core.Response;
20 import java.util.Arrays;
21 import java.util.Map;
22
23 import static com.google.common.collect.Maps.newHashMap;
24 import static com.google.common.collect.Sets.newHashSet;
25 import static org.junit.Assert.assertEquals;
26 import static org.junit.Assert.fail;
27 import static org.mockito.Mockito.*;
28
29 @RunWith (MockitoJUnitRunner.class)
30 public class TestCorsResourceFilter
31 {
32 @Mock
33 private CorsDefaults corsDefaults;
34
35 @Mock
36 private CorsDefaults corsDefaults2;
37
38 private CorsResourceFilter corsResourceFilter;
39 @Mock
40 private ContainerRequest request;
41 private Map<String,Object> requestProps;
42 @Mock
43 private ContainerResponse response;
44 @Mock
45 private PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> tracker;
46
47 @Before
48 public void setUp()
49 {
50 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults));
51 when(response.getResponse()).thenReturn(Response.ok().build());
52 requestProps = newHashMap();
53 when(request.getProperties()).thenReturn(requestProps);
54 corsResourceFilter = new CorsResourceFilter(tracker, "GET");
55 }
56
57 @Test
58 public void testSimplePreflightForGet()
59 {
60 String origin = "http://localhost";
61 requestProps.put("Cors-Preflight-Requested", "true");
62 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
63 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
64 when(request.getHeaderValue("Origin")).thenReturn(origin);
65 MultivaluedMap<String, Object> headers = execPreflight();
66 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
67 }
68
69 @Test
70 public void testPreflightSucceedsWhenOneCorsDefaultsAllowsOrigin()
71 {
72 String origin = "http://localhost";
73 MultivaluedMap<String, Object> headers = execPreflightWithTwoCorsDefaults(origin);
74 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
75 }
76
77 @Test
78 public void testSecondCorsDefaultsIsNotHitIfDoesntAllowOrigin()
79 {
80 String origin = "http://localhost";
81 MultivaluedMap<String, Object> headers = execPreflightWithTwoCorsDefaults(origin);
82 verify(corsDefaults2, never()).allowsCredentials(Matchers.<String>any());
83 verify(corsDefaults2, never()).getAllowedRequestHeaders(Matchers.<String>any());
84 verify(corsDefaults2, never()).getAllowedResponseHeaders(Matchers.<String>any());
85 }
86
87 @Test
88 public void testSimplePreflightForGetWrongDomain()
89 {
90 String origin = "http://localhost";
91 requestProps.put("Cors-Preflight-Requested", "true");
92 when(corsDefaults.allowsOrigin(origin)).thenReturn(false);
93 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
94 when(request.getHeaderValue("Origin")).thenReturn(origin);
95 execBadPreflight();
96 }
97
98 @Test
99 public void testSimplePreflightForGetWrongMethod()
100 {
101 String origin = "http://localhost";
102 requestProps.put("Cors-Preflight-Requested", "true");
103 when(corsDefaults.allowsOrigin(origin)).thenReturn(false);
104 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("POST");
105 when(request.getHeaderValue("Origin")).thenReturn(origin);
106 execBadPreflight();
107 }
108
109 @Test
110 public void testSimplePreflightForGetWrongHeaders()
111 {
112 String origin = "http://localhost";
113 requestProps.put("Cors-Preflight-Requested", "true");
114 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
115 when(corsDefaults.getAllowedRequestHeaders(origin)).thenReturn(ImmutableSet.<String>of("Foo-Header"));
116 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
117 when(request.getRequestHeader("Access-Control-Request-Headers")).thenReturn(Arrays.asList("Bar-Header"));
118 when(request.getHeaderValue("Origin")).thenReturn(origin);
119 execBadPreflight();
120 }
121
122 @Test
123 public void testSimpleGet()
124 {
125 String origin = "http://localhost";
126 when(request.getMethod()).thenReturn("GET");
127 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
128 when(request.getHeaderValue("Origin")).thenReturn(origin);
129 MultivaluedMap<String, Object> headers = execNoPreflightWithHeaders();
130 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
131 }
132
133 @Test
134 public void testSimpleGetWhenOneCorsDefaultsAllowsOrigin()
135 {
136 String origin = "http://localhost";
137 MultivaluedMap<String, Object> headers = execNoPreflightWithHeadersForTwoCorsDefaults(origin);
138 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
139 }
140
141 @Test
142 public void testSecondCorsDefaultIsNotCalledWhenItDoesntAllowOrigin()
143 {
144 String origin = "http://localhost";
145 execNoPreflightWithHeadersForTwoCorsDefaults(origin);
146 verify(corsDefaults2, never()).allowsCredentials(Matchers.<String>any());
147 verify(corsDefaults2, never()).getAllowedRequestHeaders(Matchers.<String>any());
148 verify(corsDefaults2, never()).getAllowedResponseHeaders(Matchers.<String>any());
149 }
150
151 @Test
152 public void testSimpleGetWrongOrigin()
153 {
154 String origin = "http://localhost";
155 when(request.getMethod()).thenReturn("GET");
156 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
157 when(request.getHeaderValue("Origin")).thenReturn("http://foo.com");
158 execNoPreflightNoHeaders();
159 }
160
161 @Test
162 public void testSimpleGetNoOrigin()
163 {
164 when(request.getMethod()).thenReturn("GET");
165 execNoPreflightNoHeaders();
166 }
167
168 private MultivaluedMap<String, Object> execPreflight()
169 {
170 try
171 {
172 corsResourceFilter.filter(request);
173 fail("Should have thrown preflight exception");
174 return null;
175 }
176 catch (CorsPreflightCheckCompleteException ex)
177 {
178 return ex.getResponse().getMetadata();
179 }
180 }
181
182 private MultivaluedMap<String, Object> execNoPreflightWithHeaders()
183 {
184 try
185 {
186 corsResourceFilter.filter(request);
187 ArgumentCaptor<Response> argument = ArgumentCaptor.forClass(Response.class);
188 corsResourceFilter.filter(request, response);
189 verify(response).setResponse(argument.capture());
190 return argument.getValue().getMetadata();
191 }
192 catch (CorsPreflightCheckCompleteException ex)
193 {
194 fail("Shouldn't have thrown preflight exception");
195 return null;
196 }
197 }
198
199 private void execNoPreflightNoHeaders()
200 {
201 try
202 {
203 corsResourceFilter.filter(request);
204 ArgumentCaptor<Response> argument = ArgumentCaptor.forClass(Response.class);
205 corsResourceFilter.filter(request, response);
206 verify(response, never()).setResponse(argument.capture());
207 }
208 catch (CorsPreflightCheckCompleteException ex)
209 {
210 fail("Shouldn't have thrown preflight exception");
211 }
212 }
213
214 private CorsPreflightCheckCompleteException execBadPreflight()
215 {
216 try
217 {
218 corsResourceFilter.filter(request);
219 corsResourceFilter.filter(request, response);
220 fail("Should have thrown preflight exception");
221 return null;
222 }
223 catch (CorsPreflightCheckCompleteException ex)
224 {
225 return ex;
226 }
227 }
228
229 private MultivaluedMap<String, Object> execPreflightWithTwoCorsDefaults(String origin)
230 {
231 requestProps.put("Cors-Preflight-Requested", "true");
232 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults, corsDefaults2));
233 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
234 when(corsDefaults2.allowsOrigin(origin)).thenReturn(false);
235 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
236 when(request.getHeaderValue("Origin")).thenReturn(origin);
237 MultivaluedMap<String, Object> headers = execPreflight();
238 return headers;
239 }
240
241 private MultivaluedMap<String, Object> execNoPreflightWithHeadersForTwoCorsDefaults(String origin)
242 {
243 when(request.getMethod()).thenReturn("GET");
244 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults, corsDefaults2));
245 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
246 when(corsDefaults2.allowsOrigin(origin)).thenReturn(false);
247 when(request.getHeaderValue("Origin")).thenReturn(origin);
248 MultivaluedMap<String, Object> headers = execNoPreflightWithHeaders();
249 return headers;
250 }
251 }