1 package com.atlassian.plugins.rest.common.security.jersey;
2
3 import com.atlassian.plugin.tracker.PluginModuleTracker;
4 import com.atlassian.plugins.rest.common.security.CorsPreflightCheckCompleteException;
5 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaults;
6 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaultsModuleDescriptor;
7 import com.google.common.collect.ImmutableSet;
8 import com.sun.jersey.spi.container.ContainerRequest;
9 import com.sun.jersey.spi.container.ContainerResponse;
10 import org.junit.Before;
11 import org.junit.Test;
12 import org.junit.runner.RunWith;
13 import org.mockito.ArgumentCaptor;
14 import org.mockito.Matchers;
15 import org.mockito.Mock;
16 import org.mockito.runners.MockitoJUnitRunner;
17
18 import javax.ws.rs.core.MultivaluedMap;
19 import javax.ws.rs.core.Response;
20 import java.util.Arrays;
21 import java.util.Map;
22
23 import static com.google.common.collect.Maps.newHashMap;
24 import static com.google.common.collect.Sets.newHashSet;
25 import static org.junit.Assert.assertEquals;
26 import static org.junit.Assert.fail;
27 import static org.mockito.Mockito.*;
28
29 @RunWith (MockitoJUnitRunner.class)
30 public class TestCorsResourceFilter
31 {
32 @Mock
33 private CorsDefaults corsDefaults;
34
35 @Mock
36 private CorsDefaults corsDefaults2;
37
38 private CorsResourceFilter corsResourceFilter;
39 @Mock
40 private ContainerRequest request;
41 private Map<String,Object> requestProps;
42 @Mock
43 private ContainerResponse response;
44 @Mock
45 private PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> tracker;
46
47 @Before
48 public void setUp()
49 {
50 when(response.getResponse()).thenReturn(Response.ok().build());
51 requestProps = newHashMap();
52 when(request.getProperties()).thenReturn(requestProps);
53 corsResourceFilter = new CorsResourceFilter(tracker, "GET");
54 }
55
56 @Test
57 public void testSimplePreflightForGet()
58 {
59 String origin = "http://localhost";
60 requestProps.put("Cors-Preflight-Requested", "true");
61 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults));
62 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
63 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
64 when(request.getHeaderValue("Origin")).thenReturn(origin);
65 MultivaluedMap<String, Object> headers = execPreflight();
66 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
67 }
68
69 @Test
70 public void testPreflightSucceedsWhenOneCorsDefaultsAllowsOrigin()
71 {
72 String origin = "http://localhost";
73 MultivaluedMap<String, Object> headers = execPreflightWithTwoCorsDefaults(origin);
74 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
75 }
76
77 @Test
78 public void testSecondCorsDefaultsIsNotHitIfDoesntAllowOrigin()
79 {
80 String origin = "http://localhost";
81 MultivaluedMap<String, Object> headers = execPreflightWithTwoCorsDefaults(origin);
82 verify(corsDefaults2, never()).allowsCredentials(Matchers.<String>any());
83 verify(corsDefaults2, never()).getAllowedRequestHeaders(Matchers.<String>any());
84 verify(corsDefaults2, never()).getAllowedResponseHeaders(Matchers.<String>any());
85 }
86
87 @Test
88 public void testSimplePreflightForGetWrongDomain()
89 {
90 String origin = "http://localhost";
91 requestProps.put("Cors-Preflight-Requested", "true");
92 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults));
93 when(corsDefaults.allowsOrigin(origin)).thenReturn(false);
94 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
95 when(request.getHeaderValue("Origin")).thenReturn(origin);
96 execBadPreflight();
97 }
98
99 @Test
100 public void testSimplePreflightForGetWrongMethod()
101 {
102 String origin = "http://localhost";
103 requestProps.put("Cors-Preflight-Requested", "true");
104 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults));
105 when(corsDefaults.allowsOrigin(origin)).thenReturn(false);
106 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("POST");
107 when(request.getHeaderValue("Origin")).thenReturn(origin);
108 execBadPreflight();
109 }
110
111 @Test
112 public void testSimplePreflightForGetWrongHeaders()
113 {
114 String origin = "http://localhost";
115 requestProps.put("Cors-Preflight-Requested", "true");
116 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults));
117 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
118 when(corsDefaults.getAllowedRequestHeaders(origin)).thenReturn(ImmutableSet.<String>of("Foo-Header"));
119 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
120 when(request.getRequestHeader("Access-Control-Request-Headers")).thenReturn(Arrays.asList("Bar-Header"));
121 when(request.getHeaderValue("Origin")).thenReturn(origin);
122 execBadPreflight();
123 }
124
125 @Test
126 public void testSimpleGet()
127 {
128 String origin = "http://localhost";
129 when(request.getMethod()).thenReturn("GET");
130 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults));
131 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
132 when(request.getHeaderValue("Origin")).thenReturn(origin);
133 MultivaluedMap<String, Object> headers = execNoPreflightWithHeaders();
134 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
135 }
136
137 @Test
138 public void testSimpleGetWhenOneCorsDefaultsAllowsOrigin()
139 {
140 String origin = "http://localhost";
141 MultivaluedMap<String, Object> headers = execNoPreflightWithHeadersForTwoCorsDefaults(origin);
142 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
143 }
144
145 @Test
146 public void testSecondCorsDefaultIsNotCalledWhenItDoesntAllowOrigin()
147 {
148 String origin = "http://localhost";
149 execNoPreflightWithHeadersForTwoCorsDefaults(origin);
150 verify(corsDefaults2, never()).allowsCredentials(Matchers.<String>any());
151 verify(corsDefaults2, never()).getAllowedRequestHeaders(Matchers.<String>any());
152 verify(corsDefaults2, never()).getAllowedResponseHeaders(Matchers.<String>any());
153 }
154
155 @Test
156 public void testSimpleGetWrongOrigin()
157 {
158 String origin = "http://localhost";
159 when(request.getMethod()).thenReturn("GET");
160 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults));
161 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
162 when(request.getHeaderValue("Origin")).thenReturn("http://foo.com");
163 execNoPreflightNoHeaders();
164 }
165
166 @Test
167 public void testSimpleGetNoOrigin()
168 {
169 when(request.getMethod()).thenReturn("GET");
170 execNoPreflightNoHeaders();
171 }
172
173 private MultivaluedMap<String, Object> execPreflight()
174 {
175 try
176 {
177 corsResourceFilter.filter(request);
178 fail("Should have thrown preflight exception");
179 return null;
180 }
181 catch (CorsPreflightCheckCompleteException ex)
182 {
183 return ex.getResponse().getMetadata();
184 }
185 }
186
187 private MultivaluedMap<String, Object> execNoPreflightWithHeaders()
188 {
189 try
190 {
191 corsResourceFilter.filter(request);
192 ArgumentCaptor<Response> argument = ArgumentCaptor.forClass(Response.class);
193 corsResourceFilter.filter(request, response);
194 verify(response).setResponse(argument.capture());
195 return argument.getValue().getMetadata();
196 }
197 catch (CorsPreflightCheckCompleteException ex)
198 {
199 fail("Shouldn't have thrown preflight exception");
200 return null;
201 }
202 }
203
204 private void execNoPreflightNoHeaders()
205 {
206 try
207 {
208 corsResourceFilter.filter(request);
209 ArgumentCaptor<Response> argument = ArgumentCaptor.forClass(Response.class);
210 corsResourceFilter.filter(request, response);
211 verify(response, never()).setResponse(argument.capture());
212 }
213 catch (CorsPreflightCheckCompleteException ex)
214 {
215 fail("Shouldn't have thrown preflight exception");
216 }
217 }
218
219 private CorsPreflightCheckCompleteException execBadPreflight()
220 {
221 try
222 {
223 corsResourceFilter.filter(request);
224 corsResourceFilter.filter(request, response);
225 fail("Should have thrown preflight exception");
226 return null;
227 }
228 catch (CorsPreflightCheckCompleteException ex)
229 {
230 return ex;
231 }
232 }
233
234 private MultivaluedMap<String, Object> execPreflightWithTwoCorsDefaults(String origin)
235 {
236 requestProps.put("Cors-Preflight-Requested", "true");
237 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults, corsDefaults2));
238 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
239 when(corsDefaults2.allowsOrigin(origin)).thenReturn(false);
240 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
241 when(request.getHeaderValue("Origin")).thenReturn(origin);
242 MultivaluedMap<String, Object> headers = execPreflight();
243 return headers;
244 }
245
246 private MultivaluedMap<String, Object> execNoPreflightWithHeadersForTwoCorsDefaults(String origin)
247 {
248 when(request.getMethod()).thenReturn("GET");
249 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults, corsDefaults2));
250 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
251 when(corsDefaults2.allowsOrigin(origin)).thenReturn(false);
252 when(request.getHeaderValue("Origin")).thenReturn(origin);
253 MultivaluedMap<String, Object> headers = execNoPreflightWithHeaders();
254 return headers;
255 }
256 }