DrPepper8888 / IPMN

0 stars 0 forks source link

gussian #9

Open DrPepper8888 opened 3 months ago

DrPepper8888 commented 3 months ago
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.clustering import GaussianMixture
from pyspark.ml import Pipeline
from pyspark.sql.functions import col

# 初始化SparkSession
spark = SparkSession.builder.appName("GaussianMixtureExample").getOrCreate()

# 读取数据
df = spark.read.csv('features.csv', header=True, inferSchema=True)

# 数据预处理
# 将交易日期转换为日期时间
df = df.withColumn('TRAN_DT', df['TRAN_DT'].cast('timestamp'))

# 特征选择
# 选择交易金额、持有时间和净资产作为特征
assembler = VectorAssembler(inputCols=['TRAN_AMT', 'HOLDING_TIME', 'NET_WORTH'], outputCol='features')
df = assembler.transform(df)

# 聚类分析
# 创建 GaussianMixture 模型
gmm = GaussianMixture().setK(5).setSeed(0).setMaxIter(100).setFeaturesCol("features").setPredictionCol("prediction")

# 训练模型
model = gmm.fit(df)

# 将聚类结果添加到原始数据框中,并将新列重命名为'Cluster'
df = model.transform(df).withColumnRenamed('prediction', 'Cluster')