推荐算法push
This commit is contained in:
parent
5d94fa03dc
commit
5dcd1bee74
@ -9,26 +9,30 @@ import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Encoders;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.apache.spark.mllib.recommendation.Rating;
|
||||
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
|
||||
import org.apache.spark.mllib.recommendation.ALS;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Properties;
|
||||
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
|
||||
import org.apache.spark.mllib.recommendation.ALS;
|
||||
|
||||
|
||||
@Service
|
||||
public class RecommendServiceImpl implements RecommendService {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(RecommendServiceImpl.class);
|
||||
|
||||
@Autowired
|
||||
private SparkSession spark; // 自动注入 SparkSession
|
||||
private SparkSession spark;
|
||||
|
||||
@Value("${spark.mysql.url}")
|
||||
private String url;
|
||||
|
||||
@Value("${spark.mysql.table}")
|
||||
private String table;
|
||||
|
||||
@ -40,15 +44,24 @@ public class RecommendServiceImpl implements RecommendService {
|
||||
|
||||
@Override
|
||||
public List<ProductRemRating> AllRecommend(Integer total) {
|
||||
logger.info("开始执行推荐服务,请求推荐数量: {}", total);
|
||||
logger.debug("MySQL连接配置 - URL: {}, 表名: {}, 用户: {}", url, table, user);
|
||||
|
||||
// 设置 MySQL 连接的属性
|
||||
Properties properties = new Properties();
|
||||
properties.put("user", user); // 修改为你的用户名
|
||||
properties.put("password",password); // 修改为你的密码
|
||||
properties.put("user", user);
|
||||
properties.put("password", password);
|
||||
properties.put("driver", "com.mysql.cj.jdbc.Driver");
|
||||
logger.debug("MySQL连接属性已设置");
|
||||
|
||||
// 读取 MySQL 数据
|
||||
logger.info("开始从MySQL读取数据...");
|
||||
Dataset<Row> df = readFromMySQL(spark, url, table, properties);
|
||||
// 使用 alias 重命名字段,确保字段名匹配
|
||||
logger.info("成功从MySQL读取数据,数据量: {}", df.count());
|
||||
logger.debug("原始数据Schema: {}", df.schema().treeString());
|
||||
|
||||
// 重命名字段
|
||||
logger.info("开始重命名字段...");
|
||||
Dataset<Row> renamedDF = df
|
||||
.select(
|
||||
df.col("user_id").alias("userId"),
|
||||
@ -56,62 +69,85 @@ public class RecommendServiceImpl implements RecommendService {
|
||||
df.col("star").alias("score"),
|
||||
df.col("create_time").alias("createTime")
|
||||
)
|
||||
.na().drop(); // 删除包含空值的行
|
||||
// 使用 Spark 的 Encoders 将 Dataset<Row> 转换为 ProductRating 类型
|
||||
.na().drop();
|
||||
logger.info("字段重命名完成,处理后数据量: {}", renamedDF.count());
|
||||
logger.debug("重命名后Schema: {}", renamedDF.schema().treeString());
|
||||
|
||||
// 转换为ProductRating类型
|
||||
logger.info("开始转换数据为ProductRating类型...");
|
||||
Dataset<ProductRating> productRatingDataset = renamedDF.as(Encoders.bean(ProductRating.class));
|
||||
JavaRDD<ProductRating> ratingRDD = productRatingDataset.javaRDD();
|
||||
// 创建用户和产品的 RDD
|
||||
logger.info("数据转换完成,RDD分区数: {}", ratingRDD.getNumPartitions());
|
||||
|
||||
// 创建用户和产品的RDD
|
||||
logger.info("创建用户和产品的唯一ID列表...");
|
||||
JavaRDD<Integer> userRDD = ratingRDD.map(ProductRating::getUserId).distinct();
|
||||
JavaRDD<Integer> productRDD = ratingRDD.map(ProductRating::getProductId).distinct();
|
||||
//用户编号
|
||||
logger.info("唯一用户数: {}, 唯一产品数: {}", userRDD.count(), productRDD.count());
|
||||
|
||||
// 创建训练数据集
|
||||
JavaRDD<Rating> trainData =
|
||||
ratingRDD.map(rating -> new Rating(rating.getUserId(), rating.getProductId(), rating.getScore())
|
||||
logger.info("准备训练数据...");
|
||||
JavaRDD<Rating> trainData = ratingRDD.map(rating ->
|
||||
new Rating(rating.getUserId(), rating.getProductId(), rating.getScore())
|
||||
);
|
||||
// 打印训练数据集
|
||||
trainData.collect().forEach(train -> {
|
||||
System.out.println("UserId: " + train.user() + ", ProductId: " + train.product() + ", Rating: " + train.rating());
|
||||
logger.debug("训练数据样例:");
|
||||
trainData.take(5).forEach(train -> {
|
||||
logger.debug("训练样本 - UserId: {}, ProductId: {}, Rating: {}",
|
||||
train.user(), train.product(), train.rating());
|
||||
});
|
||||
|
||||
// 调用 ALS 算法训练隐语义模型
|
||||
int rank = 50; // 隐语义因子的个数
|
||||
int iterations = 5; // 迭代次数
|
||||
double lambda = 0.01; // 正则化参数
|
||||
// ALS模型参数
|
||||
int rank = 50;
|
||||
int iterations = 5;
|
||||
double lambda = 0.01;
|
||||
logger.info("开始训练ALS模型,参数 - rank: {}, iterations: {}, lambda: {}",
|
||||
rank, iterations, lambda);
|
||||
|
||||
// 训练 ALS 模型
|
||||
// 训练ALS模型
|
||||
MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainData), rank, iterations, lambda);
|
||||
logger.info("ALS模型训练完成");
|
||||
|
||||
// 计算用户推荐矩阵
|
||||
logger.info("计算用户-产品笛卡尔积...");
|
||||
JavaPairRDD<Integer, Integer> userProducts = userRDD.cartesian(productRDD);
|
||||
logger.info("用户-产品组合总数: {}", userProducts.count());
|
||||
|
||||
// model已训练好,把id传进去就可以得到预测评分列表RDD[Rating] (userId,productId,rating)
|
||||
// 预测评分
|
||||
logger.info("开始预测评分...");
|
||||
JavaRDD<Rating> preRatings = model.predict(userProducts);
|
||||
logger.info("预测评分完成,预测结果数: {}", preRatings.count());
|
||||
|
||||
preRatings.foreach(userId -> {
|
||||
System.out.println(userId);
|
||||
});
|
||||
// 过滤出评分大于0的推荐
|
||||
// 过滤和排序
|
||||
logger.info("过滤和排序推荐结果...");
|
||||
JavaRDD<Rating> filteredRatings = preRatings.filter(rating -> rating.rating() > 0);
|
||||
// 按照评分降序排序,并取前10个推荐
|
||||
List<Rating> topRatings = filteredRatings
|
||||
.sortBy((Rating rating) -> rating.rating(), false, 1) // 按评分降序排序,false表示降序
|
||||
.take(total); // 取前10个评分最高的推荐
|
||||
logger.info("过滤后有效推荐数: {}", filteredRatings.count());
|
||||
|
||||
List<Rating> topRatings = filteredRatings
|
||||
.sortBy((Rating rating) -> rating.rating(), false, 1)
|
||||
.take(total);
|
||||
logger.info("获取到Top {}推荐结果", topRatings.size());
|
||||
|
||||
// 转换为返回对象
|
||||
logger.info("转换推荐结果为业务对象...");
|
||||
List<ProductRemRating> list = new ArrayList<>();
|
||||
// 打印前10个推荐
|
||||
topRatings.forEach(rating -> {
|
||||
ProductRemRating productRemRating = new ProductRemRating();
|
||||
productRemRating.setUserId(rating.user());
|
||||
productRemRating.setProductId(rating.product());
|
||||
productRemRating.setRating(rating.rating());
|
||||
list.add(productRemRating);
|
||||
System.out.println("UserId: " + rating.user() + ", ProductId: " + rating.product() + ", Rating: " + rating.rating());
|
||||
logger.debug("推荐结果 - UserId: {}, ProductId: {}, Rating: {}",
|
||||
rating.user(), rating.product(), rating.rating());
|
||||
});
|
||||
|
||||
logger.info("推荐服务执行完成,返回结果数: {}", list.size());
|
||||
return list;
|
||||
}
|
||||
|
||||
public static Dataset<Row> readFromMySQL(SparkSession spark, String url, String table, Properties properties) {
|
||||
// 从 MySQL 读取数据并返回 Dataset
|
||||
return spark.read().jdbc(url, table, properties);
|
||||
logger.debug("从MySQL读取数据 - URL: {}, Table: {}", url, table);
|
||||
Dataset<Row> df = spark.read().jdbc(url, table, properties);
|
||||
logger.debug("MySQL数据读取完成");
|
||||
return df;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user