locationtech / rasterframes

Geospatial Raster support for Spark DataFrames
http://rasterframes.io
Apache License 2.0
248 stars 45 forks source link

Hello,I am doing a research about Active Learning under Spark,When I use idea to process my code,i have the same issue. I am waiting for resolve this problem two days.Can you tell me about the answer? #628

Open ViViBao opened 6 months ago

ViViBao commented 6 months ago

1

import org.apache.spark.sql.{SparkSession, DataFrame} import org.apache.spark.ml.classification.{DecisionTreeClassifier, DecisionTreeClassificationModel} import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.functions._

object ActiveLearningExample { def main(args: Array[String]): Unit = { val spark = SparkSession.builder() .appName("ActiveLearningExample") .master("local[*]") .config("spark.driver.host", "localhost") .getOrCreate()

// 打印当前的 Spark 配置
println("Current Spark Configuration:")
spark.conf.getAll.foreach(println)

// 读取数据
val data = spark.read.parquet("/home/hadoop/桌面/cleaned_data")
println("Initial Data:")
data.show(5)
println(s"Count of initial data: ${data.count()}")

// 检查空值并处理
val dataWithoutNulls = data.na.drop()
println("Data without nulls:")
dataWithoutNulls.show(5)
println(s"Count after removing nulls: ${dataWithoutNulls.count()}")

// 转换列的类型为数值类型
val dataWithNumeric = dataWithoutNulls
  .withColumn("sepal_length", col("sepal_length").cast("double"))
  .withColumn("sepal_width", col("sepal_width").cast("double"))
  .withColumn("petal_length", col("petal_length").cast("double"))
  .withColumn("petal_width", col("petal_width").cast("double"))
  .withColumn("species", col("species").cast("double"))

println("Data with Numeric Columns:")
dataWithNumeric.show(5)
println(s"Count after type conversion: ${dataWithNumeric.count()}")

// 准备特征向量
val featureCols = Array("sepal_length", "sepal_width", "petal_length", "petal_width")
val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
val assembledData = assembler.transform(dataWithNumeric)

println("Data with Features Vector:")
assembledData.select("features", "species").show(5)
println(s"Count after assembling features: ${assembledData.count()}")

// 将标签列转换为数值类型
val labeledData = assembledData.withColumn("label", col("species"))
val Array(trainingData, testData) = labeledData.randomSplit(Array(0.8, 0.2), seed = 1234L)

println(s"Training data count: ${trainingData.count()}")
println(s"Test data count: ${testData.count()}")

// 训练决策树分类器
val dt = new DecisionTreeClassifier().setFeaturesCol("features").setLabelCol("label")
val dtModel = dt.fit(trainingData)

// 打印模型信息
println("Decision Tree Model:")
println(s"Number of nodes: ${dtModel.numNodes}")
println(s"Depth of tree: ${dtModel.depth}")

// 执行主动学习策略,选取样本进行标注
val samplesToLabel = ActiveLearningStrategy.selectSamplesForLabeling(testData, dtModel, 5)

// 打印选取的样本
println("Selected Samples for Labeling:")
samplesToLabel.show()

spark.stop()

} }

object ActiveLearningStrategy { import org.apache.spark.ml.linalg.Vector // 确保导入正确的Vector类型

def calculateEntropy(probabilities: Vector): Double = { probabilities.toArray.map(p => if (p == 0) 0 else -p * Math.log(p)).sum }

def selectSamplesForLabeling(data: DataFrame, model: DecisionTreeClassificationModel, k: Int): DataFrame = { val predictions = model.transform(data)

// 打印预测结果
println("Predictions:")
predictions.select("features", "probability", "prediction").show(5)

val entropyUDF = udf((probability: Vector) => calculateEntropy(probability))

val dataWithEntropy = predictions.withColumn("entropy", entropyUDF(col("probability")))

// 打印带有熵值的数据
println("Data with Entropy:")
dataWithEntropy.select("features", "probability", "entropy").show(5)

dataWithEntropy.orderBy(desc("entropy")).limit(k)

} }