View Javadoc
1   package com.atlassian.dbexporter.importer;
2   
3   import com.atlassian.dbexporter.Column;
4   import com.atlassian.dbexporter.Context;
5   import com.atlassian.dbexporter.DatabaseInformations;
6   import com.atlassian.dbexporter.ImportExportErrorService;
7   import com.atlassian.dbexporter.Table;
8   import com.google.common.base.Predicate;
9   import com.google.common.collect.Iterables;
10  
11  import java.sql.Connection;
12  import java.sql.SQLException;
13  import java.sql.Statement;
14  
15  import static com.atlassian.dbexporter.DatabaseInformations.database;
16  import static com.atlassian.dbexporter.jdbc.JdbcUtils.closeQuietly;
17  import static com.atlassian.dbexporter.jdbc.JdbcUtils.quote;
18  import static com.google.common.base.Preconditions.checkNotNull;
19  
20  public final class SqlServerAroundTableImporter implements DataImporter.AroundTableImporter {
21      private final ImportExportErrorService errorService;
22      private final String schema;
23  
24      public SqlServerAroundTableImporter(ImportExportErrorService errorService, String schema) {
25          this.errorService = checkNotNull(errorService);
26          this.schema = schema;
27      }
28  
29      @Override
30      public void before(ImportConfiguration configuration, Context context, String table, Connection connection) {
31          setIdentityInsert(configuration, context, connection, table, "ON");
32      }
33  
34      @Override
35      public void after(ImportConfiguration configuration, Context context, String table, Connection connection) {
36          setIdentityInsert(configuration, context, connection, table, "OFF");
37      }
38  
39      private void setIdentityInsert(ImportConfiguration configuration, Context context, Connection connection, String table, String onOff) {
40          if (isSqlServer(configuration) && isAutoIncrementTable(context, table)) {
41              setIdentityInsert(connection, table, onOff);
42          }
43      }
44  
45      private boolean isAutoIncrementTable(Context context, final String tableName) {
46          return hasAnyAutoIncrementColumn(findTable(context, tableName));
47      }
48  
49      private boolean hasAnyAutoIncrementColumn(Table table) {
50          return Iterables.any(table.getColumns(), new Predicate<Column>() {
51              @Override
52              public boolean apply(Column c) {
53                  return c.isAutoIncrement();
54              }
55          });
56      }
57  
58      private Table findTable(Context context, final String tableName) {
59          return Iterables.find(context.getAll(Table.class), new Predicate<Table>() {
60              @Override
61              public boolean apply(Table t) {
62                  return t.getName().equals(tableName);
63              }
64          });
65      }
66  
67      private void setIdentityInsert(Connection connection, String table, String onOff) {
68          Statement s = null;
69          try {
70              s = connection.createStatement();
71              s.execute(setIdentityInsertSql(quote(errorService, table, connection, table), onOff));
72          } catch (SQLException e) {
73              throw errorService.newImportExportSqlException(table, "", e);
74          } finally {
75              closeQuietly(s);
76          }
77      }
78  
79      private String setIdentityInsertSql(String table, String onOff) {
80          return String.format("SET IDENTITY_INSERT %s %s", schema != null ? schema + "." + table : table, onOff);
81      }
82  
83      private boolean isSqlServer(ImportConfiguration configuration) {
84          return DatabaseInformations.Database.Type.MSSQL.equals(database(configuration.getDatabaseInformation()).getType());
85      }
86  }