Open DrPepper8888 opened 3 months ago
from pyspark.sql import SparkSession from pyspark.ml.feature import VectorAssembler from pyspark.ml.clustering import GaussianMixture from pyspark.ml import Pipeline from pyspark.sql.functions import col # 初始化SparkSession spark = SparkSession.builder.appName("GaussianMixtureExample").getOrCreate() # 读取数据 df = spark.read.csv('features.csv', header=True, inferSchema=True) # 数据预处理 # 将交易日期转换为日期时间 df = df.withColumn('TRAN_DT', df['TRAN_DT'].cast('timestamp')) # 特征选择 # 选择交易金额、持有时间和净资产作为特征 assembler = VectorAssembler(inputCols=['TRAN_AMT', 'HOLDING_TIME', 'NET_WORTH'], outputCol='features') df = assembler.transform(df) # 聚类分析 # 创建 GaussianMixture 模型 gmm = GaussianMixture().setK(5).setSeed(0).setMaxIter(100).setFeaturesCol("features").setPredictionCol("prediction") # 训练模型 model = gmm.fit(df) # 将聚类结果添加到原始数据框中,并将新列重命名为'Cluster' df = model.transform(df).withColumnRenamed('prediction', 'Cluster')