1 package com.atlassian.plugins.rest.common.security.jersey;
2
3 import com.atlassian.plugin.tracker.PluginModuleTracker;
4 import com.atlassian.plugins.rest.common.security.CorsPreflightCheckCompleteException;
5 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaults;
6 import com.atlassian.plugins.rest.common.security.descriptor.CorsDefaultsModuleDescriptor;
7 import com.google.common.base.Predicate;
8 import com.google.common.collect.Iterables;
9 import com.sun.jersey.spi.container.ContainerRequest;
10 import com.sun.jersey.spi.container.ContainerRequestFilter;
11 import com.sun.jersey.spi.container.ContainerResponse;
12 import com.sun.jersey.spi.container.ContainerResponseFilter;
13 import com.sun.jersey.spi.container.ResourceFilter;
14 import org.slf4j.Logger;
15 import org.slf4j.LoggerFactory;
16
17 import javax.ws.rs.core.Response;
18 import java.net.URI;
19 import java.util.Collections;
20 import java.util.List;
21 import java.util.Set;
22
23 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS;
24 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_HEADERS;
25 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_METHODS;
26 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_ALLOW_ORIGIN;
27 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_EXPOSE_HEADERS;
28 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_MAX_AGE;
29 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_REQUEST_HEADERS;
30 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ACCESS_CONTROL_REQUEST_METHOD;
31 import static com.atlassian.plugins.rest.common.security.CorsHeaders.ORIGIN;
32 import static com.google.common.collect.Lists.newArrayList;
33 import static com.google.common.collect.Sets.newHashSet;
34
35
36
37
38
39
40
41
42
43 public class CorsResourceFilter implements ResourceFilter, ContainerRequestFilter, ContainerResponseFilter
44 {
45 private static final String CORS_PREFLIGHT_FAILED = "Cors-Preflight-Failed";
46 private static final String CORS_PREFLIGHT_SUCCEEDED = "Cors-Preflight-Succeeded";
47 public static final String CORS_PREFLIGHT_REQUESTED = "Cors-Preflight-Requested";
48 private static final Logger log = LoggerFactory.getLogger(CorsResourceFilter.class);
49
50 private final PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> pluginModuleTracker;
51 private final String allowMethod;
52
53 public CorsResourceFilter(PluginModuleTracker<CorsDefaults, CorsDefaultsModuleDescriptor> pluginModuleTracker, String allowMethod)
54 {
55 this.allowMethod = allowMethod;
56 this.pluginModuleTracker = pluginModuleTracker;
57 }
58
59 public ContainerRequest filter(final ContainerRequest request)
60 {
61
62
63
64
65
66
67
68
69
70
71
72 if (request.getProperties().containsKey(CORS_PREFLIGHT_REQUESTED))
73 {
74 Iterable<CorsDefaults> defaults = pluginModuleTracker.getModules();
75 try
76 {
77 String origin = validateSingleOriginInWhitelist(defaults, request);
78 Iterable<CorsDefaults> defaultsWithAllowedOrigin = allowsOrigin(defaults, origin);
79
80 Response.ResponseBuilder response = Response.ok();
81 validateAccessControlRequestMethod(allowMethod, request);
82 Set<String> allowedRequestHeaders = getAllowedRequestHeaders(defaultsWithAllowedOrigin, origin);
83 validateAccessControlRequestHeaders(allowedRequestHeaders, request);
84
85 addAccessControlAllowOrigin(response, origin);
86 conditionallyAddAccessControlAllowCredentials(response, origin, defaultsWithAllowedOrigin);
87 addAccessControlMaxAge(response);
88 addAccessControlAllowMethods(response, allowMethod);
89 addAccessControlAllowHeaders(response, allowedRequestHeaders);
90
91 request.getProperties().put(CORS_PREFLIGHT_SUCCEEDED, "true");
92
93 throw new CorsPreflightCheckCompleteException(response.build());
94 }
95 catch (PreflightFailedException ex)
96 {
97 Response.ResponseBuilder response = Response.ok();
98 request.getProperties().put(CORS_PREFLIGHT_FAILED, "true");
99 log.info("CORS preflight failed: " + ex.getMessage());
100 throw new CorsPreflightCheckCompleteException(response.build());
101 }
102 }
103 else
104 {
105 return request;
106 }
107 }
108
109 public ContainerResponse filter(ContainerRequest request, ContainerResponse containerResponse)
110 {
111
112
113
114
115
116 if (!request.getProperties().containsKey(CORS_PREFLIGHT_FAILED) &&
117 !request.getProperties().containsKey(CORS_PREFLIGHT_SUCCEEDED) &&
118 extractOrigin(request) != null)
119 {
120 Iterable<CorsDefaults> defaults = pluginModuleTracker.getModules();
121 try
122 {
123 String origin = validateAnyOriginInListInWhitelist(defaults, request);
124 Iterable<CorsDefaults> defaultsWithAllowedOrigin = allowsOrigin(defaults, origin);
125
126 Response.ResponseBuilder response = Response.fromResponse(containerResponse.getResponse());
127 addAccessControlAllowOrigin(response, origin);
128 conditionallyAddAccessControlAllowCredentials(response, origin, defaultsWithAllowedOrigin);
129 addAccessControlExposeHeaders(response, getAllowedResponseHeaders(defaultsWithAllowedOrigin, origin));
130 containerResponse.setResponse(response.build());
131 return containerResponse;
132 }
133 catch (PreflightFailedException ex)
134 {
135 log.info("Unable to add CORS headers to response: " + ex.getMessage());
136 }
137 }
138 return containerResponse;
139 }
140
141 private void addAccessControlExposeHeaders(Response.ResponseBuilder response, Set<String> allowedHeaders)
142 {
143 for (String header : allowedHeaders)
144 {
145 response.header(ACCESS_CONTROL_EXPOSE_HEADERS.value(), header);
146 }
147 }
148
149 private void addAccessControlAllowHeaders(Response.ResponseBuilder response, Set<String> allowedHeaders)
150 {
151 for (String header : allowedHeaders)
152 {
153 response.header(ACCESS_CONTROL_ALLOW_HEADERS.value(), header);
154 }
155 }
156
157 private void addAccessControlAllowMethods(Response.ResponseBuilder response, String allowMethod)
158 {
159 response.header(ACCESS_CONTROL_ALLOW_METHODS.value(), allowMethod);
160 }
161
162 private void addAccessControlMaxAge(Response.ResponseBuilder response)
163 {
164 response.header(ACCESS_CONTROL_MAX_AGE.value(), 60 * 60);
165 }
166
167 private void addAccessControlAllowOrigin(Response.ResponseBuilder response, String origin)
168 {
169 response.header(ACCESS_CONTROL_ALLOW_ORIGIN.value(), origin);
170 }
171
172 private void conditionallyAddAccessControlAllowCredentials(Response.ResponseBuilder response, String origin, Iterable<CorsDefaults> defaultsWithAllowedOrigin)
173 {
174 if (anyAllowsCredentials(defaultsWithAllowedOrigin, origin))
175 {
176 response.header(ACCESS_CONTROL_ALLOW_CREDENTIALS.value(), "true");
177 }
178 }
179
180 private void validateAccessControlRequestHeaders(Set<String> allowedHeaders, ContainerRequest request) throws PreflightFailedException
181 {
182
183 List<String> requestedHeaders = request.getRequestHeader(ACCESS_CONTROL_REQUEST_HEADERS.value());
184 requestedHeaders = requestedHeaders != null ? requestedHeaders : Collections.<String>emptyList();
185 if (!allowedHeaders.containsAll(requestedHeaders))
186 {
187 List<String> unexpectedHeaders = newArrayList(requestedHeaders);
188 unexpectedHeaders.removeAll(allowedHeaders);
189
190 throw new PreflightFailedException("Unexpected headers in CORS request: " + unexpectedHeaders);
191 }
192 }
193
194 private void validateAccessControlRequestMethod(String allowMethod, ContainerRequest request) throws PreflightFailedException
195 {
196 String requestedMethod = request.getHeaderValue(ACCESS_CONTROL_REQUEST_METHOD.value());
197 if (!allowMethod.equals(requestedMethod))
198 {
199 throw new PreflightFailedException("Invalid method: " + requestedMethod);
200 }
201 }
202
203 private String validateAnyOriginInListInWhitelist(Iterable<CorsDefaults> defaults, ContainerRequest request) throws PreflightFailedException
204 {
205 String originRaw = extractOrigin(request);
206 String[] originList = originRaw.split(" ");
207 for (String origin : originList)
208 {
209 validateOriginAsUri(origin);
210 if (! Iterables.isEmpty(allowsOrigin(defaults, origin)))
211 {
212 return origin;
213 }
214 }
215 throw new PreflightFailedException("Origins '" + originRaw + "' not in whitelist");
216 }
217
218 private String validateSingleOriginInWhitelist(Iterable<CorsDefaults> defaults, ContainerRequest request) throws PreflightFailedException
219 {
220 String origin = extractOrigin(request);
221 validateOriginAsUri(origin);
222
223 if (Iterables.isEmpty(allowsOrigin(defaults, origin)))
224 {
225 throw new PreflightFailedException("Origin '" + origin + "' not in whitelist");
226 }
227 return origin;
228 }
229
230 private void validateOriginAsUri(String origin) throws PreflightFailedException
231 {
232 try
233 {
234 URI.create(origin);
235 }
236 catch (IllegalArgumentException ex)
237 {
238 throw new PreflightFailedException("Origin '" + origin + "' is not a valid URI");
239 }
240 }
241
242 public static String extractOrigin(ContainerRequest request)
243 {
244 return request.getHeaderValue(ORIGIN.value());
245 }
246
247 public ContainerRequestFilter getRequestFilter()
248 {
249 return this;
250 }
251
252 public ContainerResponseFilter getResponseFilter()
253 {
254 return this;
255 }
256
257
258
259
260 private static class PreflightFailedException extends Exception
261 {
262 private PreflightFailedException(String message)
263 {
264 super(message);
265 }
266 }
267
268 private static Iterable<CorsDefaults> allowsOrigin(Iterable<CorsDefaults> delegates, final String uri)
269 {
270 return Iterables.filter(delegates, new Predicate<CorsDefaults>()
271 {
272 public boolean apply(CorsDefaults delegate)
273 {
274 return delegate.allowsOrigin(uri);
275 }
276 });
277 }
278
279 private static boolean anyAllowsCredentials(Iterable<CorsDefaults> delegatesWhichAllowOrigin, final String uri)
280 {
281 for (CorsDefaults defs : delegatesWhichAllowOrigin)
282 {
283 if (defs.allowsCredentials(uri))
284 {
285 return true;
286 }
287 }
288 return false;
289 }
290
291
292 private static Set<String> getAllowedRequestHeaders(Iterable<CorsDefaults> delegatesWhichAllowOrigin, String uri)
293 {
294 Set<String> result = newHashSet();
295 for (CorsDefaults defs : delegatesWhichAllowOrigin)
296 {
297 result.addAll(defs.getAllowedRequestHeaders(uri));
298 }
299 return result;
300 }
301
302 private static Set<String> getAllowedResponseHeaders(Iterable<CorsDefaults> delegatesWithAllowedOrigin, String uri)
303 {
304 Set<String> result = newHashSet();
305 for (CorsDefaults defs : delegatesWithAllowedOrigin)
306 {
307 result.addAll(defs.getAllowedResponseHeaders(uri));
308 }
309 return result;
310 }
311
312 }