Open hannesmiller opened 7 years ago
There's a floor in my proposal using rownum - each rownum predicate performs a table scan - Spark does this differently in a more efficient manner.
You can for example supply a hash function on a column (primary key) to return a long - most databases come with some kind of hash function
I think spark just says x < col < y doesn't it? So it might not be uniformly distributed.
You mentioned switching this to use a mod ?
Yeah will explain more tomorrow - it's not that straight forward but doable.
Basically you need to use a hash and mod functions together on the select column - Oracle supports both, e.g.:
Select * From ( Select blah, mod(hash(primary_col),8) + 1 as part From table) Where part = 1
Therefore I think the onus should be on the user to tell EEL what the partition number column name is so that you can literally wrap the sql and add the where predicate , e.g: Where part_num = 1
The hardest bit I suppose is kicking off the parallel threads and joining the results as and when each thread finishes.
package hannesmiller
import java.io.{File, PrintWriter}
import java.util.concurrent.{Callable, Executors}
import com.sksamuel.exts.Logging
import org.apache.commons.dbcp2.BasicDataSource
import scala.collection.mutable.ListBuffer
object MultiHashDcfQuery extends App with Logging {
private def generateStatsFile(fileName: String, stats: ListBuffer[String]): Unit = {
val statsFile = new File(fileName)
println(s"Generating ${statsFile.getAbsolutePath} ...")
val statsFileWriter = new PrintWriter(statsFile)
stats.foreach { s => statsFileWriter.write(s + "\n"); statsFileWriter.flush() }
statsFileWriter.close()
println(s"${statsFile.getAbsolutePath} done!")
}
val recordCount = 49510353L
val partitionsStartNumber = 2
val numberOfPartitions = 8
val numberOfRuns = 1
val sql =
s"""SELECT MY_PRIMARY_KEY, COL2, COL3, COL4, COL5
FROM MY_TABLE
WHERE COL2 in (8682)"""
def buildPartitionSql(bindExpression: String, bindExpressionAlias: String): String = {
s"""
|SELECT *
|FROM (
| SELECT eel_tmp.*, $bindExpression AS $bindExpressionAlias
| FROM ( $sql ) eel_tmp
|)
|WHERE $bindExpressionAlias = ?
|""".stripMargin
}
// Setup the database connection pool equal to the number of partitions - could be less depending on your connection
// resource limit on the Database server.
val dataSource = new BasicDataSource()
dataSource.setDriverClassName("oracle.jdbc.OracleDriver")
dataSource.setUrl("jdbc:oracle:thin:@//myhost:1901/myservice")
dataSource.setUsername("username")
dataSource.setPassword("username1234")
dataSource.setPoolPreparedStatements(false)
dataSource.setInitialSize(numberOfPartitions)
dataSource.setDefaultAutoCommit(false)
dataSource.setMaxOpenPreparedStatements(numberOfPartitions)
val stats = ListBuffer[String]()
for (numPartitions <- partitionsStartNumber to numberOfPartitions) {
for (runNumber <- 1 to numberOfRuns) {
// Kick off a number of threads equal to the number of partitions so each partitioned query is executed on parallel.
val threadPool = Executors.newFixedThreadPool(numberOfPartitions)
val startTime = System.currentTimeMillis()
val fetchSize = 100600
val futures = for (i <- 1 to numberOfPartitions) yield {
threadPool.submit(new Callable[(Long, Long, Long, Long)] {
override def call(): (Long, Long, Long, Long) = {
var rowCount = 0L
// Capture metrics about acquiring connection
val connectionIdleTimeStart = System.currentTimeMillis()
val connection = dataSource.getConnection
val connectionIdleTime = System.currentTimeMillis() - connectionIdleTimeStart
val partSql = buildPartitionSql(s"MOD(ORA_HASH(MY_PRIMARY_KEY),$numberOfPartitions) + 1", "PARTITION_NUMBER")
val prepareStatement = connection.prepareStatement(partSql)
prepareStatement.setFetchSize(fetchSize)
prepareStatement.setLong(1, i)
// Capture metrics for query execution
val excuteQueryTimeStart = System.currentTimeMillis()
val rs = prepareStatement.executeQuery()
val executeQueryTime = (System.currentTimeMillis() - excuteQueryTimeStart) / 1000
// Capture metrics for fetching data
val fetchTimeStart = System.currentTimeMillis()
while (rs.next()) {
rowCount += 1
if (rowCount % fetchSize == 0) logger.info(s"RowCount = $rowCount")
}
val fetchTime = (System.currentTimeMillis() - fetchTimeStart) / 1000
prepareStatement.close()
rs.close()
connection.close()
(connectionIdleTime, executeQueryTime, fetchTime, rowCount)
}
})
}
// Total up all the rows
var totalRowCount = 0L
var totalConnectionIdleTime = 0L
futures.foreach { f =>
val (connectionIdleTime, executeQueryTime, fetchTime, rowCount) = f.get
logger.info(s"connectionIdleTime=$connectionIdleTime, executeQueryTime=$executeQueryTime, fetchTime=$fetchTime, rowCount=$rowCount")
totalConnectionIdleTime += connectionIdleTime
totalRowCount += rowCount
}
val elapsedTime = (System.currentTimeMillis() - startTime) / 1000.0
logger.info(s"Run $runNumber with $numPartitions partition(s): Took $elapsedTime second(s) for RowCount = $totalRowCount, totalConnectionIdlTime = $totalConnectionIdleTime")
threadPool.shutdownNow()
stats += s"$numPartitions\t$runNumber\t$elapsedTime"
}
}
generateStatsFile("multi_partition_stats.csv", stats)
}
val partSql = buildPartitionSql(s"MOD(ORA_HASH(F_CASH_FLOW_ID),$numberOfPartitions) + 1", "PARTITION_NUMBER")
...
...
def buildPartitionSql(bindExpression: String, bindExpressionAlias: String): String = {
s"""
|SELECT *
|FROM (
| SELECT eel_tmp.*, $bindExpression AS $bindExpressionAlias
| FROM ( $sql ) eel_tmp
|)
|WHERE $bindExpressionAlias = ?
|""".stripMargin
}
val prepareStatement = connection.prepareStatement(partSql)
prepareStatement.setFetchSize(fetchSize)
prepareStatement.setLong(1, i)
I have experimented with my own custom JdbSource based on the original with some new arguments for supporting 2 different partition strategies.
case class JdbcSource(connFn: () => Connection,
query: String,
partHashFuncExpr: Option[String] = None,
partColumnAlias: Option[String] = None,
partRangeColumn: Option[String] = None,
minVal: Option[Long] = None,
maxVal: Option[Long] = None,
numberOfParts: Int = 1,
bind: (PreparedStatement) => Unit = stmt => (),
fetchSize: Int = 100,
providedSchema: Option[StructType] = None,
providedDialect: Option[JdbcDialect] = None,
bucketing: Option[Bucketing] = None)
val numberOfPartitions = 4
JdbcSource(() => dataSource.getConnection(), query)
.withHashPartitioning(s"MOD(ORA_HASH(ID),$numberOfPartitions) + 1", "PARTITION_NUMBER", numberOfPartitions)
Note this example is using Oracle Modulus and Hash functions – you can use an equivalent function for another Database dialect, e.g. SQLServer.
val numberOfPartitions = 4
JdbcSource(() => dataSource.getConnection(), query)
.withRangePartitioning("ID", 1, 201786, numberOfPartitions)
override def parts(): List[JdbcPart] = {
if (partHashFuncExpr.nonEmpty || partRangeColumn.nonEmpty) {
val jdbcParts: Seq[JdbcPart] = for (i <- 1 to numberOfParts) yield new JdbcPart(connFn, buildPartSql(i), bind, fetchSize, dialect())
jdbcParts.toList
} else List(new JdbcPart(connFn, query, bind, fetchSize, dialect()))
}
private def buildPartSql(partitionNumber: Int): String = {
if (partHashFuncExpr.nonEmpty) {
s"""
|SELECT *
|FROM (
| SELECT eel_tmp.*, ${partHashFuncExpr.get} AS ${partColumnAlias.get}
| FROM ( $query ) eel_tmp
|)
|WHERE ${partColumnAlias.get} = $partitionNumber
|""".stripMargin
} else if (partRangeColumn.nonEmpty) {
val partitionRanges = generatePartRanges(minVal.get, maxVal.get, numberOfParts)(partitionNumber - 1)
s"""
|SELECT *
|FROM (
| SELECT *
| FROM ( $query )
|)
|WHERE ${partRangeColumn.get} BETWEEN ${partitionRanges.min} AND ${partitionRanges.max}
|""".stripMargin
}
else query
}
case class PartRange(min: Long, max: Long)
private def generatePartRanges(min: Long, max: Long, numberOfPartitions: Int): Array[PartRange] = {
val partitionRanges = new Array[PartRange](numberOfPartitions)
val bucketSizes = new Array[Long](numberOfPartitions)
val evenLength = (max - min + 1) / numberOfPartitions
for (i <- 0 until numberOfPartitions) bucketSizes(i) = evenLength
// distribute surplus as evenly as possible across buckets
var surplus = (max - min + 1) % numberOfPartitions
var i: Int = 0
while (surplus > 0) {
bucketSizes(i) += 1
surplus -= 1
i = (i + 1) % numberOfPartitions
}
i = 0
var n = 0
var k = min
while (i < numberOfPartitions && k <= max) {
partitionRanges(i) = PartRange(k, k + bucketSizes(i) - 1)
k += bucketSizes(i)
i += 1
n += 1
}
partitionRanges
}
}
def fetchSchema(): StructType = {
using(connFn()) { conn =>
val schemaQuery = s"SELECT * FROM (${buildPartSql(1)}) tmp WHERE 1=0"
using(conn.prepareStatement(schemaQuery)) { stmt =>
stmt.setFetchSize(fetchSize)
bind(stmt)
val rs = timed("Executing query $query") {
stmt.executeQuery()
}
val schema = schemaFor(dialect(), rs)
rs.close()
schema
}
}
}
Hi, I have 40 Million records for table in oracle. I to use spark jdbc , i have to write this data to csv files. Can u plz help me on Lowerbound, upperbound ,numofpartitions and hashing function to split data equally across all the partitions
Hi Vennapuc,
Let me look into this and I will get back to you with an answer - there are a few JdbcSource partitioning strategies you can use - I have actually used the HashPartitioning strategy on Oracle with a similar population to yours...
If your table has a primary key that is a number (i.e. a sequence) then the with hash strategy you can specify:
What the above does is split your main query into multiple part queries where each part is assigned to a separate thread.
For the RangeBound strategy you will need to do something similar to the above but you MUST know up front what MAX count is which can be determined with a sql query beforehand.
I think this strategy requires some kind of SQL expression like a row_number analytical function.
I think for a next major release (1.3) we should put some examples together.
Regards, Hannes
Thanks hannesmiller for your response. I see some hash functions above "MOD(ORA_HASH(ID),$numberOfPartitions) + 1 .. You are adding "1" for hashcode. Is there any reason for this. I was able to do this with has function (1 + mod(hash(fa_id), %(5)s)) as hash_code in oracle. Where Fa_id is numeric column. I choose bucking number as 5 above. Any suggestions how to choose this number?. My concern is how to choose no of partitions and lower bound and upper bound for ~40 Million records. Can you please help on choosing right no of parameters.
Thanks,
Hi Vennapuc, Unfortunately there isn't an exact science for how many partitions you should have as there are too factors to consider (profile of your Oracle server, how many cores you have available on the client machine, etc...) – the best way is trial and error, i.e. experiment this on your target environment.
Can you tell me what version of EEL you are using? I ask because in the latest alpha release you can simply specify the Hash partition strategy on the JdbcSource, e.g.:
val partitions = 4
JdbcSource(connFn, query)
.withPartitionStrategy(HashPartitionStrategy(s”MOD(ORA_HASH(key), $partitions)”, partitions)
.withProvidedDialect(new OracleJdbcDialect)
.withFetchSize(...)
I am using spark 1.6.2 and scala 2.10.5...
Hi Vennapuc, I suggest you try posting your query to a Spark forum?
Our product EEL is light weight BigData scala library for data ingest into environments such as Hadoop.
Regards, Hannes
Hello, Hannesmiller ! I am experimenting with connection to oracle 12 g via spark jdbc driver Table in oracle has 100 million rows. Launch spark application in local mode. Laptop has 4 CPU code fragment: val df = spark.read .format("jdbc") .option("url", "jdbc:oracle:thin:login/password@172.17.0.2:1521:ORCLCDB") .option("dbtable","C##BUSER.CUSTOMERS") .option("driver", "oracle.jdbc.OracleDriver") .option("numPartitions", 4) .option("partitionColumn", "CUST_ID") .option("lowerBound", 1) .option("upperBound", 100000000) .load() df.write.csv("/path_to_file_to_save")
I tried to launch this application with numPartitions=1, numPartitions=4 and numPartitions=10 Reading data from Oracle and writing locally with partitions (4 or 10) takes 3 times more than without partitions.
Could you please help me to understand how to increase speed of reading using param "numPartitions" ?
Part 1:
Part 2:
Part 3:
Part 4:
Part 5:
For partition 5 just return the remainder of rows.
Note it may not be necessary to create N connections for each partition - simply return N JDBC result sets - one for each partition - investigating this...
Proposal
Example of returning N JDBC result sets