推荐算法push

This commit is contained in:
闵宪瑞 2025-04-07 10:52:39 +08:00
parent 5d94fa03dc
commit 5dcd1bee74

View File

@ -9,26 +9,30 @@ import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row; import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value; import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.apache.spark.mllib.recommendation.Rating; import org.apache.spark.mllib.recommendation.Rating;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.ALS;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Properties; import java.util.Properties;
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.ALS;
@Service @Service
public class RecommendServiceImpl implements RecommendService { public class RecommendServiceImpl implements RecommendService {
private static final Logger logger = LoggerFactory.getLogger(RecommendServiceImpl.class);
@Autowired @Autowired
private SparkSession spark; // 自动注入 SparkSession private SparkSession spark;
@Value("${spark.mysql.url}") @Value("${spark.mysql.url}")
private String url; private String url;
@Value("${spark.mysql.table}") @Value("${spark.mysql.table}")
private String table; private String table;
@ -40,15 +44,24 @@ public class RecommendServiceImpl implements RecommendService {
@Override @Override
public List<ProductRemRating> AllRecommend(Integer total) { public List<ProductRemRating> AllRecommend(Integer total) {
logger.info("开始执行推荐服务,请求推荐数量: {}", total);
logger.debug("MySQL连接配置 - URL: {}, 表名: {}, 用户: {}", url, table, user);
// 设置 MySQL 连接的属性 // 设置 MySQL 连接的属性
Properties properties = new Properties(); Properties properties = new Properties();
properties.put("user", user); // 修改为你的用户名 properties.put("user", user);
properties.put("password",password); // 修改为你的密码 properties.put("password", password);
properties.put("driver", "com.mysql.cj.jdbc.Driver"); properties.put("driver", "com.mysql.cj.jdbc.Driver");
logger.debug("MySQL连接属性已设置");
// 读取 MySQL 数据 // 读取 MySQL 数据
logger.info("开始从MySQL读取数据...");
Dataset<Row> df = readFromMySQL(spark, url, table, properties); Dataset<Row> df = readFromMySQL(spark, url, table, properties);
// 使用 alias 重命名字段确保字段名匹配 logger.info("成功从MySQL读取数据数据量: {}", df.count());
logger.debug("原始数据Schema: {}", df.schema().treeString());
// 重命名字段
logger.info("开始重命名字段...");
Dataset<Row> renamedDF = df Dataset<Row> renamedDF = df
.select( .select(
df.col("user_id").alias("userId"), df.col("user_id").alias("userId"),
@ -56,62 +69,85 @@ public class RecommendServiceImpl implements RecommendService {
df.col("star").alias("score"), df.col("star").alias("score"),
df.col("create_time").alias("createTime") df.col("create_time").alias("createTime")
) )
.na().drop(); // 删除包含空值的行 .na().drop();
// 使用 Spark Encoders Dataset<Row> 转换为 ProductRating 类型 logger.info("字段重命名完成,处理后数据量: {}", renamedDF.count());
logger.debug("重命名后Schema: {}", renamedDF.schema().treeString());
// 转换为ProductRating类型
logger.info("开始转换数据为ProductRating类型...");
Dataset<ProductRating> productRatingDataset = renamedDF.as(Encoders.bean(ProductRating.class)); Dataset<ProductRating> productRatingDataset = renamedDF.as(Encoders.bean(ProductRating.class));
JavaRDD<ProductRating> ratingRDD = productRatingDataset.javaRDD(); JavaRDD<ProductRating> ratingRDD = productRatingDataset.javaRDD();
// 创建用户和产品的 RDD logger.info("数据转换完成RDD分区数: {}", ratingRDD.getNumPartitions());
// 创建用户和产品的RDD
logger.info("创建用户和产品的唯一ID列表...");
JavaRDD<Integer> userRDD = ratingRDD.map(ProductRating::getUserId).distinct(); JavaRDD<Integer> userRDD = ratingRDD.map(ProductRating::getUserId).distinct();
JavaRDD<Integer> productRDD = ratingRDD.map(ProductRating::getProductId).distinct(); JavaRDD<Integer> productRDD = ratingRDD.map(ProductRating::getProductId).distinct();
//用户编号 logger.info("唯一用户数: {}, 唯一产品数: {}", userRDD.count(), productRDD.count());
// 创建训练数据集 // 创建训练数据集
JavaRDD<Rating> trainData = logger.info("准备训练数据...");
ratingRDD.map(rating -> new Rating(rating.getUserId(), rating.getProductId(), rating.getScore()) JavaRDD<Rating> trainData = ratingRDD.map(rating ->
new Rating(rating.getUserId(), rating.getProductId(), rating.getScore())
); );
// 打印训练数据集 logger.debug("训练数据样例:");
trainData.collect().forEach(train -> { trainData.take(5).forEach(train -> {
System.out.println("UserId: " + train.user() + ", ProductId: " + train.product() + ", Rating: " + train.rating()); logger.debug("训练样本 - UserId: {}, ProductId: {}, Rating: {}",
train.user(), train.product(), train.rating());
}); });
// 调用 ALS 算法训练隐语义模型 // ALS模型参数
int rank = 50; // 隐语义因子的个数 int rank = 50;
int iterations = 5; // 迭代次数 int iterations = 5;
double lambda = 0.01; // 正则化参数 double lambda = 0.01;
logger.info("开始训练ALS模型参数 - rank: {}, iterations: {}, lambda: {}",
rank, iterations, lambda);
// 训练 ALS 模型 // 训练ALS模型
MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainData), rank, iterations, lambda); MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainData), rank, iterations, lambda);
logger.info("ALS模型训练完成");
// 计算用户推荐矩阵 // 计算用户推荐矩阵
logger.info("计算用户-产品笛卡尔积...");
JavaPairRDD<Integer, Integer> userProducts = userRDD.cartesian(productRDD); JavaPairRDD<Integer, Integer> userProducts = userRDD.cartesian(productRDD);
logger.info("用户-产品组合总数: {}", userProducts.count());
// model已训练好把id传进去就可以得到预测评分列表RDD[Rating] (userId,productId,rating) // 预测评分
logger.info("开始预测评分...");
JavaRDD<Rating> preRatings = model.predict(userProducts); JavaRDD<Rating> preRatings = model.predict(userProducts);
logger.info("预测评分完成,预测结果数: {}", preRatings.count());
preRatings.foreach(userId -> { // 过滤和排序
System.out.println(userId); logger.info("过滤和排序推荐结果...");
});
// 过滤出评分大于0的推荐
JavaRDD<Rating> filteredRatings = preRatings.filter(rating -> rating.rating() > 0); JavaRDD<Rating> filteredRatings = preRatings.filter(rating -> rating.rating() > 0);
// 按照评分降序排序并取前10个推荐 logger.info("过滤后有效推荐数: {}", filteredRatings.count());
List<Rating> topRatings = filteredRatings
.sortBy((Rating rating) -> rating.rating(), false, 1) // 按评分降序排序false表示降序
.take(total); // 取前10个评分最高的推荐
List<Rating> topRatings = filteredRatings
.sortBy((Rating rating) -> rating.rating(), false, 1)
.take(total);
logger.info("获取到Top {}推荐结果", topRatings.size());
// 转换为返回对象
logger.info("转换推荐结果为业务对象...");
List<ProductRemRating> list = new ArrayList<>(); List<ProductRemRating> list = new ArrayList<>();
// 打印前10个推荐
topRatings.forEach(rating -> { topRatings.forEach(rating -> {
ProductRemRating productRemRating = new ProductRemRating(); ProductRemRating productRemRating = new ProductRemRating();
productRemRating.setUserId(rating.user()); productRemRating.setUserId(rating.user());
productRemRating.setProductId(rating.product()); productRemRating.setProductId(rating.product());
productRemRating.setRating(rating.rating()); productRemRating.setRating(rating.rating());
list.add(productRemRating); list.add(productRemRating);
System.out.println("UserId: " + rating.user() + ", ProductId: " + rating.product() + ", Rating: " + rating.rating()); logger.debug("推荐结果 - UserId: {}, ProductId: {}, Rating: {}",
rating.user(), rating.product(), rating.rating());
}); });
logger.info("推荐服务执行完成,返回结果数: {}", list.size());
return list; return list;
} }
public static Dataset<Row> readFromMySQL(SparkSession spark, String url, String table, Properties properties) { public static Dataset<Row> readFromMySQL(SparkSession spark, String url, String table, Properties properties) {
// MySQL 读取数据并返回 Dataset logger.debug("从MySQL读取数据 - URL: {}, Table: {}", url, table);
return spark.read().jdbc(url, table, properties); Dataset<Row> df = spark.read().jdbc(url, table, properties);
logger.debug("MySQL数据读取完成");
return df;
} }
} }