Commit 7ab77392 authored by Szabolcs Gyurko's avatar Szabolcs Gyurko
Browse files

Added output format selection

parent 73aa4ec6
......@@ -20,26 +20,31 @@ const (
)
var (
BuildVersion = "No Date Specified"
anonimyser gdpr.Anonymiser
randomiser gdpr.Randomiser
dbFlavour DBDumpFlavour
stdoutMutex sync.Mutex
processWaitGroup sync.WaitGroup
databases []string
singleThreaded *bool
perDBConcurrency *bool
maxIdleConnections *int
BuildVersion = "No Date Specified"
anonimyser gdpr.Anonymiser
randomiser gdpr.Randomiser
dbDumpFlavour DBDumpFlavour
dbDriverFlavour DBDriverFlavour
stdoutMutex sync.Mutex
processWaitGroup sync.WaitGroup
databases []string
singleThreaded *bool
perDBConcurrency *bool
maxIdleConnections *int
maxConnectionLifetime *int
)
type DBDumpFlavour interface {
type DBDriverFlavour interface {
GetDriverName() string
GetValueSeparator() string
GetConnectionString(host, username, password string, database *string) string
GetDirEnvironmentVariableName() string
GetAllTables(db *sql.DB, database string) []string
GetAllDatabases(db *sql.DB) []string
EscapeFieldName(name string) string
}
type DBDumpFlavour interface {
GetValueSeparator() string
GetDumpHeader() []string
GetDumpFooter() []string
GetTableHeader(table string, cols []string) []string
......@@ -47,7 +52,6 @@ type DBDumpFlavour interface {
GetDatabaseHeader(db string) []string
GetDatabaseFooter(db string) []string
EscapeData(data string) string
EscapeFieldName(name string) string
GetRowHeader(table string, cols []string) string
GetRowFooter(table string, cols []string) string
ProcessField(data string, columnTypes []*sql.ColumnType, index int) (string, bool)
......@@ -57,8 +61,6 @@ type DBDumpFlavour interface {
GetBatchFooter(table string, cols []string) string
}
type createDBFlavour func() (DBDumpFlavour, error)
// Prints out usage help
func usage() {
println("Version:", version)
......@@ -118,7 +120,7 @@ func dumpTable(connectionString, dbname, table string, output chan<- string) {
defer close(output)
// Query data
statement, err := db.Prepare("SELECT * FROM " + dbFlavour.EscapeFieldName(table))
statement, err := db.Prepare("SELECT * FROM " + dbDriverFlavour.EscapeFieldName(table))
if err != nil {
panic(err.Error())
}
......@@ -146,12 +148,12 @@ func dumpTable(connectionString, dbname, table string, output chan<- string) {
}
// DB Flavour specific entries to the dump (eg.: locks)
for _, v := range dbFlavour.GetTableHeader(table, columns) {
for _, v := range dbDumpFlavour.GetTableHeader(table, columns) {
output <- v
}
// The first part of the INSERT statement
insert := dbFlavour.GetBatchHeader(table, columns)
insert := dbDumpFlavour.GetBatchHeader(table, columns)
for rows.Next() {
if count != 0 {
insert += ","
......@@ -173,38 +175,38 @@ func dumpTable(connectionString, dbname, table string, output chan<- string) {
row += "NULL"
} else {
// Handle data first by the DB Flavour. If authoritative is `true` we want the result to be in the dump and no further processing is needed.
if data, authoritative := dbFlavour.ProcessField(*col, columnTypes, i); authoritative {
if data, authoritative := dbDumpFlavour.ProcessField(*col, columnTypes, i); authoritative {
row += data
} else {
if data := anon.Anonymise(dbname, table, rowuuid, columns[i], *col); data == nil {
row += "NULL"
} else {
row += dbFlavour.EscapeData(*data)
row += dbDumpFlavour.EscapeData(*data)
}
}
}
if i < len(values) - 1 {
row += dbFlavour.GetValueSeparator()
row += dbDumpFlavour.GetValueSeparator()
}
}
insert += dbFlavour.GetRowHeader(table, columns) + row + dbFlavour.GetRowFooter(table, columns)
insert += dbDumpFlavour.GetRowHeader(table, columns) + row + dbDumpFlavour.GetRowFooter(table, columns)
count++
// Handle the Multi-insert assembly
if count == dbFlavour.GetNumberOfRowsInOneBatch() || len(insert) > dbFlavour.GetMaxLineLength() {
output <- insert + dbFlavour.GetBatchFooter(table, columns)
insert = dbFlavour.GetBatchHeader(table, columns)
if count == dbDumpFlavour.GetNumberOfRowsInOneBatch() || len(insert) > dbDumpFlavour.GetMaxLineLength() {
output <- insert + dbDumpFlavour.GetBatchFooter(table, columns)
insert = dbDumpFlavour.GetBatchHeader(table, columns)
count = 0
}
}
// If we have left-over data, print that too
if count != 0 {
output <- insert + dbFlavour.GetBatchFooter(table, columns)
output <- insert + dbDumpFlavour.GetBatchFooter(table, columns)
}
// Closing statements for the table
for _, v := range dbFlavour.GetTableFooter(table) {
for _, v := range dbDumpFlavour.GetTableFooter(table) {
output <- v
}
......@@ -237,7 +239,7 @@ func processOutput(input <-chan string) {
// Creates an sql.DB connection based on the passed-in connection string
func openDatabase(connectionString string) (*sql.DB, error) {
db, err := sql.Open(dbFlavour.GetDriverName(), connectionString)
db, err := sql.Open(dbDriverFlavour.GetDriverName(), connectionString)
if err != nil {
return nil, err
}
......@@ -262,11 +264,11 @@ func dumpDatabase(database string, tables []string, username, password, host *st
// Get all the tables from the current DB if none supplied via command line
if len(tables) == 0 {
db, err := openDatabase(dbFlavour.GetConnectionString(*host, *username, *password, &database))
db, err := openDatabase(dbDriverFlavour.GetConnectionString(*host, *username, *password, &database))
if err != nil {
panic(err.Error())
}
tables = dbFlavour.GetAllTables(db, database)
tables = dbDriverFlavour.GetAllTables(db, database)
db.Close()
}
......@@ -286,7 +288,7 @@ func dumpDatabase(database string, tables []string, username, password, host *st
gdpr.StateTable.New(database + ":" + tables[t])
// Start the sender and receiver
go dumpTable(dbFlavour.GetConnectionString(*host, *username, *password, &database), database, tables[t], channels[t])
go dumpTable(dbDriverFlavour.GetConnectionString(*host, *username, *password, &database), database, tables[t], channels[t])
go processOutput(channels[t])
if singleThreaded != nil && *singleThreaded {
......@@ -300,7 +302,7 @@ func dumpDatabase(database string, tables []string, username, password, host *st
// Gets the config directory
func getConfigDir() string {
// Get the directory for our JSON files
dir, exists := os.LookupEnv(dbFlavour.GetDirEnvironmentVariableName())
dir, exists := os.LookupEnv(dbDriverFlavour.GetDirEnvironmentVariableName())
if !exists {
dir, _ = filepath.Abs(filepath.Dir(os.Args[0]))
}
......@@ -309,7 +311,7 @@ func getConfigDir() string {
}
func doDump(createFlavour createDBFlavour) {
func doDump(createFlavour func(*string) (DBDriverFlavour, DBDumpFlavour)) {
var tables []string
var err error
......@@ -320,9 +322,10 @@ func doDump(createFlavour createDBFlavour) {
perDBConcurrency = flag.Bool("c", false, "If set all table dumps will be concurrent within a DB, but DBs are processed sequentially")
maxIdleConnections = flag.Int("i", -1, "Max DB idle connections (-1 means default by the driver)")
maxConnectionLifetime = flag.Int("l", -1, "Max DB connection lifetime in seconds (-1 means default by driver)")
outputFlavour := flag.String("o", strings.Replace(filepath.Base(os.Args[0]), "anondump", "", -1), "Output flavour. One of the supported types")
flag.Parse()
dbFlavour, err = createFlavour()
dbDriverFlavour, dbDumpFlavour = createFlavour(outputFlavour)
// Check if username and password is correctly passed in or parsed from .my.cnf
if username == nil || password == nil || *username == "" || *password == "" {
......@@ -349,12 +352,12 @@ func doDump(createFlavour createDBFlavour) {
// Get the list of databases we need to dump
if len(flag.Args()) == 0 {
db, err := openDatabase(dbFlavour.GetConnectionString(*host, *username, *password, nil))
db, err := openDatabase(dbDriverFlavour.GetConnectionString(*host, *username, *password, nil))
if err != nil {
panic(err.Error())
}
defer db.Close()
databases = dbFlavour.GetAllDatabases(db)
databases = dbDriverFlavour.GetAllDatabases(db)
} else {
databases = flag.Args()[0:1]
}
......@@ -365,7 +368,7 @@ func doDump(createFlavour createDBFlavour) {
}
// Add the dump header (if any)
for _, v := range dbFlavour.GetDumpHeader() {
for _, v := range dbDumpFlavour.GetDumpHeader() {
fmt.Println(v)
}
......@@ -376,12 +379,12 @@ func doDump(createFlavour createDBFlavour) {
}
// If we have a database header, we need to wait for the running dumps
if len(dbFlavour.GetDatabaseHeader(databases[i])) > 0 {
if len(dbDumpFlavour.GetDatabaseHeader(databases[i])) > 0 {
// Wait for the concurrent dumps to complete
processWaitGroup.Wait()
// Add database header
for _, v := range dbFlavour.GetDatabaseHeader(databases[i]) {
for _, v := range dbDumpFlavour.GetDatabaseHeader(databases[i]) {
fmt.Println(v)
}
}
......@@ -389,12 +392,12 @@ func doDump(createFlavour createDBFlavour) {
dumpDatabase(databases[i], tables, username, password, host)
// If we have a database footer or we are running in per-db-concurrency, wait for the dump to complete
if len(dbFlavour.GetDatabaseFooter(databases[i])) > 0 || (perDBConcurrency != nil && *perDBConcurrency) {
if len(dbDumpFlavour.GetDatabaseFooter(databases[i])) > 0 || (perDBConcurrency != nil && *perDBConcurrency) {
// Wait for the concurrent dumps to complete
processWaitGroup.Wait()
// Add database footer (if any)
for _, v := range dbFlavour.GetDatabaseFooter(databases[i]) {
for _, v := range dbDumpFlavour.GetDatabaseFooter(databases[i]) {
fmt.Println(v)
}
}
......@@ -404,7 +407,7 @@ func doDump(createFlavour createDBFlavour) {
processWaitGroup.Wait()
// Add the dump footer (if any)
for _, v := range dbFlavour.GetDumpFooter() {
for _, v := range dbDumpFlavour.GetDumpFooter() {
fmt.Println(v)
}
}
......@@ -6,5 +6,8 @@ import (
)
func main() {
doDump(func () (DBDumpFlavour, error) { return mssql.NewMSSQLDBDumpFlavour() })
doDump(func (*string) (DBDriverFlavour, DBDumpFlavour) {
flavour, _ := mssql.NewMSSQLDBDumpFlavour()
return flavour, flavour
})
}
......@@ -4,23 +4,44 @@ import (
_ "github.com/lib/pq"
_ "github.com/go-sql-driver/mysql"
_ "github.com/denisenkom/go-mssqldb"
"os"
"path/filepath"
"./postgres"
"./mysql"
"./mssql"
"path/filepath"
"os"
)
func main() {
switch filepath.Base(os.Args[0]) {
case "pganondump":
doDump(func () (DBDumpFlavour, error) { return postgres.NewPostgresDBDumpFlavour() })
case "mysqlanondump":
doDump(func () (DBDumpFlavour, error) { return mysql.NewMySQLDBDumpFlavour() })
case "mssqlanondump":
doDump(func () (DBDumpFlavour, error) { return mssql.NewMSSQLDBDumpFlavour() })
default:
println("This is a multi-call binary which should be called by one of the supported names")
os.Exit(1)
}
doDump(func (outputFlavour *string) (DBDriverFlavour, DBDumpFlavour) {
var inputFlavour interface{}
switch filepath.Base(os.Args[0]) {
case "pganondump":
inputFlavour, _ = postgres.NewPostgresDBDumpFlavour()
case "mysqlanondump":
inputFlavour, _ = mysql.NewMySQLDBDumpFlavour()
case "mssqlanondump":
inputFlavour, _ = mssql.NewMSSQLDBDumpFlavour()
default:
panic("This is a multi-call binary which should be called by one of the supported names")
}
if outputFlavour == nil {
return inputFlavour.(DBDriverFlavour), inputFlavour.(DBDumpFlavour)
}
switch *outputFlavour {
case "pg":
outputFlavour, _ := postgres.NewPostgresDBDumpFlavour()
return inputFlavour.(DBDriverFlavour), outputFlavour
case "mysql":
outputFlavour, _ := mysql.NewMySQLDBDumpFlavour()
return inputFlavour.(DBDriverFlavour), outputFlavour
case "mssql":
outputFlavour, _ := mssql.NewMSSQLDBDumpFlavour()
return inputFlavour.(DBDriverFlavour), outputFlavour
default:
panic("Output flavour '" + *outputFlavour + "' is not supported")
}
})
}
......@@ -6,5 +6,8 @@ import (
)
func main() {
doDump(func () (DBDumpFlavour, error) { return mysql.NewMySQLDBDumpFlavour() })
doDump(func (*string) (DBDriverFlavour, DBDumpFlavour) {
flavour, _ := mysql.NewMySQLDBDumpFlavour()
return flavour, flavour
})
}
......@@ -6,5 +6,8 @@ import (
)
func main() {
doDump(func () (DBDumpFlavour, error) { return postgres.NewPostgresDBDumpFlavour() })
doDump(func (*string) (DBDriverFlavour, DBDumpFlavour) {
flavour, _ := postgres.NewPostgresDBDumpFlavour()
return flavour, flavour
})
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment