1 package com.atlassian.activeobjects.backup;
2
3 import com.atlassian.dbexporter.Context;
4 import com.atlassian.dbexporter.EntityNameProcessor;
5 import com.atlassian.dbexporter.ImportExportErrorService;
6 import com.atlassian.dbexporter.Table;
7 import com.atlassian.dbexporter.importer.ImportConfiguration;
8 import com.atlassian.dbexporter.importer.NoOpAroundImporter;
9 import com.atlassian.dbexporter.node.NodeParser;
10 import net.java.ao.DatabaseProvider;
11
12 import java.sql.Connection;
13 import java.sql.ResultSet;
14 import java.sql.SQLException;
15 import java.sql.Statement;
16
17 import static com.atlassian.activeobjects.backup.SqlUtils.executeQuery;
18 import static com.atlassian.activeobjects.backup.SqlUtils.executeUpdate;
19 import static com.atlassian.activeobjects.backup.SqlUtils.getIntFromResultSet;
20 import static com.atlassian.activeobjects.backup.SqlUtils.tableColumnPairs;
21 import static com.atlassian.dbexporter.DatabaseInformations.Database;
22 import static com.atlassian.dbexporter.DatabaseInformations.database;
23 import static com.atlassian.dbexporter.jdbc.JdbcUtils.closeQuietly;
24 import static com.atlassian.dbexporter.jdbc.JdbcUtils.quote;
25 import static com.google.common.base.Preconditions.checkNotNull;
26
27
28
29
30
31 public final class PostgresSequencesAroundImporter extends NoOpAroundImporter {
32 private final ImportExportErrorService errorService;
33 private final DatabaseProvider provider;
34
35 public PostgresSequencesAroundImporter(ImportExportErrorService errorService, DatabaseProvider provider) {
36 this.errorService = checkNotNull(errorService);
37 this.provider = checkNotNull(provider);
38 }
39
40 @Override
41 public void after(NodeParser node, ImportConfiguration configuration, Context context) {
42 if (isPostgres(configuration)) {
43 updateSequences(configuration, context);
44 }
45 }
46
47 private boolean isPostgres(ImportConfiguration configuration) {
48 return Database.Type.POSTGRES.equals(database(configuration.getDatabaseInformation()).getType());
49 }
50
51 private void updateSequences(ImportConfiguration configuration, Context context) {
52 final EntityNameProcessor entityNameProcessor = configuration.getEntityNameProcessor();
53 for (SqlUtils.TableColumnPair tableColumnPair : tableColumnPairs(context.getAll(Table.class))) {
54 final String tableName = entityNameProcessor.tableName(tableColumnPair.table.getName());
55 final String columnName = entityNameProcessor.columnName(tableColumnPair.column.getName());
56 updateSequence(tableName, columnName);
57 }
58 }
59
60 private void updateSequence(String tableName, String columnName) {
61 Connection connection = null;
62 Statement maxStmt = null;
63 Statement alterSeqStmt = null;
64 try {
65 connection = provider.getConnection();
66 maxStmt = connection.createStatement();
67
68 final ResultSet res = executeQuery(errorService, tableName, maxStmt, max(connection, tableName, columnName));
69
70 final int max = getIntFromResultSet(errorService, tableName, res);
71 alterSeqStmt = connection.createStatement();
72 executeUpdate(errorService, tableName, alterSeqStmt, alterSequence(connection, tableName, columnName, max + 1));
73 } catch (SQLException e) {
74 throw errorService.newImportExportSqlException(tableName, "", e);
75 } finally {
76 closeQuietly(maxStmt, alterSeqStmt);
77 closeQuietly(connection);
78 }
79 }
80
81 private String max(Connection connection, String tableName, String columnName) {
82 return "SELECT MAX(" + quote(errorService, tableName, connection, columnName) + ") FROM " + tableName(connection, tableName);
83 }
84
85 private String tableName(Connection connection, String tableName) {
86 final String schema = isBlank(provider.getSchema()) ? null : provider.getSchema();
87 final String quoted = quote(errorService, tableName, connection, tableName);
88 return schema != null ? schema + "." + quoted : quoted;
89 }
90
91 private String alterSequence(Connection connection, String tableName, String columnName, int val) {
92 return "ALTER SEQUENCE " + sequenceName(connection, tableName, columnName) + " RESTART WITH " + val;
93 }
94
95 private String sequenceName(Connection connection, String tableName, String columnName) {
96 final String schema = isBlank(provider.getSchema()) ? null : provider.getSchema();
97 final String quoted = quote(errorService, tableName, connection, tableName + "_" + columnName + "_" + "seq");
98 return schema != null ? schema + "." + quoted : quoted;
99 }
100
101 private static boolean isBlank(String str) {
102 int strLen;
103 if (str == null || (strLen = str.length()) == 0) {
104 return true;
105 }
106 for (int i = 0; i < strLen; i++) {
107 if (!Character.isWhitespace(str.charAt(i))) {
108 return false;
109 }
110 }
111 return true;
112 }
113 }