1 package com.atlassian.plugin.servlet;
2
3 import com.atlassian.plugin.ModuleDescriptor;
4 import com.atlassian.plugin.Plugin;
5 import com.atlassian.plugin.event.PluginEventListener;
6 import com.atlassian.plugin.event.PluginEventManager;
7 import com.atlassian.plugin.event.events.PluginDisabledEvent;
8 import com.atlassian.plugin.event.events.PluginFrameworkShutdownEvent;
9 import com.atlassian.plugin.servlet.descriptors.ServletContextListenerModuleDescriptor;
10 import com.atlassian.plugin.servlet.descriptors.ServletContextParamModuleDescriptor;
11 import com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptor;
12 import com.atlassian.plugin.servlet.descriptors.ServletModuleDescriptor;
13 import com.atlassian.plugin.servlet.filter.DelegatingPluginFilter;
14 import com.atlassian.plugin.servlet.filter.FilterDispatcherCondition;
15 import com.atlassian.plugin.servlet.filter.FilterLocation;
16 import com.atlassian.plugin.servlet.filter.PluginFilterConfig;
17 import com.atlassian.plugin.servlet.util.DefaultPathMapper;
18 import com.atlassian.plugin.servlet.util.PathMapper;
19 import com.atlassian.plugin.servlet.util.ServletContextServletModuleManagerAccessor;
20 import com.atlassian.plugin.util.ClassLoaderStack;
21 import com.atlassian.util.concurrent.LazyReference;
22 import org.slf4j.Logger;
23 import org.slf4j.LoggerFactory;
24
25 import javax.servlet.Filter;
26 import javax.servlet.FilterConfig;
27 import javax.servlet.ServletConfig;
28 import javax.servlet.ServletContext;
29 import javax.servlet.ServletContextEvent;
30 import javax.servlet.ServletContextListener;
31 import javax.servlet.ServletException;
32 import javax.servlet.http.HttpServlet;
33 import java.util.ArrayList;
34 import java.util.Arrays;
35 import java.util.Collections;
36 import java.util.Comparator;
37 import java.util.Enumeration;
38 import java.util.HashMap;
39 import java.util.HashSet;
40 import java.util.LinkedList;
41 import java.util.List;
42 import java.util.Map;
43 import java.util.Set;
44 import java.util.concurrent.ConcurrentHashMap;
45 import java.util.concurrent.ConcurrentMap;
46
47 import static com.atlassian.plugin.servlet.descriptors.ServletFilterModuleDescriptor.byWeight;
48 import static com.google.common.base.Preconditions.checkNotNull;
49
50
51
52
53
54
55
56 public class DefaultServletModuleManager implements ServletModuleManager
57 {
58 private static final Logger log = LoggerFactory.getLogger(DefaultServletModuleManager.class);
59
60 private final PathMapper servletMapper;
61 private final Map<String, ServletModuleDescriptor> servletDescriptors = new HashMap<String, ServletModuleDescriptor>();
62 private final ConcurrentMap<String, LazyReference<HttpServlet>> servletRefs = new ConcurrentHashMap<String, LazyReference<HttpServlet>>();
63
64 private final PathMapper filterMapper;
65 private final Map<String, ServletFilterModuleDescriptor> filterDescriptors = new HashMap<String, ServletFilterModuleDescriptor>();
66 private final ConcurrentMap<String, LazyReference<Filter>> filterRefs = new ConcurrentHashMap<String, LazyReference<Filter>>();
67
68 private final ConcurrentMap<Plugin, ContextLifecycleReference> pluginContextRefs = new ConcurrentHashMap<Plugin, ContextLifecycleReference>();
69
70
71
72
73
74
75
76
77
78 public DefaultServletModuleManager(final ServletContext servletContext, final PluginEventManager pluginEventManager)
79 {
80 this(pluginEventManager);
81 ServletContextServletModuleManagerAccessor.setServletModuleManager(servletContext, this);
82 }
83
84
85
86
87
88
89
90
91
92 public DefaultServletModuleManager(final PluginEventManager pluginEventManager)
93 {
94 this(pluginEventManager, new DefaultPathMapper(), new DefaultPathMapper());
95 }
96
97
98
99
100
101
102
103
104
105
106
107
108 public DefaultServletModuleManager(final PluginEventManager pluginEventManager, final PathMapper servletPathMapper, final PathMapper filterPathMapper)
109 {
110 servletMapper = servletPathMapper;
111 filterMapper = filterPathMapper;
112 pluginEventManager.register(this);
113 }
114
115 public void addServletModule(final ServletModuleDescriptor descriptor)
116 {
117 servletDescriptors.put(descriptor.getCompleteKey(), descriptor);
118
119
120
121 final List<String> paths = descriptor.getPaths();
122 for (final String path : paths)
123 {
124 servletMapper.put(descriptor.getCompleteKey(), path);
125 }
126 final LazyReference<HttpServlet> servletRef = servletRefs.remove(descriptor.getCompleteKey());
127 if (servletRef != null)
128 {
129 servletRef.get().destroy();
130 }
131 }
132
133 public HttpServlet getServlet(final String path, final ServletConfig servletConfig) throws ServletException
134 {
135 final String completeKey = servletMapper.get(path);
136
137 if (completeKey == null)
138 {
139 return null;
140 }
141 final ServletModuleDescriptor descriptor = servletDescriptors.get(completeKey);
142 if (descriptor == null)
143 {
144 return null;
145 }
146
147 final HttpServlet servlet = getServlet(descriptor, servletConfig);
148 if (servlet == null)
149 {
150 servletRefs.remove(descriptor.getCompleteKey());
151 }
152 return servlet;
153 }
154
155 public void removeServletModule(final ServletModuleDescriptor descriptor)
156 {
157 servletDescriptors.remove(descriptor.getCompleteKey());
158 servletMapper.put(descriptor.getCompleteKey(), null);
159
160 final LazyReference<HttpServlet> servletRef = servletRefs.remove(descriptor.getCompleteKey());
161 if (servletRef != null)
162 {
163 servletRef.get().destroy();
164 }
165 }
166
167 public void addFilterModule(final ServletFilterModuleDescriptor descriptor)
168 {
169 filterDescriptors.put(descriptor.getCompleteKey(), descriptor);
170
171 for (final String path : descriptor.getPaths())
172 {
173 filterMapper.put(descriptor.getCompleteKey(), path);
174 }
175 final LazyReference<Filter> filterRef = filterRefs.remove(descriptor.getCompleteKey());
176 if (filterRef != null)
177 {
178 filterRef.get().destroy();
179 }
180 }
181
182 public Iterable<Filter> getFilters(final FilterLocation location, final String path, final FilterConfig filterConfig) throws ServletException
183 {
184 return getFilters(location, path, filterConfig, FilterDispatcherCondition.REQUEST);
185 }
186
187 public Iterable<Filter> getFilters(FilterLocation location, String path, FilterConfig filterConfig, FilterDispatcherCondition condition) throws ServletException
188 {
189 checkNotNull(condition);
190 final List<ServletFilterModuleDescriptor> matchingFilterDescriptors = new ArrayList<ServletFilterModuleDescriptor>();
191
192 for (final String completeKey : filterMapper.getAll(path))
193 {
194 final ServletFilterModuleDescriptor descriptor = filterDescriptors.get(completeKey);
195 if (!descriptor.getDispatcherConditions().contains(condition))
196 {
197 if (log.isTraceEnabled())
198 {
199 log.trace("Skipping filter " + descriptor.getCompleteKey() + " as condition " + condition +
200 " doesn't match list:" + Arrays.asList(descriptor.getDispatcherConditions()));
201 }
202 continue;
203 }
204
205 if (location.equals(descriptor.getLocation()))
206 {
207 sortedInsert(matchingFilterDescriptors, descriptor, byWeight);
208 }
209 }
210 final List<Filter> filters = new LinkedList<Filter>();
211 for (final ServletFilterModuleDescriptor descriptor : matchingFilterDescriptors)
212 {
213 final Filter filter = getFilter(descriptor, filterConfig);
214 if (filter == null)
215 {
216 filterRefs.remove(descriptor.getCompleteKey());
217 }
218 else
219 {
220 filters.add(getFilter(descriptor, filterConfig));
221 }
222 }
223
224 return filters;
225 }
226
227 static <T> void sortedInsert(final List<T> list, final T e, final Comparator<T> comparator)
228 {
229 int insertIndex = Collections.binarySearch(list, e, comparator);
230 if (insertIndex < 0)
231 {
232
233
234 insertIndex = -insertIndex - 1;
235 }
236 else
237 {
238
239
240 while ((insertIndex < list.size()) && (comparator.compare(list.get(insertIndex), e) == 0))
241 {
242 insertIndex++;
243 }
244 }
245 list.add(insertIndex, e);
246 }
247
248 public void removeFilterModule(final ServletFilterModuleDescriptor descriptor)
249 {
250 filterDescriptors.remove(descriptor.getCompleteKey());
251 filterMapper.put(descriptor.getCompleteKey(), null);
252
253 final LazyReference<Filter> filterRef = filterRefs.remove(descriptor.getCompleteKey());
254 if (filterRef != null)
255 {
256 filterRef.get().destroy();
257 }
258 }
259
260
261
262
263
264
265 @PluginEventListener
266 public void onPluginDisabled(final PluginDisabledEvent event)
267 {
268 final Plugin plugin = event.getPlugin();
269 final ContextLifecycleReference context = pluginContextRefs.remove(plugin);
270 if (context == null)
271 {
272 return;
273 }
274
275 context.get().contextDestroyed();
276 }
277
278 @PluginEventListener
279 public void onPluginFrameworkShutdown(final PluginFrameworkShutdownEvent event)
280 {
281 for (ServletModuleDescriptor md : new ArrayList<ServletModuleDescriptor>(servletDescriptors.values()))
282 {
283 if (md != null)
284 {
285 md.destroy(md.getPlugin());
286 }
287 }
288 for (ServletFilterModuleDescriptor md : new ArrayList<ServletFilterModuleDescriptor>(filterDescriptors.values()))
289 {
290 if (md != null)
291 {
292 md.destroy(md.getPlugin());
293 }
294 }
295 for (ContextLifecycleReference context : pluginContextRefs.values())
296 {
297 if (context != null)
298 {
299 ContextLifecycleManager lifecycleManager = context.get();
300 if (lifecycleManager != null)
301 {
302 lifecycleManager.contextDestroyed();
303 }
304 }
305 }
306 }
307
308
309
310
311
312
313
314
315
316
317
318
319
320 HttpServlet getServlet(final ServletModuleDescriptor descriptor, final ServletConfig servletConfig)
321 {
322
323
324
325
326 LazyReference<HttpServlet> servletRef = servletRefs.get(descriptor.getCompleteKey());
327 if (servletRef == null)
328 {
329
330 final ServletContext servletContext = getWrappedContext(descriptor.getPlugin(), servletConfig.getServletContext());
331 servletRef = new LazyLoadedServletReference(descriptor, servletContext);
332
333
334
335
336 if (servletRefs.putIfAbsent(descriptor.getCompleteKey(), servletRef) != null)
337 {
338 servletRef = servletRefs.get(descriptor.getCompleteKey());
339 }
340 }
341 HttpServlet servlet = null;
342 try
343 {
344 servlet = servletRef.get();
345 }
346 catch (final RuntimeException ex)
347 {
348 log.error("Unable to create servlet", ex);
349 }
350 return servlet;
351 }
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366 Filter getFilter(final ServletFilterModuleDescriptor descriptor, final FilterConfig filterConfig)
367 {
368
369
370
371
372 LazyReference<Filter> filterRef = filterRefs.get(descriptor.getCompleteKey());
373 if (filterRef == null)
374 {
375
376 final ServletContext servletContext = getWrappedContext(descriptor.getPlugin(), filterConfig.getServletContext());
377 filterRef = new LazyLoadedFilterReference(descriptor, servletContext);
378
379
380
381
382 if (filterRefs.putIfAbsent(descriptor.getCompleteKey(), filterRef) != null)
383 {
384 filterRef = filterRefs.get(descriptor.getCompleteKey());
385 }
386 }
387 try
388 {
389 return filterRef.get();
390 }
391 catch (final RuntimeException ex)
392 {
393 log.error("Unable to create filter", ex);
394 return null;
395 }
396 }
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413 private ServletContext getWrappedContext(final Plugin plugin, final ServletContext baseContext)
414 {
415 ContextLifecycleReference pluginContextRef = pluginContextRefs.get(plugin);
416 if (pluginContextRef == null)
417 {
418 pluginContextRef = new ContextLifecycleReference(plugin, baseContext);
419 if (pluginContextRefs.putIfAbsent(plugin, pluginContextRef) != null)
420 {
421 pluginContextRef = pluginContextRefs.get(plugin);
422 }
423 }
424 return pluginContextRef.get().servletContext;
425 }
426
427 private static final class LazyLoadedFilterReference extends LazyReference<Filter>
428 {
429 private final ServletFilterModuleDescriptor descriptor;
430 private final ServletContext servletContext;
431
432 private LazyLoadedFilterReference(final ServletFilterModuleDescriptor descriptor, final ServletContext servletContext)
433 {
434 this.descriptor = descriptor;
435 this.servletContext = servletContext;
436 }
437
438 @Override
439 protected Filter create() throws Exception
440 {
441 final Filter filter = new DelegatingPluginFilter(descriptor);
442 filter.init(new PluginFilterConfig(descriptor, servletContext));
443 return filter;
444 }
445 }
446
447 private static final class LazyLoadedServletReference extends LazyReference<HttpServlet>
448 {
449 private final ServletModuleDescriptor descriptor;
450 private final ServletContext servletContext;
451
452 private LazyLoadedServletReference(final ServletModuleDescriptor descriptor, final ServletContext servletContext)
453 {
454 this.descriptor = descriptor;
455 this.servletContext = servletContext;
456 }
457
458 @Override
459 protected HttpServlet create() throws Exception
460 {
461 final HttpServlet servlet = new DelegatingPluginServlet(descriptor);
462 servlet.init(new PluginServletConfig(descriptor, servletContext));
463 return servlet;
464 }
465 }
466
467 private static final class ContextLifecycleReference extends LazyReference<ContextLifecycleManager>
468 {
469 private final Plugin plugin;
470 private final ServletContext baseContext;
471
472 private ContextLifecycleReference(final Plugin plugin, final ServletContext baseContext)
473 {
474 this.plugin = plugin;
475 this.baseContext = baseContext;
476 }
477
478 @Override
479 protected ContextLifecycleManager create() throws Exception
480 {
481 final ConcurrentMap<String, Object> contextAttributes = new ConcurrentHashMap<String, Object>();
482 final Map<String, String> initParams = mergeInitParams(baseContext, plugin);
483 final ServletContext context = new PluginServletContextWrapper(plugin, baseContext, contextAttributes, initParams);
484
485 ClassLoaderStack.push(plugin.getClassLoader());
486 final List<ServletContextListener> listeners = new ArrayList<ServletContextListener>();
487 try
488 {
489 for (final ServletContextListenerModuleDescriptor descriptor : findModuleDescriptorsByType(ServletContextListenerModuleDescriptor.class, plugin))
490 {
491 listeners.add(descriptor.getModule());
492 }
493 }
494 finally
495 {
496 ClassLoaderStack.pop();
497 }
498
499 return new ContextLifecycleManager(context, listeners);
500 }
501
502 private Map<String, String> mergeInitParams(final ServletContext baseContext, final Plugin plugin)
503 {
504 final Map<String, String> mergedInitParams = new HashMap<String, String>();
505 @SuppressWarnings("unchecked")
506 final Enumeration<String> e = baseContext.getInitParameterNames();
507 while (e.hasMoreElements())
508 {
509 final String paramName = e.nextElement();
510 mergedInitParams.put(paramName, baseContext.getInitParameter(paramName));
511 }
512 for (final ServletContextParamModuleDescriptor descriptor : findModuleDescriptorsByType(ServletContextParamModuleDescriptor.class, plugin))
513 {
514 mergedInitParams.put(descriptor.getParamName(), descriptor.getParamValue());
515 }
516 return Collections.unmodifiableMap(mergedInitParams);
517 }
518 }
519
520 static <T extends ModuleDescriptor<?>> Iterable<T> findModuleDescriptorsByType(final Class<T> type, final Plugin plugin)
521 {
522 final Set<T> descriptors = new HashSet<T>();
523 for (final ModuleDescriptor<?> descriptor : plugin.getModuleDescriptors())
524 {
525 if (type.isAssignableFrom(descriptor.getClass()))
526 {
527 descriptors.add(type.cast(descriptor));
528 }
529 }
530 return descriptors;
531 }
532
533 static final class ContextLifecycleManager
534 {
535 private final ServletContext servletContext;
536 private final Iterable<ServletContextListener> listeners;
537
538 ContextLifecycleManager(final ServletContext servletContext, final Iterable<ServletContextListener> listeners)
539 {
540 this.servletContext = servletContext;
541 this.listeners = listeners;
542 for (final ServletContextListener listener : listeners)
543 {
544 listener.contextInitialized(new ServletContextEvent(servletContext));
545 }
546 }
547
548 ServletContext getServletContext()
549 {
550 return servletContext;
551 }
552
553 void contextDestroyed()
554 {
555 final ServletContextEvent event = new ServletContextEvent(servletContext);
556 for (final ServletContextListener listener : listeners)
557 {
558 listener.contextDestroyed(event);
559 }
560 }
561 }
562 }