diff --git a/.gitignore b/.gitignore index 9154f4c..5ff6309 100644 --- a/.gitignore +++ b/.gitignore @@ -1,26 +1,38 @@ -# ---> Java -# Compiled class file -*.class +target/ +!.mvn/wrapper/maven-wrapper.jar +!**/src/main/**/target/ +!**/src/test/**/target/ -# Log file -*.log +### IntelliJ IDEA ### +.idea/modules.xml +.idea/jarRepositories.xml +.idea/compiler.xml +.idea/libraries/ +*.iws +*.iml +*.ipr -# BlueJ files -*.ctxt +### Eclipse ### +.apt_generated +.classpath +.factorypath +.project +.settings +.springBeans +.sts4-cache -# Mobile Tools for Java (J2ME) -.mtj.tmp/ +### NetBeans ### +/nbproject/private/ +/nbbuild/ +/dist/ +/nbdist/ +/.nb-gradle/ +build/ +!**/src/main/**/build/ +!**/src/test/**/build/ -# Package Files # -*.jar -*.war -*.nar -*.ear -*.zip -*.tar.gz -*.rar - -# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml -hs_err_pid* -replay_pid* +### VS Code ### +.vscode/ +### Mac OS ### +.DS_Store \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/codeStyles/Project.xml b/.idea/codeStyles/Project.xml new file mode 100644 index 0000000..919ce1f --- /dev/null +++ b/.idea/codeStyles/Project.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/codeStyles/codeStyleConfig.xml b/.idea/codeStyles/codeStyleConfig.xml new file mode 100644 index 0000000..a55e7a1 --- /dev/null +++ b/.idea/codeStyles/codeStyleConfig.xml @@ -0,0 +1,5 @@ + + + + \ No newline at end of file diff --git a/.idea/encodings.xml b/.idea/encodings.xml new file mode 100644 index 0000000..63e9001 --- /dev/null +++ b/.idea/encodings.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..03d9549 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..5ed0e80 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,12 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/README.md b/README.md index 81b87b1..c914fa8 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,3 @@ # spark-rem -推荐算法模板 \ No newline at end of file +推荐算法 \ No newline at end of file diff --git a/pom.xml b/pom.xml new file mode 100644 index 0000000..dd5c3aa --- /dev/null +++ b/pom.xml @@ -0,0 +1,77 @@ + + + 4.0.0 + + + org.springframework.boot + spring-boot-starter-parent + 2.6.2 + + + + org.example + spark-rem + 1.0-SNAPSHOT + + + 8 + 8 + UTF-8 + + + + + + + org.springframework.boot + spring-boot-starter-web + + + + org.projectlombok + lombok + + + org.apache.spark + spark-core_2.12 + 3.1.2 + + + org.slf4j + slf4j-log4j12 + + + + + + + org.apache.spark + spark-sql_2.12 + 3.1.2 + + + + org.apache.spark + spark-mllib_2.12 + 3.1.2 + + + + org.codehaus.janino + janino + 3.0.15 + + + + + + + mysql + mysql-connector-java + 8.0.28 + + + + \ No newline at end of file diff --git a/src/main/java/com/OfflineRecommender.java b/src/main/java/com/OfflineRecommender.java new file mode 100644 index 0000000..9b082d1 --- /dev/null +++ b/src/main/java/com/OfflineRecommender.java @@ -0,0 +1,112 @@ +//package com; +// +//import com.model.ProductRating; +//import org.apache.spark.SparkConf; +//import org.apache.spark.sql.Dataset; +//import org.apache.spark.sql.Encoders; +//import org.apache.spark.sql.SparkSession; +//import org.apache.spark.sql.Row; +//import org.apache.spark.api.java.JavaRDD; +//import org.apache.spark.api.java.JavaPairRDD; +//import org.apache.spark.mllib.recommendation.ALS; +//import org.apache.spark.mllib.recommendation.Rating; +//import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +//import java.util.Properties; +//import java.util.List; +// +//public class OfflineRecommender { +// +// // 读取 MySQL 数据的工具方法 +// public static Dataset readFromMySQL(SparkSession spark, String url, String table, Properties properties) { +// // 从 MySQL 读取数据并返回 Dataset +// return spark.read().jdbc(url, table, properties); +// } +// +// public static void main(String[] args) { +// +// // 创建 SparkConf 对象并设置 Spark 配置 +// SparkConf sparkConf = new SparkConf() +// .setMaster("local[*]") +// .setAppName("OfflineRecommender"); +// +// // 创建 SparkSession 对象 +// SparkSession spark = SparkSession.builder().config(sparkConf).getOrCreate(); +// +// +// //配置 MySQL 连接信息 +// String url = "jdbc:mysql://localhost:3306/test_rem"; // 修改为你的数据库连接 URL +// String table = "rating"; // 替换为你要读取的表名 +// +// // 设置 MySQL 连接的属性 +// Properties properties = new Properties(); +// properties.put("user", "root"); // 修改为你的用户名 +// properties.put("password", "123456"); // 修改为你的密码 +// properties.put("driver", "com.mysql.cj.jdbc.Driver"); +// +// // 读取 MySQL 数据 +// Dataset df = readFromMySQL(spark, url, table, properties); +// +// // 使用 alias 重命名字段,确保字段名匹配 +// Dataset renamedDF = df +// .select( +// df.col("user_id").alias("userId"), +// df.col("product_id").alias("productId"), +// df.col("score"), +// df.col("create_time").alias("createTime") +// ) +// .na().drop(); // 删除包含空值的行 +// +// // 使用 Spark 的 Encoders 将 Dataset 转换为 ProductRating 类型 +// Dataset productRatingDataset = renamedDF.as(Encoders.bean(ProductRating.class)); +// +// // 显示映射后的实体类数据 +// productRatingDataset.show(); +// +// // 将 Dataset 转换为 Java RDD +// JavaRDD ratingRDD = productRatingDataset.javaRDD(); +// +// // 创建用户和产品的 RDD +// JavaRDD userRDD = ratingRDD.map(ProductRating::getUserId).distinct(); +// JavaRDD productRDD = ratingRDD.map(ProductRating::getProductId).distinct(); +// +// // 创建训练数据集 +// JavaRDD trainData = +// ratingRDD.map(rating -> new Rating(rating.getUserId(), rating.getProductId(), rating.getScore()) +// ); +// +// // 打印训练数据集 +// trainData.collect().forEach(train -> { +// System.out.println("UserId: " + train.user() + ", ProductId: " + train.product() + ", Rating: " + train.rating()); +// }); +// +// // 调用 ALS 算法训练隐语义模型 +// int rank = 50; // 隐语义因子的个数 +// int iterations = 5; // 迭代次数 +// double lambda = 0.01; // 正则化参数 +// +// // 训练 ALS 模型 +// MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainData), rank, iterations, lambda); +// +// // 计算用户推荐矩阵 +// JavaPairRDD userProducts = userRDD.cartesian(productRDD); +// +// // model已训练好,把id传进去就可以得到预测评分列表RDD[Rating] (userId,productId,rating) +// JavaRDD preRatings = model.predict(userProducts); +// +// preRatings.foreach(userId -> { +// System.out.println(userId); +// }); +// // 过滤出评分大于0的推荐 +// JavaRDD filteredRatings = preRatings.filter(rating -> rating.rating() > 0); +// // 按照评分降序排序,并取前10个推荐 +// List top10Ratings = filteredRatings +// .sortBy((Rating rating) -> rating.rating(), false, 1) // 按评分降序排序,false表示降序 +// .take(10); // 取前10个评分最高的推荐 +// // 打印前10个推荐 +// top10Ratings.forEach(rating -> { +// System.out.println("UserId: " + rating.user() + ", ProductId: " + rating.product() + ", Rating: " + rating.rating()); +// }); +// +// } +// +//} diff --git a/src/main/java/com/SpringBootRecommendApplication.java b/src/main/java/com/SpringBootRecommendApplication.java new file mode 100644 index 0000000..383c0b6 --- /dev/null +++ b/src/main/java/com/SpringBootRecommendApplication.java @@ -0,0 +1,12 @@ +package com; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; + +@SpringBootApplication +public class SpringBootRecommendApplication { + public static void main(String[] args) { + SpringApplication.run(SpringBootRecommendApplication.class, args); + } +} + diff --git a/src/main/java/com/config/SparkConfig.java b/src/main/java/com/config/SparkConfig.java new file mode 100644 index 0000000..3eaa73a --- /dev/null +++ b/src/main/java/com/config/SparkConfig.java @@ -0,0 +1,23 @@ +package com.config; + +import org.apache.spark.SparkConf; +import org.apache.spark.sql.SparkSession; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class SparkConfig { + + @Bean + public SparkSession sparkSession() { + // 创建 SparkConf 对象并设置 Spark 配置 + SparkConf sparkConf = new SparkConf() + .setMaster("local[*]") // 本地模式,使用所有可用核心 + .setAppName("OfflineRecommender"); // Spark 应用名称 + + // 创建并返回 SparkSession 对象 + return SparkSession.builder() + .config(sparkConf) + .getOrCreate(); // 初始化 SparkSession + } +} diff --git a/src/main/java/com/controller/RecommendController.java b/src/main/java/com/controller/RecommendController.java new file mode 100644 index 0000000..08c48fc --- /dev/null +++ b/src/main/java/com/controller/RecommendController.java @@ -0,0 +1,34 @@ +package com.controller; + + +import com.model.ProductRemRating; +import com.service.RecommendService; +import org.apache.spark.mllib.recommendation.Rating; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.*; + +import java.util.List; + +/** + * 推荐 + */ +@RestController +@RequestMapping("/api/recommend") +@CrossOrigin +public class RecommendController { + + @Autowired private RecommendService recommendService; + + /** + * 推荐全部商品 + * @param total + * @return 返回 用户Id,商品编号 评估推荐分数 + */ + @GetMapping("all") + public List allUserRecommend(@RequestParam Integer total){ + List ratings = recommendService.AllRecommend(total); + return ratings; + } + + +} diff --git a/src/main/java/com/model/ProductRating.java b/src/main/java/com/model/ProductRating.java new file mode 100644 index 0000000..d27599b --- /dev/null +++ b/src/main/java/com/model/ProductRating.java @@ -0,0 +1,68 @@ +package com.model; + +/** + * 评分表 + */ +public class ProductRating { + + private Integer userId; + + private Integer productId; + + private Double score; + private String createTime; + + public ProductRating() { + // 无参构造函数 + } + + public ProductRating(Integer userId, Integer productId, Double score, String createTime) { + this.userId = userId; + this.productId = productId; + this.score = score; + this.createTime = createTime; + } + + public Integer getProductId() { + return productId; + } + + public Double getScore() { + return score; + } + + public void setProductId(Integer productId) { + this.productId = productId; + } + + public void setScore(Double score) { + this.score = score; + } + + public Integer getUserId() { + return userId; + } + + public String getCreateTime() { + return createTime; + } + + public void setUserId(Integer userId) { + this.userId = userId; + } + + public void setCreateTime(String createTime) { + this.createTime = createTime; + } + + @Override + public String toString() { + return "ProductRating{" + + "userId=" + userId + + ", productId=" + productId + + ", score=" + score + + ", createTime='" + createTime + '\'' + + '}'; + } + +} diff --git a/src/main/java/com/model/ProductRemRating.java b/src/main/java/com/model/ProductRemRating.java new file mode 100644 index 0000000..52d9c8d --- /dev/null +++ b/src/main/java/com/model/ProductRemRating.java @@ -0,0 +1,17 @@ +package com.model; + +import lombok.Data; + +/** + * 推荐返回表 + */ +@Data +public class ProductRemRating { + + private Integer userId; + + private Integer productId; + + private Double rating; + +} diff --git a/src/main/java/com/service/RecommendService.java b/src/main/java/com/service/RecommendService.java new file mode 100644 index 0000000..ede974a --- /dev/null +++ b/src/main/java/com/service/RecommendService.java @@ -0,0 +1,16 @@ +package com.service; + +import com.model.ProductRemRating; +import org.apache.spark.mllib.recommendation.Rating; +import java.util.List; + +public interface RecommendService { + + + /** + * 推荐全部 + * @return + */ + List AllRecommend(Integer total); + +} diff --git a/src/main/java/com/service/impl/RecommendServiceImpl.java b/src/main/java/com/service/impl/RecommendServiceImpl.java new file mode 100644 index 0000000..218368e --- /dev/null +++ b/src/main/java/com/service/impl/RecommendServiceImpl.java @@ -0,0 +1,117 @@ +package com.service.impl; + +import com.model.ProductRating; +import com.model.ProductRemRating; +import com.service.RecommendService; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Service; +import org.apache.spark.mllib.recommendation.Rating; + +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; +import org.apache.spark.mllib.recommendation.MatrixFactorizationModel; +import org.apache.spark.mllib.recommendation.ALS; + + +@Service +public class RecommendServiceImpl implements RecommendService { + + @Autowired + private SparkSession spark; // 自动注入 SparkSession + + @Value("${spark.mysql.url}") + private String url; + @Value("${spark.mysql.table}") + private String table; + + @Value("${spark.mysql.user}") + private String user; + + @Value("${spark.mysql.password}") + private String password; + + @Override + public List AllRecommend(Integer total) { + // 设置 MySQL 连接的属性 + Properties properties = new Properties(); + properties.put("user", user); // 修改为你的用户名 + properties.put("password",password); // 修改为你的密码 + properties.put("driver", "com.mysql.cj.jdbc.Driver"); + + // 读取 MySQL 数据 + Dataset df = readFromMySQL(spark, url, table, properties); + // 使用 alias 重命名字段,确保字段名匹配 + Dataset renamedDF = df + .select( + df.col("user_id").alias("userId"), + df.col("item_id").alias("productId"), + df.col("star").alias("score"), + df.col("create_time").alias("createTime") + ) + .na().drop(); // 删除包含空值的行 + // 使用 Spark 的 Encoders 将 Dataset 转换为 ProductRating 类型 + Dataset productRatingDataset = renamedDF.as(Encoders.bean(ProductRating.class)); + JavaRDD ratingRDD = productRatingDataset.javaRDD(); + // 创建用户和产品的 RDD + JavaRDD userRDD = ratingRDD.map(ProductRating::getUserId).distinct(); + JavaRDD productRDD = ratingRDD.map(ProductRating::getProductId).distinct(); + //用户编号 + // 创建训练数据集 + JavaRDD trainData = + ratingRDD.map(rating -> new Rating(rating.getUserId(), rating.getProductId(), rating.getScore()) + ); + // 打印训练数据集 + trainData.collect().forEach(train -> { + System.out.println("UserId: " + train.user() + ", ProductId: " + train.product() + ", Rating: " + train.rating()); + }); + + // 调用 ALS 算法训练隐语义模型 + int rank = 50; // 隐语义因子的个数 + int iterations = 5; // 迭代次数 + double lambda = 0.01; // 正则化参数 + + // 训练 ALS 模型 + MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainData), rank, iterations, lambda); + + // 计算用户推荐矩阵 + JavaPairRDD userProducts = userRDD.cartesian(productRDD); + + // model已训练好,把id传进去就可以得到预测评分列表RDD[Rating] (userId,productId,rating) + JavaRDD preRatings = model.predict(userProducts); + + preRatings.foreach(userId -> { + System.out.println(userId); + }); + // 过滤出评分大于0的推荐 + JavaRDD filteredRatings = preRatings.filter(rating -> rating.rating() > 0); + // 按照评分降序排序,并取前10个推荐 + List topRatings = filteredRatings + .sortBy((Rating rating) -> rating.rating(), false, 1) // 按评分降序排序,false表示降序 + .take(total); // 取前10个评分最高的推荐 + + List list = new ArrayList<>(); + // 打印前10个推荐 + topRatings.forEach(rating -> { + ProductRemRating productRemRating = new ProductRemRating(); + productRemRating.setUserId(rating.user()); + productRemRating.setProductId(rating.product()); + productRemRating.setRating(rating.rating()); + list.add(productRemRating); + System.out.println("UserId: " + rating.user() + ", ProductId: " + rating.product() + ", Rating: " + rating.rating()); + }); + return list; + } + + public static Dataset readFromMySQL(SparkSession spark, String url, String table, Properties properties) { + // 从 MySQL 读取数据并返回 Dataset + return spark.read().jdbc(url, table, properties); + } +} \ No newline at end of file diff --git a/src/main/java/com/util/MySQLReader.java b/src/main/java/com/util/MySQLReader.java new file mode 100644 index 0000000..be81f3a --- /dev/null +++ b/src/main/java/com/util/MySQLReader.java @@ -0,0 +1,4 @@ +package com.util; + +public class MySQLReader { +} diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml new file mode 100644 index 0000000..3f65f01 --- /dev/null +++ b/src/main/resources/application.yml @@ -0,0 +1,8 @@ +server: + port: 9090 +spark: + mysql: + url: jdbc:mysql://localhost:3306/racket + user: root + password: 123456 + table: tb_order \ No newline at end of file diff --git a/src/main/resources/log4j.properties b/src/main/resources/log4j.properties new file mode 100644 index 0000000..5039e33 --- /dev/null +++ b/src/main/resources/log4j.properties @@ -0,0 +1,4 @@ +log4j.rootLogger=info, stdout +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} %5p --- [%50t] %-80c(line:%5L) : %m%n