推荐算法push
This commit is contained in:
parent
9041756e80
commit
5d94fa03dc
54
.gitignore
vendored
54
.gitignore
vendored
@ -1,26 +1,38 @@
|
|||||||
# ---> Java
|
target/
|
||||||
# Compiled class file
|
!.mvn/wrapper/maven-wrapper.jar
|
||||||
*.class
|
!**/src/main/**/target/
|
||||||
|
!**/src/test/**/target/
|
||||||
|
|
||||||
# Log file
|
### IntelliJ IDEA ###
|
||||||
*.log
|
.idea/modules.xml
|
||||||
|
.idea/jarRepositories.xml
|
||||||
|
.idea/compiler.xml
|
||||||
|
.idea/libraries/
|
||||||
|
*.iws
|
||||||
|
*.iml
|
||||||
|
*.ipr
|
||||||
|
|
||||||
# BlueJ files
|
### Eclipse ###
|
||||||
*.ctxt
|
.apt_generated
|
||||||
|
.classpath
|
||||||
|
.factorypath
|
||||||
|
.project
|
||||||
|
.settings
|
||||||
|
.springBeans
|
||||||
|
.sts4-cache
|
||||||
|
|
||||||
# Mobile Tools for Java (J2ME)
|
### NetBeans ###
|
||||||
.mtj.tmp/
|
/nbproject/private/
|
||||||
|
/nbbuild/
|
||||||
|
/dist/
|
||||||
|
/nbdist/
|
||||||
|
/.nb-gradle/
|
||||||
|
build/
|
||||||
|
!**/src/main/**/build/
|
||||||
|
!**/src/test/**/build/
|
||||||
|
|
||||||
# Package Files #
|
### VS Code ###
|
||||||
*.jar
|
.vscode/
|
||||||
*.war
|
|
||||||
*.nar
|
|
||||||
*.ear
|
|
||||||
*.zip
|
|
||||||
*.tar.gz
|
|
||||||
*.rar
|
|
||||||
|
|
||||||
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
|
|
||||||
hs_err_pid*
|
|
||||||
replay_pid*
|
|
||||||
|
|
||||||
|
### Mac OS ###
|
||||||
|
.DS_Store
|
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
7
.idea/codeStyles/Project.xml
generated
Normal file
7
.idea/codeStyles/Project.xml
generated
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
<component name="ProjectCodeStyleConfiguration">
|
||||||
|
<code_scheme name="Project" version="173">
|
||||||
|
<ScalaCodeStyleSettings>
|
||||||
|
<option name="MULTILINE_STRING_CLOSING_QUOTES_ON_NEW_LINE" value="true" />
|
||||||
|
</ScalaCodeStyleSettings>
|
||||||
|
</code_scheme>
|
||||||
|
</component>
|
5
.idea/codeStyles/codeStyleConfig.xml
generated
Normal file
5
.idea/codeStyles/codeStyleConfig.xml
generated
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
<component name="ProjectCodeStyleConfiguration">
|
||||||
|
<state>
|
||||||
|
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
|
||||||
|
</state>
|
||||||
|
</component>
|
6
.idea/encodings.xml
generated
Normal file
6
.idea/encodings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Encoding">
|
||||||
|
<file url="file://$PROJECT_DIR$/src/main/java" charset="UTF-8" />
|
||||||
|
</component>
|
||||||
|
</project>
|
6
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
6
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||||
|
</profile>
|
||||||
|
</component>
|
12
.idea/misc.xml
generated
Normal file
12
.idea/misc.xml
generated
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||||
|
<component name="MavenProjectsManager">
|
||||||
|
<option name="originalFiles">
|
||||||
|
<list>
|
||||||
|
<option value="$PROJECT_DIR$/pom.xml" />
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="corretto-21" project-jdk-type="JavaSDK" />
|
||||||
|
</project>
|
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
77
pom.xml
Normal file
77
pom.xml
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||||
|
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<parent>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-starter-parent</artifactId>
|
||||||
|
<version>2.6.2</version>
|
||||||
|
<relativePath/>
|
||||||
|
</parent>
|
||||||
|
|
||||||
|
<groupId>org.example</groupId>
|
||||||
|
<artifactId>spark-rem</artifactId>
|
||||||
|
<version>1.0-SNAPSHOT</version>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<maven.compiler.source>8</maven.compiler.source>
|
||||||
|
<maven.compiler.target>8</maven.compiler.target>
|
||||||
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<!-- spark引入-->
|
||||||
|
<dependencies>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.springframework.boot</groupId>
|
||||||
|
<artifactId>spring-boot-starter-web</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.projectlombok</groupId>
|
||||||
|
<artifactId>lombok</artifactId>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.spark</groupId>
|
||||||
|
<artifactId>spark-core_2.12</artifactId>
|
||||||
|
<version>3.1.2</version>
|
||||||
|
<exclusions>
|
||||||
|
<exclusion>
|
||||||
|
<groupId>org.slf4j</groupId>
|
||||||
|
<artifactId>slf4j-log4j12</artifactId>
|
||||||
|
</exclusion>
|
||||||
|
</exclusions>
|
||||||
|
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.spark</groupId>
|
||||||
|
<artifactId>spark-sql_2.12</artifactId>
|
||||||
|
<version>3.1.2</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.spark</groupId>
|
||||||
|
<artifactId>spark-mllib_2.12</artifactId>
|
||||||
|
<version>3.1.2</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.codehaus.janino</groupId>
|
||||||
|
<artifactId>janino</artifactId>
|
||||||
|
<version>3.0.15</version> <!-- 使用适当版本 -->
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
<!-- MySQL JDBC Driver -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>mysql</groupId>
|
||||||
|
<artifactId>mysql-connector-java</artifactId>
|
||||||
|
<version>8.0.28</version> <!-- 使用适合的 MySQL 驱动版本 -->
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
|
||||||
|
</project>
|
112
src/main/java/com/OfflineRecommender.java
Normal file
112
src/main/java/com/OfflineRecommender.java
Normal file
@ -0,0 +1,112 @@
|
|||||||
|
//package com;
|
||||||
|
//
|
||||||
|
//import com.model.ProductRating;
|
||||||
|
//import org.apache.spark.SparkConf;
|
||||||
|
//import org.apache.spark.sql.Dataset;
|
||||||
|
//import org.apache.spark.sql.Encoders;
|
||||||
|
//import org.apache.spark.sql.SparkSession;
|
||||||
|
//import org.apache.spark.sql.Row;
|
||||||
|
//import org.apache.spark.api.java.JavaRDD;
|
||||||
|
//import org.apache.spark.api.java.JavaPairRDD;
|
||||||
|
//import org.apache.spark.mllib.recommendation.ALS;
|
||||||
|
//import org.apache.spark.mllib.recommendation.Rating;
|
||||||
|
//import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
|
||||||
|
//import java.util.Properties;
|
||||||
|
//import java.util.List;
|
||||||
|
//
|
||||||
|
//public class OfflineRecommender {
|
||||||
|
//
|
||||||
|
// // 读取 MySQL 数据的工具方法
|
||||||
|
// public static Dataset<Row> readFromMySQL(SparkSession spark, String url, String table, Properties properties) {
|
||||||
|
// // 从 MySQL 读取数据并返回 Dataset
|
||||||
|
// return spark.read().jdbc(url, table, properties);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// public static void main(String[] args) {
|
||||||
|
//
|
||||||
|
// // 创建 SparkConf 对象并设置 Spark 配置
|
||||||
|
// SparkConf sparkConf = new SparkConf()
|
||||||
|
// .setMaster("local[*]")
|
||||||
|
// .setAppName("OfflineRecommender");
|
||||||
|
//
|
||||||
|
// // 创建 SparkSession 对象
|
||||||
|
// SparkSession spark = SparkSession.builder().config(sparkConf).getOrCreate();
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// //配置 MySQL 连接信息
|
||||||
|
// String url = "jdbc:mysql://localhost:3306/test_rem"; // 修改为你的数据库连接 URL
|
||||||
|
// String table = "rating"; // 替换为你要读取的表名
|
||||||
|
//
|
||||||
|
// // 设置 MySQL 连接的属性
|
||||||
|
// Properties properties = new Properties();
|
||||||
|
// properties.put("user", "root"); // 修改为你的用户名
|
||||||
|
// properties.put("password", "123456"); // 修改为你的密码
|
||||||
|
// properties.put("driver", "com.mysql.cj.jdbc.Driver");
|
||||||
|
//
|
||||||
|
// // 读取 MySQL 数据
|
||||||
|
// Dataset<Row> df = readFromMySQL(spark, url, table, properties);
|
||||||
|
//
|
||||||
|
// // 使用 alias 重命名字段,确保字段名匹配
|
||||||
|
// Dataset<Row> renamedDF = df
|
||||||
|
// .select(
|
||||||
|
// df.col("user_id").alias("userId"),
|
||||||
|
// df.col("product_id").alias("productId"),
|
||||||
|
// df.col("score"),
|
||||||
|
// df.col("create_time").alias("createTime")
|
||||||
|
// )
|
||||||
|
// .na().drop(); // 删除包含空值的行
|
||||||
|
//
|
||||||
|
// // 使用 Spark 的 Encoders 将 Dataset<Row> 转换为 ProductRating 类型
|
||||||
|
// Dataset<ProductRating> productRatingDataset = renamedDF.as(Encoders.bean(ProductRating.class));
|
||||||
|
//
|
||||||
|
// // 显示映射后的实体类数据
|
||||||
|
// productRatingDataset.show();
|
||||||
|
//
|
||||||
|
// // 将 Dataset 转换为 Java RDD
|
||||||
|
// JavaRDD<ProductRating> ratingRDD = productRatingDataset.javaRDD();
|
||||||
|
//
|
||||||
|
// // 创建用户和产品的 RDD
|
||||||
|
// JavaRDD<Integer> userRDD = ratingRDD.map(ProductRating::getUserId).distinct();
|
||||||
|
// JavaRDD<Integer> productRDD = ratingRDD.map(ProductRating::getProductId).distinct();
|
||||||
|
//
|
||||||
|
// // 创建训练数据集
|
||||||
|
// JavaRDD<Rating> trainData =
|
||||||
|
// ratingRDD.map(rating -> new Rating(rating.getUserId(), rating.getProductId(), rating.getScore())
|
||||||
|
// );
|
||||||
|
//
|
||||||
|
// // 打印训练数据集
|
||||||
|
// trainData.collect().forEach(train -> {
|
||||||
|
// System.out.println("UserId: " + train.user() + ", ProductId: " + train.product() + ", Rating: " + train.rating());
|
||||||
|
// });
|
||||||
|
//
|
||||||
|
// // 调用 ALS 算法训练隐语义模型
|
||||||
|
// int rank = 50; // 隐语义因子的个数
|
||||||
|
// int iterations = 5; // 迭代次数
|
||||||
|
// double lambda = 0.01; // 正则化参数
|
||||||
|
//
|
||||||
|
// // 训练 ALS 模型
|
||||||
|
// MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainData), rank, iterations, lambda);
|
||||||
|
//
|
||||||
|
// // 计算用户推荐矩阵
|
||||||
|
// JavaPairRDD<Integer, Integer> userProducts = userRDD.cartesian(productRDD);
|
||||||
|
//
|
||||||
|
// // model已训练好,把id传进去就可以得到预测评分列表RDD[Rating] (userId,productId,rating)
|
||||||
|
// JavaRDD<Rating> preRatings = model.predict(userProducts);
|
||||||
|
//
|
||||||
|
// preRatings.foreach(userId -> {
|
||||||
|
// System.out.println(userId);
|
||||||
|
// });
|
||||||
|
// // 过滤出评分大于0的推荐
|
||||||
|
// JavaRDD<Rating> filteredRatings = preRatings.filter(rating -> rating.rating() > 0);
|
||||||
|
// // 按照评分降序排序,并取前10个推荐
|
||||||
|
// List<Rating> top10Ratings = filteredRatings
|
||||||
|
// .sortBy((Rating rating) -> rating.rating(), false, 1) // 按评分降序排序,false表示降序
|
||||||
|
// .take(10); // 取前10个评分最高的推荐
|
||||||
|
// // 打印前10个推荐
|
||||||
|
// top10Ratings.forEach(rating -> {
|
||||||
|
// System.out.println("UserId: " + rating.user() + ", ProductId: " + rating.product() + ", Rating: " + rating.rating());
|
||||||
|
// });
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
//}
|
12
src/main/java/com/SpringBootRecommendApplication.java
Normal file
12
src/main/java/com/SpringBootRecommendApplication.java
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package com;
|
||||||
|
|
||||||
|
import org.springframework.boot.SpringApplication;
|
||||||
|
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||||
|
|
||||||
|
@SpringBootApplication
|
||||||
|
public class SpringBootRecommendApplication {
|
||||||
|
public static void main(String[] args) {
|
||||||
|
SpringApplication.run(SpringBootRecommendApplication.class, args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
23
src/main/java/com/config/SparkConfig.java
Normal file
23
src/main/java/com/config/SparkConfig.java
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
package com.config;
|
||||||
|
|
||||||
|
import org.apache.spark.SparkConf;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
import org.springframework.context.annotation.Bean;
|
||||||
|
import org.springframework.context.annotation.Configuration;
|
||||||
|
|
||||||
|
@Configuration
|
||||||
|
public class SparkConfig {
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
public SparkSession sparkSession() {
|
||||||
|
// 创建 SparkConf 对象并设置 Spark 配置
|
||||||
|
SparkConf sparkConf = new SparkConf()
|
||||||
|
.setMaster("local[*]") // 本地模式,使用所有可用核心
|
||||||
|
.setAppName("OfflineRecommender"); // Spark 应用名称
|
||||||
|
|
||||||
|
// 创建并返回 SparkSession 对象
|
||||||
|
return SparkSession.builder()
|
||||||
|
.config(sparkConf)
|
||||||
|
.getOrCreate(); // 初始化 SparkSession
|
||||||
|
}
|
||||||
|
}
|
34
src/main/java/com/controller/RecommendController.java
Normal file
34
src/main/java/com/controller/RecommendController.java
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package com.controller;
|
||||||
|
|
||||||
|
|
||||||
|
import com.model.ProductRemRating;
|
||||||
|
import com.service.RecommendService;
|
||||||
|
import org.apache.spark.mllib.recommendation.Rating;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.web.bind.annotation.*;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 推荐
|
||||||
|
*/
|
||||||
|
@RestController
|
||||||
|
@RequestMapping("/api/recommend")
|
||||||
|
@CrossOrigin
|
||||||
|
public class RecommendController {
|
||||||
|
|
||||||
|
@Autowired private RecommendService recommendService;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 推荐全部商品
|
||||||
|
* @param total
|
||||||
|
* @return 返回 用户Id,商品编号 评估推荐分数
|
||||||
|
*/
|
||||||
|
@GetMapping("all")
|
||||||
|
public List<ProductRemRating> allUserRecommend(@RequestParam Integer total){
|
||||||
|
List<ProductRemRating> ratings = recommendService.AllRecommend(total);
|
||||||
|
return ratings;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
68
src/main/java/com/model/ProductRating.java
Normal file
68
src/main/java/com/model/ProductRating.java
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
package com.model;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 评分表
|
||||||
|
*/
|
||||||
|
public class ProductRating {
|
||||||
|
|
||||||
|
private Integer userId;
|
||||||
|
|
||||||
|
private Integer productId;
|
||||||
|
|
||||||
|
private Double score;
|
||||||
|
private String createTime;
|
||||||
|
|
||||||
|
public ProductRating() {
|
||||||
|
// 无参构造函数
|
||||||
|
}
|
||||||
|
|
||||||
|
public ProductRating(Integer userId, Integer productId, Double score, String createTime) {
|
||||||
|
this.userId = userId;
|
||||||
|
this.productId = productId;
|
||||||
|
this.score = score;
|
||||||
|
this.createTime = createTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getProductId() {
|
||||||
|
return productId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Double getScore() {
|
||||||
|
return score;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setProductId(Integer productId) {
|
||||||
|
this.productId = productId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setScore(Double score) {
|
||||||
|
this.score = score;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer getUserId() {
|
||||||
|
return userId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getCreateTime() {
|
||||||
|
return createTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setUserId(Integer userId) {
|
||||||
|
this.userId = userId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setCreateTime(String createTime) {
|
||||||
|
this.createTime = createTime;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "ProductRating{" +
|
||||||
|
"userId=" + userId +
|
||||||
|
", productId=" + productId +
|
||||||
|
", score=" + score +
|
||||||
|
", createTime='" + createTime + '\'' +
|
||||||
|
'}';
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
17
src/main/java/com/model/ProductRemRating.java
Normal file
17
src/main/java/com/model/ProductRemRating.java
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
package com.model;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 推荐返回表
|
||||||
|
*/
|
||||||
|
@Data
|
||||||
|
public class ProductRemRating {
|
||||||
|
|
||||||
|
private Integer userId;
|
||||||
|
|
||||||
|
private Integer productId;
|
||||||
|
|
||||||
|
private Double rating;
|
||||||
|
|
||||||
|
}
|
16
src/main/java/com/service/RecommendService.java
Normal file
16
src/main/java/com/service/RecommendService.java
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
package com.service;
|
||||||
|
|
||||||
|
import com.model.ProductRemRating;
|
||||||
|
import org.apache.spark.mllib.recommendation.Rating;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public interface RecommendService {
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 推荐全部
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
List<ProductRemRating> AllRecommend(Integer total);
|
||||||
|
|
||||||
|
}
|
117
src/main/java/com/service/impl/RecommendServiceImpl.java
Normal file
117
src/main/java/com/service/impl/RecommendServiceImpl.java
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
package com.service.impl;
|
||||||
|
|
||||||
|
import com.model.ProductRating;
|
||||||
|
import com.model.ProductRemRating;
|
||||||
|
import com.service.RecommendService;
|
||||||
|
import org.apache.spark.api.java.JavaPairRDD;
|
||||||
|
import org.apache.spark.api.java.JavaRDD;
|
||||||
|
import org.apache.spark.sql.Dataset;
|
||||||
|
import org.apache.spark.sql.Encoders;
|
||||||
|
import org.apache.spark.sql.Row;
|
||||||
|
import org.apache.spark.sql.SparkSession;
|
||||||
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
|
import org.springframework.stereotype.Service;
|
||||||
|
import org.apache.spark.mllib.recommendation.Rating;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Properties;
|
||||||
|
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
|
||||||
|
import org.apache.spark.mllib.recommendation.ALS;
|
||||||
|
|
||||||
|
|
||||||
|
@Service
|
||||||
|
public class RecommendServiceImpl implements RecommendService {
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
private SparkSession spark; // 自动注入 SparkSession
|
||||||
|
|
||||||
|
@Value("${spark.mysql.url}")
|
||||||
|
private String url;
|
||||||
|
@Value("${spark.mysql.table}")
|
||||||
|
private String table;
|
||||||
|
|
||||||
|
@Value("${spark.mysql.user}")
|
||||||
|
private String user;
|
||||||
|
|
||||||
|
@Value("${spark.mysql.password}")
|
||||||
|
private String password;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<ProductRemRating> AllRecommend(Integer total) {
|
||||||
|
// 设置 MySQL 连接的属性
|
||||||
|
Properties properties = new Properties();
|
||||||
|
properties.put("user", user); // 修改为你的用户名
|
||||||
|
properties.put("password",password); // 修改为你的密码
|
||||||
|
properties.put("driver", "com.mysql.cj.jdbc.Driver");
|
||||||
|
|
||||||
|
// 读取 MySQL 数据
|
||||||
|
Dataset<Row> df = readFromMySQL(spark, url, table, properties);
|
||||||
|
// 使用 alias 重命名字段,确保字段名匹配
|
||||||
|
Dataset<Row> renamedDF = df
|
||||||
|
.select(
|
||||||
|
df.col("user_id").alias("userId"),
|
||||||
|
df.col("item_id").alias("productId"),
|
||||||
|
df.col("star").alias("score"),
|
||||||
|
df.col("create_time").alias("createTime")
|
||||||
|
)
|
||||||
|
.na().drop(); // 删除包含空值的行
|
||||||
|
// 使用 Spark 的 Encoders 将 Dataset<Row> 转换为 ProductRating 类型
|
||||||
|
Dataset<ProductRating> productRatingDataset = renamedDF.as(Encoders.bean(ProductRating.class));
|
||||||
|
JavaRDD<ProductRating> ratingRDD = productRatingDataset.javaRDD();
|
||||||
|
// 创建用户和产品的 RDD
|
||||||
|
JavaRDD<Integer> userRDD = ratingRDD.map(ProductRating::getUserId).distinct();
|
||||||
|
JavaRDD<Integer> productRDD = ratingRDD.map(ProductRating::getProductId).distinct();
|
||||||
|
//用户编号
|
||||||
|
// 创建训练数据集
|
||||||
|
JavaRDD<Rating> trainData =
|
||||||
|
ratingRDD.map(rating -> new Rating(rating.getUserId(), rating.getProductId(), rating.getScore())
|
||||||
|
);
|
||||||
|
// 打印训练数据集
|
||||||
|
trainData.collect().forEach(train -> {
|
||||||
|
System.out.println("UserId: " + train.user() + ", ProductId: " + train.product() + ", Rating: " + train.rating());
|
||||||
|
});
|
||||||
|
|
||||||
|
// 调用 ALS 算法训练隐语义模型
|
||||||
|
int rank = 50; // 隐语义因子的个数
|
||||||
|
int iterations = 5; // 迭代次数
|
||||||
|
double lambda = 0.01; // 正则化参数
|
||||||
|
|
||||||
|
// 训练 ALS 模型
|
||||||
|
MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainData), rank, iterations, lambda);
|
||||||
|
|
||||||
|
// 计算用户推荐矩阵
|
||||||
|
JavaPairRDD<Integer, Integer> userProducts = userRDD.cartesian(productRDD);
|
||||||
|
|
||||||
|
// model已训练好,把id传进去就可以得到预测评分列表RDD[Rating] (userId,productId,rating)
|
||||||
|
JavaRDD<Rating> preRatings = model.predict(userProducts);
|
||||||
|
|
||||||
|
preRatings.foreach(userId -> {
|
||||||
|
System.out.println(userId);
|
||||||
|
});
|
||||||
|
// 过滤出评分大于0的推荐
|
||||||
|
JavaRDD<Rating> filteredRatings = preRatings.filter(rating -> rating.rating() > 0);
|
||||||
|
// 按照评分降序排序,并取前10个推荐
|
||||||
|
List<Rating> topRatings = filteredRatings
|
||||||
|
.sortBy((Rating rating) -> rating.rating(), false, 1) // 按评分降序排序,false表示降序
|
||||||
|
.take(total); // 取前10个评分最高的推荐
|
||||||
|
|
||||||
|
List<ProductRemRating> list = new ArrayList<>();
|
||||||
|
// 打印前10个推荐
|
||||||
|
topRatings.forEach(rating -> {
|
||||||
|
ProductRemRating productRemRating = new ProductRemRating();
|
||||||
|
productRemRating.setUserId(rating.user());
|
||||||
|
productRemRating.setProductId(rating.product());
|
||||||
|
productRemRating.setRating(rating.rating());
|
||||||
|
list.add(productRemRating);
|
||||||
|
System.out.println("UserId: " + rating.user() + ", ProductId: " + rating.product() + ", Rating: " + rating.rating());
|
||||||
|
});
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Dataset<Row> readFromMySQL(SparkSession spark, String url, String table, Properties properties) {
|
||||||
|
// 从 MySQL 读取数据并返回 Dataset
|
||||||
|
return spark.read().jdbc(url, table, properties);
|
||||||
|
}
|
||||||
|
}
|
4
src/main/java/com/util/MySQLReader.java
Normal file
4
src/main/java/com/util/MySQLReader.java
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
package com.util;
|
||||||
|
|
||||||
|
public class MySQLReader {
|
||||||
|
}
|
8
src/main/resources/application.yml
Normal file
8
src/main/resources/application.yml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
server:
|
||||||
|
port: 9090
|
||||||
|
spark:
|
||||||
|
mysql:
|
||||||
|
url: jdbc:mysql://localhost:3306/racket
|
||||||
|
user: root
|
||||||
|
password: 123456
|
||||||
|
table: tb_order
|
4
src/main/resources/log4j.properties
Normal file
4
src/main/resources/log4j.properties
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
log4j.rootLogger=info, stdout
|
||||||
|
log4j.appender.stdout=org.apache.log4j.ConsoleAppender
|
||||||
|
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
|
||||||
|
log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} %5p --- [%50t] %-80c(line:%5L) : %m%n
|
Loading…
Reference in New Issue
Block a user