DrPepper8888 / IPMN

0 stars 0 forks source link

XGBOOST FOR NEWS #8

Open DrPepper8888 opened 3 months ago

DrPepper8888 commented 3 months ago
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import GradientBoostedTreesRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml import Pipeline

# 初始化 SparkSession
spark = SparkSession.builder.appName("NewsIndexSensitivityAnalysis").getOrCreate()

# 假设您已经有了一个 DataFrame,其中包含 "news_index"(新闻指数)和 "forex_trade_volume"(外汇交易量)两列
# df = ...

# 创建一个 VectorAssembler,将所有特征合并为一个向量
assembler = VectorAssembler(inputCols=["news_index"], outputCol="features")

# 创建 GradientBoostedTreesRegressor 对象
gbt = GradientBoostedTreesRegressor(
    featuresCol="features",
    labelCol="forex_trade_volume",
    maxIter=10,  # 迭代次数
    maxDepth=3,  # 树的最大深度
    learningRate=0.1,  # 学习率
    maxBins=32  # 用于构建决策树的节点的最大分割数
)

# 创建 Pipeline,串联 VectorAssembler 和 GradientBoostedTreesRegressor
pipeline = Pipeline(stages=[assembler, gbt])

# 训练模型
model = pipeline.fit(df)

# 评估模型
evaluator = RegressionEvaluator(labelCol="forex_trade_volume", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(model.transform(df))
print(f"Root Mean Squared Error (RMSE): {rmse}")

# 预测
predictions = model.transform(df)
predictions.select("news_index", "forex_trade_volume", "prediction").show()

# 停止 SparkSession
spark.stop()