在我们使用JdbcRDD时系统默认的参数如下:
sc: SparkContext,
getConnection: () => Connection,
sql: String,
lowerBound: Long,
upperBound: Long,
numPartitions: Int,
mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _
根据其注释的说明:
select title, author from books where ? <= id and id <= ?
* @param lowerBound the minimum value of the first placeholder
* @param upperBound the maximum value of the second placeholder
* The lower and upper bounds are inclusive.
* @param numPartitions the number of partitions.
* Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
* the query would be executed twice, once with (1, 10) and once with (11, 20)
由上上面的内容可以发现,JdbcRDD中的主构造函数中这几个参数是必不可少的,且没有辅助构造函数可以使用,于是我们在查询时就不得不输入上下界,即必须输入有查询条件的sql,然后以参数的形式传入JdbcRDD的主构造函数中。我们在实际的使用中,或者在测试中,我们需要不带参数进行使用就显得无能为力,为此,我们该如何做呢?
方法可能有很多,对我们来说,简单的实现由两种方式,即自己实现JdbcRDD和继承JdbcRDD,自己定义辅构造函数。本文只实现自己重新定义JdbcRDD,降低程序的耦合度。
通过查看JdbcRDD的源码发现,其实,
lowerBound 用于定义查询的下标upperBound 用于定义查询的上标numPartitions 用于定义查询的分区数这三个参数在实际的生产环境中,可能很有用,通过该三个参数定义每个分区查询的范数据围,这也是spark人员设计时一定加上该参数的原因。说明:本例仅仅是简单的去掉该三个参数,需要知道的是方式不止这一种,且由于把分区参数去掉了,本代码默认的是一个分区,可以在代码中手动的设置多个分区。修改JdbcRDD的源码,同时需要修改有NextIterator.scala(其实只是挪一下位置,源码不动的挪过来,由于源码是spark包下的private,所以不能引用在其他的包内),该文件就不再粘贴出来。修改后的JdbcRDD.scala改名为JDBCRDD.scala,NextIterator.scala放在与JDBCRDD.scala同一个包内。以下是JDBCRDD.scala源码
import java.sql.{Connection, ResultSet}
import scala.reflect.ClassTag
import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
/**
* Created by Administrator on 2017/9/8.
*/
class JDBCPartition(idx: Int) extends Partition {
override def index: Int = idx
}
class JDBCRDD[T: ClassTag](
sc: SparkContext,
getConnection: () => Connection,
sql: String,
mapRow: (ResultSet) => T = JDBCRDD.resultSetToObjectArray _)
extends RDD[T](sc, Nil) with Logging {
override def getPartitions: Array[Partition] = {
(0 to 1).map { i => new JDBCPartition(i) }.toArray
}
override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] {
context.addTaskCompletionListener { context => closeIfNeeded() }
val part = thePart.asInstanceOf[JDBCPartition]
val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
val url = conn.getMetaData.getURL
if (url.startsWith("jdbc:mysql:")) {
stmt.setFetchSize(Integer.MIN_VALUE)
} else {
stmt.setFetchSize(100)
}
logInfo(s"statement fetch size set to: ${stmt.getFetchSize}")
val rs = stmt.executeQuery()
override def getNext(): T = {
if (rs.next()) {
mapRow(rs)
} else {
finished = true
null.asInstanceOf[T]
}
}
override def close() {
try {
if (null != rs) {
rs.close()
}
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
try {
if (null != stmt) {
stmt.close()
}
} catch {
case e: Exception => logWarning("Exception closing statement", e)
}
try {
if (null != conn) {
conn.close()
}
logInfo("closed connection")
} catch {
case e: Exception => logWarning("Exception closing connection", e)
}
}
}
}
object JDBCRDD {
def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
trait ConnectionFactory extends Serializable {
@throws[Exception]
def getConnection: Connection
}
def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]]
def create[T](
sc: JavaSparkContext,
connectionFactory: ConnectionFactory,
sql: String,
mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {
val JDBCRDD = new JDBCRDD[T](
sc.sc,
() => connectionFactory.getConnection,
sql,
(resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)
new JavaRDD[T](JDBCRDD)(fakeClassTag)
}
def create(
sc: JavaSparkContext,
connectionFactory: ConnectionFactory,
sql: String
): JavaRDD[Array[Object]] = {
val mapRow = new JFunction[ResultSet, Array[Object]] {
override def call(resultSet: ResultSet): Array[Object] = {
resultSetToObjectArray(resultSet)
}
}
create(sc, connectionFactory, sql, mapRow)
}
}
以下是测试上面JDBCRDD.scala的例子
import java.sql.DriverManager
import org.apache.spark.{SparkConf, SparkContext}
/**
* Created by Administrator on 2017/9/8.
*/
object TestJDBC {
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setAppName("TestJDBC").setMaster("local[2]")
val sc = new SparkContext(conf)
try {
val connection = () => {
Class.forName("com.mysql.jdbc.Driver").newInstance()
DriverManager.getConnection("jdbc:mysql://192.168.0.4:3306/spark", "root", "root")
}
val JDBCRDD = new JDBCRDD(
sc,
connection,
"SELECT * FROM result",
r => {
val id = r.getInt(1)
val code = r.getString(2)
(id, code)
}
)
val jrdd = JDBCRDD.collect()
println(JDBCRDD.collect().toBuffer)
sc.stop()
}
catch {
case e: Exception => println(e.printStackTrace())
}
}
}
简单的修改JdbcRDD的源码到此就完成了。希望对你有用。