1 package com.atlassian.plugins.rest.common.security.jersey;
2
3 import com.atlassian.plugin.tracker.PluginModuleTracker;
4 import com.atlassian.plugins.rest.common.security.CorsPreflightCheckCompleteException;
5 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaults;
6 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaultsModuleDescriptor;
7 import com.sun.jersey.spi.container.ContainerRequest;
8 import com.sun.jersey.spi.container.ContainerRequestFilter;
9 import com.sun.jersey.spi.container.ContainerResponse;
10 import com.sun.jersey.spi.container.ContainerResponseFilter;
11 import com.sun.jersey.spi.container.ResourceFilter;
12 import org.slf4j.Logger;
13 import org.slf4j.LoggerFactory;
14
15 import javax.ws.rs.core.Response;
16 import java.net.URI;
17 import java.util.Collections;
18 import java.util.List;
19 import java.util.Set;
20
21 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS;
22 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_HEADERS;
23 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_METHODS;
24 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_ORIGIN;
25 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_EXPOSE_HEADERS;
26 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_MAX_AGE;
27 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_REQUEST_HEADERS;
28 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_REQUEST_METHOD;
29 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ORIGIN;
30 import static com.google.common.collect.Lists.newArrayList;
31 import static com.google.common.collect.Sets.newHashSet;
32
33
34
35
36
37
38
39
40
41 public class CorsResourceFilter implements ResourceFilter, ContainerRequestFilter, ContainerResponseFilter
42 {
43 private static final String CORS_PREFLIGHT_FAILED = "Cors-Preflight-Failed";
44 private static final String CORS_PREFLIGHT_SUCCEEDED = "Cors-Preflight-Succeeded";
45 public static final String CORS_PREFLIGHT_REQUESTED = "Cors-Preflight-Requested";
46 private static final Logger log = LoggerFactory.getLogger(CorsResourceFilter.class);
47
48 private final PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> pluginModuleTracker;
49 private final String allowMethod;
50
51 public CorsResourceFilter(PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> pluginModuleTracker, String allowMethod)
52 {
53 this.allowMethod = allowMethod;
54 this.pluginModuleTracker = pluginModuleTracker;
55 }
56
57 public ContainerRequest filter(final ContainerRequest request)
58 {
59
60
61
62
63
64
65
66
67
68
69
70 if (request.getProperties().containsKey(CORS_PREFLIGHT_REQUESTED))
71 {
72 CorsDefaults defaults = new AggregateCorsDefaults(pluginModuleTracker.getModules());
73
74 try
75 {
76 Response.ResponseBuilder response = Response.ok();
77 String origin = validateSingleOriginInWhitelist(defaults, request);
78 validateAccessControlRequestMethod(allowMethod, request);
79 Set<String> allowedRequestHeaders = defaults.getAllowedRequestHeaders(origin);
80 validateAccessControlRequestHeaders(allowedRequestHeaders, request);
81
82 addAccessControlAllowOrigin(response, origin);
83 conditionallyAddAccessControlAllowCredentials(response, origin, defaults);
84 addAccessControlMaxAge(response);
85 addAccessControlAllowMethods(response, allowMethod);
86 addAccessControlAllowHeaders(response, allowedRequestHeaders);
87
88 request.getProperties().put(CORS_PREFLIGHT_SUCCEEDED, "true");
89
90 throw new CorsPreflightCheckCompleteException(response.build());
91 }
92 catch (PreflightFailedException ex)
93 {
94 Response.ResponseBuilder response = Response.ok();
95 request.getProperties().put(CORS_PREFLIGHT_FAILED, "true");
96 log.info("CORS preflight failed: " + ex.getMessage());
97 throw new CorsPreflightCheckCompleteException(response.build());
98 }
99 }
100 else
101 {
102 return request;
103 }
104 }
105
106 public ContainerResponse filter(ContainerRequest request, ContainerResponse containerResponse)
107 {
108
109
110
111
112
113 if (!request.getProperties().containsKey(CORS_PREFLIGHT_FAILED) &&
114 !request.getProperties().containsKey(CORS_PREFLIGHT_SUCCEEDED) &&
115 extractOrigin(request) != null)
116 {
117 CorsDefaults defaults = new AggregateCorsDefaults(pluginModuleTracker.getModules());
118 try
119 {
120 String origin = validateAnyOriginInListInWhitelist(defaults, request);
121
122 Response.ResponseBuilder response = Response.fromResponse(containerResponse.getResponse());
123 addAccessControlAllowOrigin(response, origin);
124 conditionallyAddAccessControlAllowCredentials(response, origin, defaults);
125 addAccessControlExposeHeaders(response, defaults.getAllowedResponseHeaders(origin));
126 containerResponse.setResponse(response.build());
127 return containerResponse;
128 }
129 catch (PreflightFailedException ex)
130 {
131 log.info("Unable to add CORS headers to response: " + ex.getMessage());
132 }
133 }
134 return containerResponse;
135 }
136
137 private void addAccessControlExposeHeaders(Response.ResponseBuilder response, Set<String> allowedHeaders)
138 {
139 for (String header : allowedHeaders)
140 {
141 response.header(ACCESS_CONTROL_EXPOSE_HEADERS.value(), header);
142 }
143 }
144
145 private void addAccessControlAllowHeaders(Response.ResponseBuilder response, Set<String> allowedHeaders)
146 {
147 for (String header : allowedHeaders)
148 {
149 response.header(ACCESS_CONTROL_ALLOW_HEADERS.value(), header);
150 }
151 }
152
153 private void addAccessControlAllowMethods(Response.ResponseBuilder response, String allowMethod)
154 {
155 response.header(ACCESS_CONTROL_ALLOW_METHODS.value(), allowMethod);
156 }
157
158 private void addAccessControlMaxAge(Response.ResponseBuilder response)
159 {
160 response.header(ACCESS_CONTROL_MAX_AGE.value(), 60 * 60);
161 }
162
163 private void addAccessControlAllowOrigin(Response.ResponseBuilder response, String origin)
164 {
165 response.header(ACCESS_CONTROL_ALLOW_ORIGIN.value(), origin);
166 }
167
168 private void conditionallyAddAccessControlAllowCredentials(Response.ResponseBuilder response, String origin, CorsDefaults defaults)
169 {
170 if (defaults.allowsCredentials(origin))
171 {
172 response.header(ACCESS_CONTROL_ALLOW_CREDENTIALS.value(), "true");
173 }
174 }
175
176 private void validateAccessControlRequestHeaders(Set<String> allowedHeaders, ContainerRequest request) throws PreflightFailedException
177 {
178
179 List<String> requestedHeaders = request.getRequestHeader(ACCESS_CONTROL_REQUEST_HEADERS.value());
180 requestedHeaders = requestedHeaders != null ? requestedHeaders : Collections.<String>emptyList();
181 if (!allowedHeaders.containsAll(requestedHeaders))
182 {
183 List<String> unexpectedHeaders = newArrayList(requestedHeaders);
184 unexpectedHeaders.removeAll(allowedHeaders);
185
186 throw new PreflightFailedException("Unexpected headers in CORS request: " + unexpectedHeaders);
187 }
188 }
189
190 private void validateAccessControlRequestMethod(String allowMethod, ContainerRequest request) throws PreflightFailedException
191 {
192 String requestedMethod = request.getHeaderValue(ACCESS_CONTROL_REQUEST_METHOD.value());
193 if (!allowMethod.equals(requestedMethod))
194 {
195 throw new PreflightFailedException("Invalid method: " + requestedMethod);
196 }
197 }
198
199 private String validateAnyOriginInListInWhitelist(CorsDefaults defaults, ContainerRequest request) throws PreflightFailedException
200 {
201 String originRaw = extractOrigin(request);
202 String[] originList = originRaw.split(" ");
203 for (String origin : originList)
204 {
205 validateOriginAsUri(origin);
206 if (defaults.allowsOrigin(origin))
207 {
208 return origin;
209 }
210 }
211 throw new PreflightFailedException("Origins '" + originRaw + "' not in whitelist");
212 }
213
214 private String validateSingleOriginInWhitelist(CorsDefaults defaults, ContainerRequest request) throws PreflightFailedException
215 {
216 String origin = extractOrigin(request);
217 validateOriginAsUri(origin);
218
219 if (!defaults.allowsOrigin(origin))
220 {
221 throw new PreflightFailedException("Origin '" + origin + "' not in whitelist");
222 }
223 return origin;
224 }
225
226 private void validateOriginAsUri(String origin) throws PreflightFailedException
227 {
228 try
229 {
230 URI.create(origin);
231 }
232 catch (IllegalArgumentException ex)
233 {
234 throw new PreflightFailedException("Origin '" + origin + "' is not a valid URI");
235 }
236 }
237
238 public static String extractOrigin(ContainerRequest request)
239 {
240 return request.getHeaderValue(ORIGIN.value());
241 }
242
243 public ContainerRequestFilter getRequestFilter()
244 {
245 return this;
246 }
247
248 public ContainerResponseFilter getResponseFilter()
249 {
250 return this;
251 }
252
253
254
255
256 private static class PreflightFailedException extends Exception
257 {
258 private PreflightFailedException(String message)
259 {
260 super(message);
261 }
262 }
263
264 private static class AggregateCorsDefaults implements CorsDefaults
265 {
266 private final Iterable<CorsDefaults> delegates;
267
268 public AggregateCorsDefaults(Iterable<CorsDefaults> delegates)
269 {
270 this.delegates = delegates;
271 }
272
273 public boolean allowsCredentials(String uri)
274 {
275 for (CorsDefaults defs : delegates)
276 {
277 if (defs.allowsCredentials(uri))
278 {
279 return true;
280 }
281 }
282 return false;
283 }
284
285 public boolean allowsOrigin(String uri)
286 {
287 for (CorsDefaults defs : delegates)
288 {
289 if (defs.allowsOrigin(uri))
290 {
291 return true;
292 }
293 }
294 return false;
295 }
296
297 public Set<String> getAllowedRequestHeaders(String uri)
298 {
299 Set<String> result = newHashSet();
300 for (CorsDefaults defs : delegates)
301 {
302 result.addAll(defs.getAllowedRequestHeaders(uri));
303 }
304 return result;
305 }
306
307 public Set<String> getAllowedResponseHeaders(String uri)
308 {
309 Set<String> result = newHashSet();
310 for (CorsDefaults defs : delegates)
311 {
312 result.addAll(defs.getAllowedResponseHeaders(uri));
313 }
314 return result;
315 }
316 }
317
318 }