diff --git a/src/main/java/com/service/impl/RecommendServiceImpl.java b/src/main/java/com/service/impl/RecommendServiceImpl.java index 218368e..1dcd97b 100644 --- a/src/main/java/com/service/impl/RecommendServiceImpl.java +++ b/src/main/java/com/service/impl/RecommendServiceImpl.java @@ -9,46 +9,59 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.SparkSession; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.apache.spark.mllib.recommendation.Rating; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.ALS; import java.util.ArrayList; import java.util.List; import java.util.Properties; -import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; -import org.apache.spark.mllib.recommendation.ALS; - @Service public class RecommendServiceImpl implements RecommendService { + private static final Logger logger = LoggerFactory.getLogger(RecommendServiceImpl.class); + @Autowired - private SparkSession spark; // 自动注入 SparkSession + private SparkSession spark; @Value("${spark.mysql.url}") - private String url; + private String url; + @Value("${spark.mysql.table}") - private String table; + private String table; @Value("${spark.mysql.user}") - private String user; + private String user; @Value("${spark.mysql.password}") - private String password; + private String password; @Override public List AllRecommend(Integer total) { + logger.info("开始执行推荐服务,请求推荐数量: {}", total); + logger.debug("MySQL连接配置 - URL: {}, 表名: {}, 用户: {}", url, table, user); + // 设置 MySQL 连接的属性 Properties properties = new Properties(); - properties.put("user", user); // 修改为你的用户名 - properties.put("password",password); // 修改为你的密码 + properties.put("user", user); + properties.put("password", password); properties.put("driver", "com.mysql.cj.jdbc.Driver"); + logger.debug("MySQL连接属性已设置"); // 读取 MySQL 数据 + logger.info("开始从MySQL读取数据..."); Dataset df = readFromMySQL(spark, url, table, properties); - // 使用 alias 重命名字段,确保字段名匹配 + logger.info("成功从MySQL读取数据,数据量: {}", df.count()); + logger.debug("原始数据Schema: {}", df.schema().treeString()); + + // 重命名字段 + logger.info("开始重命名字段..."); Dataset renamedDF = df .select( df.col("user_id").alias("userId"), @@ -56,62 +69,85 @@ public class RecommendServiceImpl implements RecommendService { df.col("star").alias("score"), df.col("create_time").alias("createTime") ) - .na().drop(); // 删除包含空值的行 - // 使用 Spark 的 Encoders 将 Dataset 转换为 ProductRating 类型 + .na().drop(); + logger.info("字段重命名完成,处理后数据量: {}", renamedDF.count()); + logger.debug("重命名后Schema: {}", renamedDF.schema().treeString()); + + // 转换为ProductRating类型 + logger.info("开始转换数据为ProductRating类型..."); Dataset productRatingDataset = renamedDF.as(Encoders.bean(ProductRating.class)); JavaRDD ratingRDD = productRatingDataset.javaRDD(); - // 创建用户和产品的 RDD + logger.info("数据转换完成,RDD分区数: {}", ratingRDD.getNumPartitions()); + + // 创建用户和产品的RDD + logger.info("创建用户和产品的唯一ID列表..."); JavaRDD userRDD = ratingRDD.map(ProductRating::getUserId).distinct(); JavaRDD productRDD = ratingRDD.map(ProductRating::getProductId).distinct(); - //用户编号 + logger.info("唯一用户数: {}, 唯一产品数: {}", userRDD.count(), productRDD.count()); + // 创建训练数据集 - JavaRDD trainData = - ratingRDD.map(rating -> new Rating(rating.getUserId(), rating.getProductId(), rating.getScore()) - ); - // 打印训练数据集 - trainData.collect().forEach(train -> { - System.out.println("UserId: " + train.user() + ", ProductId: " + train.product() + ", Rating: " + train.rating()); + logger.info("准备训练数据..."); + JavaRDD trainData = ratingRDD.map(rating -> + new Rating(rating.getUserId(), rating.getProductId(), rating.getScore()) + ); + logger.debug("训练数据样例:"); + trainData.take(5).forEach(train -> { + logger.debug("训练样本 - UserId: {}, ProductId: {}, Rating: {}", + train.user(), train.product(), train.rating()); }); - // 调用 ALS 算法训练隐语义模型 - int rank = 50; // 隐语义因子的个数 - int iterations = 5; // 迭代次数 - double lambda = 0.01; // 正则化参数 + // ALS模型参数 + int rank = 50; + int iterations = 5; + double lambda = 0.01; + logger.info("开始训练ALS模型,参数 - rank: {}, iterations: {}, lambda: {}", + rank, iterations, lambda); - // 训练 ALS 模型 + // 训练ALS模型 MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainData), rank, iterations, lambda); + logger.info("ALS模型训练完成"); // 计算用户推荐矩阵 + logger.info("计算用户-产品笛卡尔积..."); JavaPairRDD userProducts = userRDD.cartesian(productRDD); + logger.info("用户-产品组合总数: {}", userProducts.count()); - // model已训练好,把id传进去就可以得到预测评分列表RDD[Rating] (userId,productId,rating) + // 预测评分 + logger.info("开始预测评分..."); JavaRDD preRatings = model.predict(userProducts); + logger.info("预测评分完成,预测结果数: {}", preRatings.count()); - preRatings.foreach(userId -> { - System.out.println(userId); - }); - // 过滤出评分大于0的推荐 + // 过滤和排序 + logger.info("过滤和排序推荐结果..."); JavaRDD filteredRatings = preRatings.filter(rating -> rating.rating() > 0); - // 按照评分降序排序,并取前10个推荐 - List topRatings = filteredRatings - .sortBy((Rating rating) -> rating.rating(), false, 1) // 按评分降序排序,false表示降序 - .take(total); // 取前10个评分最高的推荐 + logger.info("过滤后有效推荐数: {}", filteredRatings.count()); + List topRatings = filteredRatings + .sortBy((Rating rating) -> rating.rating(), false, 1) + .take(total); + logger.info("获取到Top {}推荐结果", topRatings.size()); + + // 转换为返回对象 + logger.info("转换推荐结果为业务对象..."); List list = new ArrayList<>(); - // 打印前10个推荐 topRatings.forEach(rating -> { ProductRemRating productRemRating = new ProductRemRating(); productRemRating.setUserId(rating.user()); productRemRating.setProductId(rating.product()); productRemRating.setRating(rating.rating()); list.add(productRemRating); - System.out.println("UserId: " + rating.user() + ", ProductId: " + rating.product() + ", Rating: " + rating.rating()); + logger.debug("推荐结果 - UserId: {}, ProductId: {}, Rating: {}", + rating.user(), rating.product(), rating.rating()); }); + + logger.info("推荐服务执行完成,返回结果数: {}", list.size()); return list; } public static Dataset readFromMySQL(SparkSession spark, String url, String table, Properties properties) { - // 从 MySQL 读取数据并返回 Dataset - return spark.read().jdbc(url, table, properties); + logger.debug("从MySQL读取数据 - URL: {}, Table: {}", url, table); + Dataset df = spark.read().jdbc(url, table, properties); + logger.debug("MySQL数据读取完成"); + return df; } } \ No newline at end of file