View Javadoc
1   package com.atlassian.activeobjects.backup;
2   
3   import com.atlassian.dbexporter.Context;
4   import com.atlassian.dbexporter.EntityNameProcessor;
5   import com.atlassian.dbexporter.ImportExportErrorService;
6   import com.atlassian.dbexporter.Table;
7   import com.atlassian.dbexporter.importer.ImportConfiguration;
8   import com.atlassian.dbexporter.importer.NoOpAroundImporter;
9   import com.atlassian.dbexporter.node.NodeParser;
10  import net.java.ao.DatabaseProvider;
11  
12  import java.sql.Connection;
13  import java.sql.ResultSet;
14  import java.sql.SQLException;
15  import java.sql.Statement;
16  
17  import static com.atlassian.activeobjects.backup.SqlUtils.executeQuery;
18  import static com.atlassian.activeobjects.backup.SqlUtils.executeUpdate;
19  import static com.atlassian.activeobjects.backup.SqlUtils.getIntFromResultSet;
20  import static com.atlassian.activeobjects.backup.SqlUtils.tableColumnPairs;
21  import static com.atlassian.dbexporter.DatabaseInformations.Database;
22  import static com.atlassian.dbexporter.DatabaseInformations.database;
23  import static com.atlassian.dbexporter.jdbc.JdbcUtils.closeQuietly;
24  import static com.atlassian.dbexporter.jdbc.JdbcUtils.quote;
25  import static com.google.common.base.Preconditions.checkNotNull;
26  
27  /**
28   * Updates the auto-increment sequences so that they start are the correct min value after some data has been 'manually'
29   * imported into the database.
30   */
31  public final class PostgresSequencesAroundImporter extends NoOpAroundImporter {
32      private final ImportExportErrorService errorService;
33      private final DatabaseProvider provider;
34  
35      public PostgresSequencesAroundImporter(ImportExportErrorService errorService, DatabaseProvider provider) {
36          this.errorService = checkNotNull(errorService);
37          this.provider = checkNotNull(provider);
38      }
39  
40      @Override
41      public void after(NodeParser node, ImportConfiguration configuration, Context context) {
42          if (isPostgres(configuration)) {
43              updateSequences(configuration, context);
44          }
45      }
46  
47      private boolean isPostgres(ImportConfiguration configuration) {
48          return Database.Type.POSTGRES.equals(database(configuration.getDatabaseInformation()).getType());
49      }
50  
51      private void updateSequences(ImportConfiguration configuration, Context context) {
52          final EntityNameProcessor entityNameProcessor = configuration.getEntityNameProcessor();
53          for (SqlUtils.TableColumnPair tableColumnPair : tableColumnPairs(context.getAll(Table.class))) {
54              final String tableName = entityNameProcessor.tableName(tableColumnPair.table.getName());
55              final String columnName = entityNameProcessor.columnName(tableColumnPair.column.getName());
56              updateSequence(tableName, columnName);
57          }
58      }
59  
60      private void updateSequence(String tableName, String columnName) {
61          Connection connection = null;
62          Statement maxStmt = null;
63          Statement alterSeqStmt = null;
64          try {
65              connection = provider.getConnection();
66              maxStmt = connection.createStatement();
67  
68              final ResultSet res = executeQuery(errorService, tableName, maxStmt, max(connection, tableName, columnName));
69  
70              final int max = getIntFromResultSet(errorService, tableName, res);
71              alterSeqStmt = connection.createStatement();
72              executeUpdate(errorService, tableName, alterSeqStmt, alterSequence(connection, tableName, columnName, max + 1));
73          } catch (SQLException e) {
74              throw errorService.newImportExportSqlException(tableName, "", e);
75          } finally {
76              closeQuietly(maxStmt, alterSeqStmt);
77              closeQuietly(connection);
78          }
79      }
80  
81      private String max(Connection connection, String tableName, String columnName) {
82          return "SELECT MAX(" + quote(errorService, tableName, connection, columnName) + ") FROM " + tableName(connection, tableName);
83      }
84  
85      private String tableName(Connection connection, String tableName) {
86          final String schema = isBlank(provider.getSchema()) ? null : provider.getSchema();
87          final String quoted = quote(errorService, tableName, connection, tableName);
88          return schema != null ? schema + "." + quoted : quoted;
89      }
90  
91      private String alterSequence(Connection connection, String tableName, String columnName, int val) {
92          return "ALTER SEQUENCE " + sequenceName(connection, tableName, columnName) + " RESTART WITH " + val;
93      }
94  
95      private String sequenceName(Connection connection, String tableName, String columnName) {
96          final String schema = isBlank(provider.getSchema()) ? null : provider.getSchema();
97          final String quoted = quote(errorService, tableName, connection, tableName + "_" + columnName + "_" + "seq");
98          return schema != null ? schema + "." + quoted : quoted;
99      }
100 
101     private static boolean isBlank(String str) {
102         int strLen;
103         if (str == null || (strLen = str.length()) == 0) {
104             return true;
105         }
106         for (int i = 0; i < strLen; i++) {
107             if (!Character.isWhitespace(str.charAt(i))) {
108                 return false;
109             }
110         }
111         return true;
112     }
113 }