1 package com.atlassian.plugins.rest.common.security.jersey;
2
3 import com.atlassian.plugin.tracker.PluginModuleTracker;
4 import com.atlassian.plugins.rest.common.security.CorsPreflightCheckCompleteException;
5 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaults;
6 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaultsModuleDescriptor;
7 import com.google.common.collect.ImmutableSet;
8 import com.sun.jersey.spi.container.ContainerRequest;
9 import com.sun.jersey.spi.container.ContainerResponse;
10 import org.junit.Before;
11 import org.junit.Test;
12 import org.junit.runner.RunWith;
13 import org.mockito.ArgumentCaptor;
14 import org.mockito.Mock;
15 import org.mockito.runners.MockitoJUnitRunner;
16
17 import javax.ws.rs.core.MultivaluedMap;
18 import javax.ws.rs.core.Response;
19 import java.util.Arrays;
20 import java.util.Map;
21
22 import static com.google.common.collect.Maps.newHashMap;
23 import static com.google.common.collect.Sets.newHashSet;
24 import static org.junit.Assert.assertEquals;
25 import static org.junit.Assert.fail;
26 import static org.mockito.Mockito.*;
27
28 @RunWith (MockitoJUnitRunner.class)
29 public class TestCorsResourceFilter
30 {
31 @Mock
32 private CorsDefaults corsDefaults;
33 private CorsResourceFilter corsResourceFilter;
34 @Mock
35 private ContainerRequest request;
36 private Map<String,Object> requestProps;
37 @Mock
38 private ContainerResponse response;
39 @Mock
40 private PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> tracker;
41
42 @Before
43 public void setUp()
44 {
45 when(tracker.getModules()).thenReturn(newHashSet(corsDefaults));
46 when(response.getResponse()).thenReturn(Response.ok().build());
47 requestProps = newHashMap();
48 when(request.getProperties()).thenReturn(requestProps);
49 corsResourceFilter = new CorsResourceFilter(tracker, "GET");
50 }
51
52 @Test
53 public void testSimplePreflightForGet()
54 {
55 String origin = "http://localhost";
56 requestProps.put("Cors-Preflight-Requested", "true");
57 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
58 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
59 when(request.getHeaderValue("Origin")).thenReturn(origin);
60 MultivaluedMap<String, Object> headers = execPreflight();
61 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
62 }
63
64 @Test
65 public void testSimplePreflightForGetWrongDomain()
66 {
67 String origin = "http://localhost";
68 requestProps.put("Cors-Preflight-Requested", "true");
69 when(corsDefaults.allowsOrigin(origin)).thenReturn(false);
70 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
71 when(request.getHeaderValue("Origin")).thenReturn(origin);
72 execBadPreflight();
73 }
74
75 @Test
76 public void testSimplePreflightForGetWrongMethod()
77 {
78 String origin = "http://localhost";
79 requestProps.put("Cors-Preflight-Requested", "true");
80 when(corsDefaults.allowsOrigin(origin)).thenReturn(false);
81 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("POST");
82 when(request.getHeaderValue("Origin")).thenReturn(origin);
83 execBadPreflight();
84 }
85
86 @Test
87 public void testSimplePreflightForGetWrongHeaders()
88 {
89 String origin = "http://localhost";
90 requestProps.put("Cors-Preflight-Requested", "true");
91 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
92 when(corsDefaults.getAllowedRequestHeaders(origin)).thenReturn(ImmutableSet.<String>of("Foo-Header"));
93 when(request.getHeaderValue("Access-Control-Request-Method")).thenReturn("GET");
94 when(request.getRequestHeader("Access-Control-Request-Headers")).thenReturn(Arrays.asList("Bar-Header"));
95 when(request.getHeaderValue("Origin")).thenReturn(origin);
96 execBadPreflight();
97 }
98
99 @Test
100 public void testSimpleGet()
101 {
102 String origin = "http://localhost";
103 when(request.getMethod()).thenReturn("GET");
104 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
105 when(request.getHeaderValue("Origin")).thenReturn(origin);
106 MultivaluedMap<String, Object> headers = execNoPreflightWithHeaders();
107 assertEquals(headers.getFirst("Access-Control-Allow-Origin"), origin);
108 }
109
110 @Test
111 public void testSimpleGetWrongOrigin()
112 {
113 String origin = "http://localhost";
114 when(request.getMethod()).thenReturn("GET");
115 when(corsDefaults.allowsOrigin(origin)).thenReturn(true);
116 when(request.getHeaderValue("Origin")).thenReturn("http://foo.com");
117 execNoPreflightNoHeaders();
118 }
119
120 @Test
121 public void testSimpleGetNoOrigin()
122 {
123 when(request.getMethod()).thenReturn("GET");
124 execNoPreflightNoHeaders();
125 }
126
127 private MultivaluedMap<String, Object> execPreflight()
128 {
129 try
130 {
131 corsResourceFilter.filter(request);
132 fail("Should have thrown preflight exception");
133 return null;
134 }
135 catch (CorsPreflightCheckCompleteException ex)
136 {
137 return ex.getResponse().getMetadata();
138 }
139 }
140
141 private MultivaluedMap<String, Object> execNoPreflightWithHeaders()
142 {
143 try
144 {
145 corsResourceFilter.filter(request);
146 ArgumentCaptor<Response> argument = ArgumentCaptor.forClass(Response.class);
147 corsResourceFilter.filter(request, response);
148 verify(response).setResponse(argument.capture());
149 return argument.getValue().getMetadata();
150 }
151 catch (CorsPreflightCheckCompleteException ex)
152 {
153 fail("Shouldn't have thrown preflight exception");
154 return null;
155 }
156 }
157
158 private void execNoPreflightNoHeaders()
159 {
160 try
161 {
162 corsResourceFilter.filter(request);
163 ArgumentCaptor<Response> argument = ArgumentCaptor.forClass(Response.class);
164 corsResourceFilter.filter(request, response);
165 verify(response, never()).setResponse(argument.capture());
166 }
167 catch (CorsPreflightCheckCompleteException ex)
168 {
169 fail("Shouldn't have thrown preflight exception");
170 }
171 }
172
173 private CorsPreflightCheckCompleteException execBadPreflight()
174 {
175 try
176 {
177 corsResourceFilter.filter(request);
178 corsResourceFilter.filter(request, response);
179 fail("Should have thrown preflight exception");
180 return null;
181 }
182 catch (CorsPreflightCheckCompleteException ex)
183 {
184 return ex;
185 }
186 }
187 }