Open ViViBao opened 6 months ago
import org.apache.spark.sql.{SparkSession, DataFrame} import org.apache.spark.ml.classification.{DecisionTreeClassifier, DecisionTreeClassificationModel} import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.functions._
object ActiveLearningExample { def main(args: Array[String]): Unit = { val spark = SparkSession.builder() .appName("ActiveLearningExample") .master("local[*]") .config("spark.driver.host", "localhost") .getOrCreate()
// 打印当前的 Spark 配置 println("Current Spark Configuration:") spark.conf.getAll.foreach(println) // 读取数据 val data = spark.read.parquet("/home/hadoop/桌面/cleaned_data") println("Initial Data:") data.show(5) println(s"Count of initial data: ${data.count()}") // 检查空值并处理 val dataWithoutNulls = data.na.drop() println("Data without nulls:") dataWithoutNulls.show(5) println(s"Count after removing nulls: ${dataWithoutNulls.count()}") // 转换列的类型为数值类型 val dataWithNumeric = dataWithoutNulls .withColumn("sepal_length", col("sepal_length").cast("double")) .withColumn("sepal_width", col("sepal_width").cast("double")) .withColumn("petal_length", col("petal_length").cast("double")) .withColumn("petal_width", col("petal_width").cast("double")) .withColumn("species", col("species").cast("double")) println("Data with Numeric Columns:") dataWithNumeric.show(5) println(s"Count after type conversion: ${dataWithNumeric.count()}") // 准备特征向量 val featureCols = Array("sepal_length", "sepal_width", "petal_length", "petal_width") val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features") val assembledData = assembler.transform(dataWithNumeric) println("Data with Features Vector:") assembledData.select("features", "species").show(5) println(s"Count after assembling features: ${assembledData.count()}") // 将标签列转换为数值类型 val labeledData = assembledData.withColumn("label", col("species")) val Array(trainingData, testData) = labeledData.randomSplit(Array(0.8, 0.2), seed = 1234L) println(s"Training data count: ${trainingData.count()}") println(s"Test data count: ${testData.count()}") // 训练决策树分类器 val dt = new DecisionTreeClassifier().setFeaturesCol("features").setLabelCol("label") val dtModel = dt.fit(trainingData) // 打印模型信息 println("Decision Tree Model:") println(s"Number of nodes: ${dtModel.numNodes}") println(s"Depth of tree: ${dtModel.depth}") // 执行主动学习策略,选取样本进行标注 val samplesToLabel = ActiveLearningStrategy.selectSamplesForLabeling(testData, dtModel, 5) // 打印选取的样本 println("Selected Samples for Labeling:") samplesToLabel.show() spark.stop()
} }
object ActiveLearningStrategy { import org.apache.spark.ml.linalg.Vector // 确保导入正确的Vector类型
def calculateEntropy(probabilities: Vector): Double = { probabilities.toArray.map(p => if (p == 0) 0 else -p * Math.log(p)).sum }
def selectSamplesForLabeling(data: DataFrame, model: DecisionTreeClassificationModel, k: Int): DataFrame = { val predictions = model.transform(data)
// 打印预测结果 println("Predictions:") predictions.select("features", "probability", "prediction").show(5) val entropyUDF = udf((probability: Vector) => calculateEntropy(probability)) val dataWithEntropy = predictions.withColumn("entropy", entropyUDF(col("probability"))) // 打印带有熵值的数据 println("Data with Entropy:") dataWithEntropy.select("features", "probability", "entropy").show(5) dataWithEntropy.orderBy(desc("entropy")).limit(k)
import org.apache.spark.sql.{SparkSession, DataFrame} import org.apache.spark.ml.classification.{DecisionTreeClassifier, DecisionTreeClassificationModel} import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.linalg.Vector import org.apache.spark.sql.functions._
object ActiveLearningExample { def main(args: Array[String]): Unit = { val spark = SparkSession.builder() .appName("ActiveLearningExample") .master("local[*]") .config("spark.driver.host", "localhost") .getOrCreate()
} }
object ActiveLearningStrategy { import org.apache.spark.ml.linalg.Vector // 确保导入正确的Vector类型
def calculateEntropy(probabilities: Vector): Double = { probabilities.toArray.map(p => if (p == 0) 0 else -p * Math.log(p)).sum }
def selectSamplesForLabeling(data: DataFrame, model: DecisionTreeClassificationModel, k: Int): DataFrame = { val predictions = model.transform(data)
} }