1 package com.atlassian.activeobjects.backup;
2
3 import com.atlassian.dbexporter.Context;
4 import com.atlassian.dbexporter.DatabaseInformations;
5 import com.atlassian.dbexporter.ImportExportErrorService;
6 import com.atlassian.dbexporter.Table;
7 import com.atlassian.dbexporter.importer.ImportConfiguration;
8 import com.atlassian.dbexporter.importer.NoOpAroundImporter;
9 import com.atlassian.dbexporter.node.NodeParser;
10 import net.java.ao.DatabaseProvider;
11 import net.java.ao.schema.NameConverters;
12
13 import java.sql.Connection;
14 import java.sql.ResultSet;
15 import java.sql.SQLException;
16 import java.sql.Statement;
17 import java.util.Collection;
18
19 import static com.atlassian.activeobjects.backup.SqlUtils.TableColumnPair;
20 import static com.atlassian.activeobjects.backup.SqlUtils.executeQuery;
21 import static com.atlassian.activeobjects.backup.SqlUtils.executeUpdate;
22 import static com.atlassian.activeobjects.backup.SqlUtils.getIntFromResultSet;
23 import static com.atlassian.activeobjects.backup.SqlUtils.tableColumnPairs;
24 import static com.atlassian.dbexporter.DatabaseInformations.database;
25 import static com.atlassian.dbexporter.jdbc.JdbcUtils.closeQuietly;
26 import static com.atlassian.dbexporter.jdbc.JdbcUtils.quote;
27 import static com.google.common.base.Preconditions.checkNotNull;
28
29 public final class OracleSequencesAroundImporter extends NoOpAroundImporter {
30 private final ImportExportErrorService errorService;
31 private final DatabaseProvider provider;
32 private final NameConverters nameConverters;
33
34 public OracleSequencesAroundImporter(ImportExportErrorService errorService, DatabaseProvider provider, NameConverters nameConverters) {
35 this.errorService = checkNotNull(errorService);
36 this.provider = checkNotNull(provider);
37 this.nameConverters = checkNotNull(nameConverters);
38 }
39
40 @Override
41 public void before(NodeParser node, ImportConfiguration configuration, Context context) {
42 if (isOracle(configuration)) {
43 doBefore(context);
44 }
45 }
46
47 @Override
48 public void after(NodeParser node, ImportConfiguration configuration, Context context) {
49 if (isOracle(configuration)) {
50 doAfter(context);
51 }
52 }
53
54 private boolean isOracle(ImportConfiguration configuration) {
55 return DatabaseInformations.Database.Type.ORACLE.equals(database(configuration.getDatabaseInformation()).getType());
56 }
57
58 private void doBefore(Context context) {
59 final Collection<Table> tables = context.getAll(Table.class);
60 disableAllTriggers(tables);
61 dropAllSequences(tables);
62 }
63
64 private void doAfter(Context context) {
65 final Collection<Table> tables = context.getAll(Table.class);
66 createAllSequences(tables);
67 enableAllTriggers(tables);
68 }
69
70 private void disableAllTriggers(Collection<Table> tables) {
71 Connection connection = null;
72 try {
73 connection = provider.getConnection();
74 for (Table table : tables) {
75 executeUpdate(errorService, table.getName(), connection, "ALTER TABLE " + tableName(connection, table.getName()) + " DISABLE ALL TRIGGERS");
76 }
77 } catch (SQLException e) {
78 throw errorService.newImportExportSqlException(null, "", e);
79 } finally {
80 closeQuietly(connection);
81 }
82 }
83
84 private void dropAllSequences(Collection<Table> tables) {
85 Connection connection = null;
86 try {
87 connection = provider.getConnection();
88 for (TableColumnPair tcp : tableColumnPairs(tables)) {
89 dropSequence(connection, tcp);
90 }
91 } catch (SQLException e) {
92 throw errorService.newImportExportSqlException(null, "", e);
93 } finally {
94 closeQuietly(connection);
95 }
96 }
97
98 private void dropSequence(Connection connection, TableColumnPair tcp) {
99 executeUpdate(errorService, tcp.table.getName(), connection, "DROP SEQUENCE " + sequenceName(connection, tcp));
100 }
101
102 private void createAllSequences(Collection<Table> tables) {
103 Connection connection = null;
104 try {
105 connection = provider.getConnection();
106 for (TableColumnPair tcp : tableColumnPairs(tables)) {
107 createSequence(connection, tcp);
108 }
109 } catch (SQLException e) {
110 throw errorService.newImportExportSqlException(null, "", e);
111 } finally {
112 closeQuietly(connection);
113 }
114 }
115
116 private void createSequence(Connection connection, TableColumnPair tcp) {
117 Statement maxStmt = null;
118 final String tableName = tcp.table.getName();
119 try {
120 maxStmt = connection.createStatement();
121 final ResultSet res = executeQuery(errorService, tableName, maxStmt,
122 "SELECT MAX(" + quote(errorService, tableName, connection, tcp.column.getName()) + ")" +
123 " FROM " + tableName(connection, tableName));
124 final int max = getIntFromResultSet(errorService, tableName, res);
125 executeUpdate(errorService, tableName, connection, "CREATE SEQUENCE " + sequenceName(connection, tcp)
126 + " INCREMENT BY 1 START WITH " + (max + 1) + " NOMAXVALUE MINVALUE " + (max + 1));
127 } catch (SQLException e) {
128 throw errorService.newImportExportSqlException(tableName, "", e);
129 } finally {
130 closeQuietly(maxStmt);
131 }
132 }
133
134 private void enableAllTriggers(Collection<Table> tables) {
135 Connection connection = null;
136 try {
137 connection = provider.getConnection();
138 for (Table table : tables) {
139 executeUpdate(errorService, table.getName(), connection, "ALTER TABLE " + tableName(connection, table.getName()) + " ENABLE ALL TRIGGERS");
140 }
141 } catch (SQLException e) {
142 throw errorService.newImportExportSqlException(null, "", e);
143 } finally {
144 closeQuietly(connection);
145 }
146 }
147
148 private String tableName(Connection connection, String tableName) {
149 final String schema = isBlank(provider.getSchema()) ? null : provider.getSchema();
150 final String quoted = quote(errorService, tableName, connection, tableName);
151 return schema != null ? schema + "." + quoted : quoted;
152 }
153
154 private String sequenceName(Connection connection, TableColumnPair tcp) {
155 final String schema = isBlank(provider.getSchema()) ? null : provider.getSchema();
156 final String quoted = quote(errorService, tcp.table.getName(), connection, nameConverters.getSequenceNameConverter().getName(tcp.table.getName(), tcp.column.getName()));
157 return schema != null ? schema + "." + quoted : quoted;
158 }
159
160 private static boolean isBlank(String str) {
161 int strLen;
162 if (str == null || (strLen = str.length()) == 0) {
163 return true;
164 }
165 for (int i = 0; i < strLen; i++) {
166 if (!Character.isWhitespace(str.charAt(i))) {
167 return false;
168 }
169 }
170 return true;
171 }
172 }