1 package com.atlassian.dbexporter.importer;
2
3 import com.atlassian.dbexporter.ImportExportErrorService;
4 import com.atlassian.dbexporter.node.NodeParser;
5 import com.atlassian.dbexporter.node.NodeStreamReader;
6 import com.atlassian.dbexporter.node.stax.StaxStreamReader;
7 import com.google.common.base.Predicate;
8 import com.google.common.collect.Iterables;
9 import org.junit.rules.MethodRule;
10 import org.junit.runners.model.FrameworkMethod;
11 import org.junit.runners.model.Statement;
12
13 import java.io.StringReader;
14 import java.lang.reflect.Field;
15
16 import static com.google.common.collect.Lists.newArrayList;
17 import static org.junit.Assert.assertFalse;
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 public final class NodeParserRule implements MethodRule {
39 private NodeStreamReader streamReader;
40 private NodeParser node;
41
42 public NodeParser getNode() {
43 return node;
44 }
45
46 public Statement apply(final Statement statement, final FrameworkMethod frameworkMethod, final Object o) {
47 return new Statement() {
48 @Override
49 public void evaluate() throws Throwable {
50 before(frameworkMethod, o);
51 try {
52 statement.evaluate();
53 } finally {
54 after();
55 }
56 }
57 };
58 }
59
60 private void before(FrameworkMethod method, Object o) {
61 final Xml xml = method.getAnnotation(Xml.class);
62 if (xml != null) {
63 streamReader = new StaxStreamReader(findErrorService(method.getMethod().getDeclaringClass(), o), new StringReader(xml.value()));
64 node = streamReader.getRootNode();
65 assertFalse(node.isClosed());
66 }
67 }
68
69 private ImportExportErrorService findErrorService(Class aClass, Object o) {
70 try {
71 return getValue(ImportExportErrorService.class, o, findFieldOfType(aClass, ImportExportErrorService.class));
72 } catch (IllegalAccessException e) {
73 throw new RuntimeException(e);
74 }
75 }
76
77 private <T> T getValue(Class<T> type, Object o, Field f) throws IllegalAccessException {
78 final boolean accessible = f.isAccessible();
79 try {
80 f.setAccessible(true);
81 return type.cast(f.get(o));
82 } finally {
83 f.setAccessible(accessible);
84 }
85 }
86
87 private Field findFieldOfType(Class aClass, final Class<ImportExportErrorService> type) {
88 return Iterables.find(newArrayList(aClass.getDeclaredFields()), new Predicate<Field>() {
89 @Override
90 public boolean apply(Field f) {
91 return type.isAssignableFrom(f.getType());
92 }
93 });
94 }
95
96 private void after() {
97 if (streamReader != null) {
98 streamReader.close();
99 }
100 streamReader = null;
101 node = null;
102 }
103 }