推荐算法push
This commit is contained in:
parent
9041756e80
commit
5d94fa03dc
54
.gitignore
vendored
54
.gitignore
vendored
@ -1,26 +1,38 @@
|
||||
# ---> Java
|
||||
# Compiled class file
|
||||
*.class
|
||||
target/
|
||||
!.mvn/wrapper/maven-wrapper.jar
|
||||
!**/src/main/**/target/
|
||||
!**/src/test/**/target/
|
||||
|
||||
# Log file
|
||||
*.log
|
||||
### IntelliJ IDEA ###
|
||||
.idea/modules.xml
|
||||
.idea/jarRepositories.xml
|
||||
.idea/compiler.xml
|
||||
.idea/libraries/
|
||||
*.iws
|
||||
*.iml
|
||||
*.ipr
|
||||
|
||||
# BlueJ files
|
||||
*.ctxt
|
||||
### Eclipse ###
|
||||
.apt_generated
|
||||
.classpath
|
||||
.factorypath
|
||||
.project
|
||||
.settings
|
||||
.springBeans
|
||||
.sts4-cache
|
||||
|
||||
# Mobile Tools for Java (J2ME)
|
||||
.mtj.tmp/
|
||||
### NetBeans ###
|
||||
/nbproject/private/
|
||||
/nbbuild/
|
||||
/dist/
|
||||
/nbdist/
|
||||
/.nb-gradle/
|
||||
build/
|
||||
!**/src/main/**/build/
|
||||
!**/src/test/**/build/
|
||||
|
||||
# Package Files #
|
||||
*.jar
|
||||
*.war
|
||||
*.nar
|
||||
*.ear
|
||||
*.zip
|
||||
*.tar.gz
|
||||
*.rar
|
||||
|
||||
# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml
|
||||
hs_err_pid*
|
||||
replay_pid*
|
||||
### VS Code ###
|
||||
.vscode/
|
||||
|
||||
### Mac OS ###
|
||||
.DS_Store
|
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
7
.idea/codeStyles/Project.xml
generated
Normal file
7
.idea/codeStyles/Project.xml
generated
Normal file
@ -0,0 +1,7 @@
|
||||
<component name="ProjectCodeStyleConfiguration">
|
||||
<code_scheme name="Project" version="173">
|
||||
<ScalaCodeStyleSettings>
|
||||
<option name="MULTILINE_STRING_CLOSING_QUOTES_ON_NEW_LINE" value="true" />
|
||||
</ScalaCodeStyleSettings>
|
||||
</code_scheme>
|
||||
</component>
|
5
.idea/codeStyles/codeStyleConfig.xml
generated
Normal file
5
.idea/codeStyles/codeStyleConfig.xml
generated
Normal file
@ -0,0 +1,5 @@
|
||||
<component name="ProjectCodeStyleConfiguration">
|
||||
<state>
|
||||
<option name="PREFERRED_PROJECT_CODE_STYLE" value="Default" />
|
||||
</state>
|
||||
</component>
|
6
.idea/encodings.xml
generated
Normal file
6
.idea/encodings.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Encoding">
|
||||
<file url="file://$PROJECT_DIR$/src/main/java" charset="UTF-8" />
|
||||
</component>
|
||||
</project>
|
6
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
6
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||
</profile>
|
||||
</component>
|
12
.idea/misc.xml
generated
Normal file
12
.idea/misc.xml
generated
Normal file
@ -0,0 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||
<component name="MavenProjectsManager">
|
||||
<option name="originalFiles">
|
||||
<list>
|
||||
<option value="$PROJECT_DIR$/pom.xml" />
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="corretto-21" project-jdk-type="JavaSDK" />
|
||||
</project>
|
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
77
pom.xml
Normal file
77
pom.xml
Normal file
@ -0,0 +1,77 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-parent</artifactId>
|
||||
<version>2.6.2</version>
|
||||
<relativePath/>
|
||||
</parent>
|
||||
|
||||
<groupId>org.example</groupId>
|
||||
<artifactId>spark-rem</artifactId>
|
||||
<version>1.0-SNAPSHOT</version>
|
||||
|
||||
<properties>
|
||||
<maven.compiler.source>8</maven.compiler.source>
|
||||
<maven.compiler.target>8</maven.compiler.target>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
</properties>
|
||||
|
||||
<!-- spark引入-->
|
||||
<dependencies>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-core_2.12</artifactId>
|
||||
<version>3.1.2</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-log4j12</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-sql_2.12</artifactId>
|
||||
<version>3.1.2</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-mllib_2.12</artifactId>
|
||||
<version>3.1.2</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.codehaus.janino</groupId>
|
||||
<artifactId>janino</artifactId>
|
||||
<version>3.0.15</version> <!-- 使用适当版本 -->
|
||||
</dependency>
|
||||
|
||||
|
||||
|
||||
<!-- MySQL JDBC Driver -->
|
||||
<dependency>
|
||||
<groupId>mysql</groupId>
|
||||
<artifactId>mysql-connector-java</artifactId>
|
||||
<version>8.0.28</version> <!-- 使用适合的 MySQL 驱动版本 -->
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
</project>
|
112
src/main/java/com/OfflineRecommender.java
Normal file
112
src/main/java/com/OfflineRecommender.java
Normal file
@ -0,0 +1,112 @@
|
||||
//package com;
|
||||
//
|
||||
//import com.model.ProductRating;
|
||||
//import org.apache.spark.SparkConf;
|
||||
//import org.apache.spark.sql.Dataset;
|
||||
//import org.apache.spark.sql.Encoders;
|
||||
//import org.apache.spark.sql.SparkSession;
|
||||
//import org.apache.spark.sql.Row;
|
||||
//import org.apache.spark.api.java.JavaRDD;
|
||||
//import org.apache.spark.api.java.JavaPairRDD;
|
||||
//import org.apache.spark.mllib.recommendation.ALS;
|
||||
//import org.apache.spark.mllib.recommendation.Rating;
|
||||
//import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
|
||||
//import java.util.Properties;
|
||||
//import java.util.List;
|
||||
//
|
||||
//public class OfflineRecommender {
|
||||
//
|
||||
// // 读取 MySQL 数据的工具方法
|
||||
// public static Dataset<Row> readFromMySQL(SparkSession spark, String url, String table, Properties properties) {
|
||||
// // 从 MySQL 读取数据并返回 Dataset
|
||||
// return spark.read().jdbc(url, table, properties);
|
||||
// }
|
||||
//
|
||||
// public static void main(String[] args) {
|
||||
//
|
||||
// // 创建 SparkConf 对象并设置 Spark 配置
|
||||
// SparkConf sparkConf = new SparkConf()
|
||||
// .setMaster("local[*]")
|
||||
// .setAppName("OfflineRecommender");
|
||||
//
|
||||
// // 创建 SparkSession 对象
|
||||
// SparkSession spark = SparkSession.builder().config(sparkConf).getOrCreate();
|
||||
//
|
||||
//
|
||||
// //配置 MySQL 连接信息
|
||||
// String url = "jdbc:mysql://localhost:3306/test_rem"; // 修改为你的数据库连接 URL
|
||||
// String table = "rating"; // 替换为你要读取的表名
|
||||
//
|
||||
// // 设置 MySQL 连接的属性
|
||||
// Properties properties = new Properties();
|
||||
// properties.put("user", "root"); // 修改为你的用户名
|
||||
// properties.put("password", "123456"); // 修改为你的密码
|
||||
// properties.put("driver", "com.mysql.cj.jdbc.Driver");
|
||||
//
|
||||
// // 读取 MySQL 数据
|
||||
// Dataset<Row> df = readFromMySQL(spark, url, table, properties);
|
||||
//
|
||||
// // 使用 alias 重命名字段,确保字段名匹配
|
||||
// Dataset<Row> renamedDF = df
|
||||
// .select(
|
||||
// df.col("user_id").alias("userId"),
|
||||
// df.col("product_id").alias("productId"),
|
||||
// df.col("score"),
|
||||
// df.col("create_time").alias("createTime")
|
||||
// )
|
||||
// .na().drop(); // 删除包含空值的行
|
||||
//
|
||||
// // 使用 Spark 的 Encoders 将 Dataset<Row> 转换为 ProductRating 类型
|
||||
// Dataset<ProductRating> productRatingDataset = renamedDF.as(Encoders.bean(ProductRating.class));
|
||||
//
|
||||
// // 显示映射后的实体类数据
|
||||
// productRatingDataset.show();
|
||||
//
|
||||
// // 将 Dataset 转换为 Java RDD
|
||||
// JavaRDD<ProductRating> ratingRDD = productRatingDataset.javaRDD();
|
||||
//
|
||||
// // 创建用户和产品的 RDD
|
||||
// JavaRDD<Integer> userRDD = ratingRDD.map(ProductRating::getUserId).distinct();
|
||||
// JavaRDD<Integer> productRDD = ratingRDD.map(ProductRating::getProductId).distinct();
|
||||
//
|
||||
// // 创建训练数据集
|
||||
// JavaRDD<Rating> trainData =
|
||||
// ratingRDD.map(rating -> new Rating(rating.getUserId(), rating.getProductId(), rating.getScore())
|
||||
// );
|
||||
//
|
||||
// // 打印训练数据集
|
||||
// trainData.collect().forEach(train -> {
|
||||
// System.out.println("UserId: " + train.user() + ", ProductId: " + train.product() + ", Rating: " + train.rating());
|
||||
// });
|
||||
//
|
||||
// // 调用 ALS 算法训练隐语义模型
|
||||
// int rank = 50; // 隐语义因子的个数
|
||||
// int iterations = 5; // 迭代次数
|
||||
// double lambda = 0.01; // 正则化参数
|
||||
//
|
||||
// // 训练 ALS 模型
|
||||
// MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainData), rank, iterations, lambda);
|
||||
//
|
||||
// // 计算用户推荐矩阵
|
||||
// JavaPairRDD<Integer, Integer> userProducts = userRDD.cartesian(productRDD);
|
||||
//
|
||||
// // model已训练好,把id传进去就可以得到预测评分列表RDD[Rating] (userId,productId,rating)
|
||||
// JavaRDD<Rating> preRatings = model.predict(userProducts);
|
||||
//
|
||||
// preRatings.foreach(userId -> {
|
||||
// System.out.println(userId);
|
||||
// });
|
||||
// // 过滤出评分大于0的推荐
|
||||
// JavaRDD<Rating> filteredRatings = preRatings.filter(rating -> rating.rating() > 0);
|
||||
// // 按照评分降序排序,并取前10个推荐
|
||||
// List<Rating> top10Ratings = filteredRatings
|
||||
// .sortBy((Rating rating) -> rating.rating(), false, 1) // 按评分降序排序,false表示降序
|
||||
// .take(10); // 取前10个评分最高的推荐
|
||||
// // 打印前10个推荐
|
||||
// top10Ratings.forEach(rating -> {
|
||||
// System.out.println("UserId: " + rating.user() + ", ProductId: " + rating.product() + ", Rating: " + rating.rating());
|
||||
// });
|
||||
//
|
||||
// }
|
||||
//
|
||||
//}
|
12
src/main/java/com/SpringBootRecommendApplication.java
Normal file
12
src/main/java/com/SpringBootRecommendApplication.java
Normal file
@ -0,0 +1,12 @@
|
||||
package com;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
|
||||
@SpringBootApplication
|
||||
public class SpringBootRecommendApplication {
|
||||
public static void main(String[] args) {
|
||||
SpringApplication.run(SpringBootRecommendApplication.class, args);
|
||||
}
|
||||
}
|
||||
|
23
src/main/java/com/config/SparkConfig.java
Normal file
23
src/main/java/com/config/SparkConfig.java
Normal file
@ -0,0 +1,23 @@
|
||||
package com.config;
|
||||
|
||||
import org.apache.spark.SparkConf;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
public class SparkConfig {
|
||||
|
||||
@Bean
|
||||
public SparkSession sparkSession() {
|
||||
// 创建 SparkConf 对象并设置 Spark 配置
|
||||
SparkConf sparkConf = new SparkConf()
|
||||
.setMaster("local[*]") // 本地模式,使用所有可用核心
|
||||
.setAppName("OfflineRecommender"); // Spark 应用名称
|
||||
|
||||
// 创建并返回 SparkSession 对象
|
||||
return SparkSession.builder()
|
||||
.config(sparkConf)
|
||||
.getOrCreate(); // 初始化 SparkSession
|
||||
}
|
||||
}
|
34
src/main/java/com/controller/RecommendController.java
Normal file
34
src/main/java/com/controller/RecommendController.java
Normal file
@ -0,0 +1,34 @@
|
||||
package com.controller;
|
||||
|
||||
|
||||
import com.model.ProductRemRating;
|
||||
import com.service.RecommendService;
|
||||
import org.apache.spark.mllib.recommendation.Rating;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* 推荐
|
||||
*/
|
||||
@RestController
|
||||
@RequestMapping("/api/recommend")
|
||||
@CrossOrigin
|
||||
public class RecommendController {
|
||||
|
||||
@Autowired private RecommendService recommendService;
|
||||
|
||||
/**
|
||||
* 推荐全部商品
|
||||
* @param total
|
||||
* @return 返回 用户Id,商品编号 评估推荐分数
|
||||
*/
|
||||
@GetMapping("all")
|
||||
public List<ProductRemRating> allUserRecommend(@RequestParam Integer total){
|
||||
List<ProductRemRating> ratings = recommendService.AllRecommend(total);
|
||||
return ratings;
|
||||
}
|
||||
|
||||
|
||||
}
|
68
src/main/java/com/model/ProductRating.java
Normal file
68
src/main/java/com/model/ProductRating.java
Normal file
@ -0,0 +1,68 @@
|
||||
package com.model;
|
||||
|
||||
/**
|
||||
* 评分表
|
||||
*/
|
||||
public class ProductRating {
|
||||
|
||||
private Integer userId;
|
||||
|
||||
private Integer productId;
|
||||
|
||||
private Double score;
|
||||
private String createTime;
|
||||
|
||||
public ProductRating() {
|
||||
// 无参构造函数
|
||||
}
|
||||
|
||||
public ProductRating(Integer userId, Integer productId, Double score, String createTime) {
|
||||
this.userId = userId;
|
||||
this.productId = productId;
|
||||
this.score = score;
|
||||
this.createTime = createTime;
|
||||
}
|
||||
|
||||
public Integer getProductId() {
|
||||
return productId;
|
||||
}
|
||||
|
||||
public Double getScore() {
|
||||
return score;
|
||||
}
|
||||
|
||||
public void setProductId(Integer productId) {
|
||||
this.productId = productId;
|
||||
}
|
||||
|
||||
public void setScore(Double score) {
|
||||
this.score = score;
|
||||
}
|
||||
|
||||
public Integer getUserId() {
|
||||
return userId;
|
||||
}
|
||||
|
||||
public String getCreateTime() {
|
||||
return createTime;
|
||||
}
|
||||
|
||||
public void setUserId(Integer userId) {
|
||||
this.userId = userId;
|
||||
}
|
||||
|
||||
public void setCreateTime(String createTime) {
|
||||
this.createTime = createTime;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ProductRating{" +
|
||||
"userId=" + userId +
|
||||
", productId=" + productId +
|
||||
", score=" + score +
|
||||
", createTime='" + createTime + '\'' +
|
||||
'}';
|
||||
}
|
||||
|
||||
}
|
17
src/main/java/com/model/ProductRemRating.java
Normal file
17
src/main/java/com/model/ProductRemRating.java
Normal file
@ -0,0 +1,17 @@
|
||||
package com.model;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
/**
|
||||
* 推荐返回表
|
||||
*/
|
||||
@Data
|
||||
public class ProductRemRating {
|
||||
|
||||
private Integer userId;
|
||||
|
||||
private Integer productId;
|
||||
|
||||
private Double rating;
|
||||
|
||||
}
|
16
src/main/java/com/service/RecommendService.java
Normal file
16
src/main/java/com/service/RecommendService.java
Normal file
@ -0,0 +1,16 @@
|
||||
package com.service;
|
||||
|
||||
import com.model.ProductRemRating;
|
||||
import org.apache.spark.mllib.recommendation.Rating;
|
||||
import java.util.List;
|
||||
|
||||
public interface RecommendService {
|
||||
|
||||
|
||||
/**
|
||||
* 推荐全部
|
||||
* @return
|
||||
*/
|
||||
List<ProductRemRating> AllRecommend(Integer total);
|
||||
|
||||
}
|
117
src/main/java/com/service/impl/RecommendServiceImpl.java
Normal file
117
src/main/java/com/service/impl/RecommendServiceImpl.java
Normal file
@ -0,0 +1,117 @@
|
||||
package com.service.impl;
|
||||
|
||||
import com.model.ProductRating;
|
||||
import com.model.ProductRemRating;
|
||||
import com.service.RecommendService;
|
||||
import org.apache.spark.api.java.JavaPairRDD;
|
||||
import org.apache.spark.api.java.JavaRDD;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Encoders;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.apache.spark.mllib.recommendation.Rating;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Properties;
|
||||
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
|
||||
import org.apache.spark.mllib.recommendation.ALS;
|
||||
|
||||
|
||||
@Service
|
||||
public class RecommendServiceImpl implements RecommendService {
|
||||
|
||||
@Autowired
|
||||
private SparkSession spark; // 自动注入 SparkSession
|
||||
|
||||
@Value("${spark.mysql.url}")
|
||||
private String url;
|
||||
@Value("${spark.mysql.table}")
|
||||
private String table;
|
||||
|
||||
@Value("${spark.mysql.user}")
|
||||
private String user;
|
||||
|
||||
@Value("${spark.mysql.password}")
|
||||
private String password;
|
||||
|
||||
@Override
|
||||
public List<ProductRemRating> AllRecommend(Integer total) {
|
||||
// 设置 MySQL 连接的属性
|
||||
Properties properties = new Properties();
|
||||
properties.put("user", user); // 修改为你的用户名
|
||||
properties.put("password",password); // 修改为你的密码
|
||||
properties.put("driver", "com.mysql.cj.jdbc.Driver");
|
||||
|
||||
// 读取 MySQL 数据
|
||||
Dataset<Row> df = readFromMySQL(spark, url, table, properties);
|
||||
// 使用 alias 重命名字段,确保字段名匹配
|
||||
Dataset<Row> renamedDF = df
|
||||
.select(
|
||||
df.col("user_id").alias("userId"),
|
||||
df.col("item_id").alias("productId"),
|
||||
df.col("star").alias("score"),
|
||||
df.col("create_time").alias("createTime")
|
||||
)
|
||||
.na().drop(); // 删除包含空值的行
|
||||
// 使用 Spark 的 Encoders 将 Dataset<Row> 转换为 ProductRating 类型
|
||||
Dataset<ProductRating> productRatingDataset = renamedDF.as(Encoders.bean(ProductRating.class));
|
||||
JavaRDD<ProductRating> ratingRDD = productRatingDataset.javaRDD();
|
||||
// 创建用户和产品的 RDD
|
||||
JavaRDD<Integer> userRDD = ratingRDD.map(ProductRating::getUserId).distinct();
|
||||
JavaRDD<Integer> productRDD = ratingRDD.map(ProductRating::getProductId).distinct();
|
||||
//用户编号
|
||||
// 创建训练数据集
|
||||
JavaRDD<Rating> trainData =
|
||||
ratingRDD.map(rating -> new Rating(rating.getUserId(), rating.getProductId(), rating.getScore())
|
||||
);
|
||||
// 打印训练数据集
|
||||
trainData.collect().forEach(train -> {
|
||||
System.out.println("UserId: " + train.user() + ", ProductId: " + train.product() + ", Rating: " + train.rating());
|
||||
});
|
||||
|
||||
// 调用 ALS 算法训练隐语义模型
|
||||
int rank = 50; // 隐语义因子的个数
|
||||
int iterations = 5; // 迭代次数
|
||||
double lambda = 0.01; // 正则化参数
|
||||
|
||||
// 训练 ALS 模型
|
||||
MatrixFactorizationModel model = ALS.train(JavaRDD.toRDD(trainData), rank, iterations, lambda);
|
||||
|
||||
// 计算用户推荐矩阵
|
||||
JavaPairRDD<Integer, Integer> userProducts = userRDD.cartesian(productRDD);
|
||||
|
||||
// model已训练好,把id传进去就可以得到预测评分列表RDD[Rating] (userId,productId,rating)
|
||||
JavaRDD<Rating> preRatings = model.predict(userProducts);
|
||||
|
||||
preRatings.foreach(userId -> {
|
||||
System.out.println(userId);
|
||||
});
|
||||
// 过滤出评分大于0的推荐
|
||||
JavaRDD<Rating> filteredRatings = preRatings.filter(rating -> rating.rating() > 0);
|
||||
// 按照评分降序排序,并取前10个推荐
|
||||
List<Rating> topRatings = filteredRatings
|
||||
.sortBy((Rating rating) -> rating.rating(), false, 1) // 按评分降序排序,false表示降序
|
||||
.take(total); // 取前10个评分最高的推荐
|
||||
|
||||
List<ProductRemRating> list = new ArrayList<>();
|
||||
// 打印前10个推荐
|
||||
topRatings.forEach(rating -> {
|
||||
ProductRemRating productRemRating = new ProductRemRating();
|
||||
productRemRating.setUserId(rating.user());
|
||||
productRemRating.setProductId(rating.product());
|
||||
productRemRating.setRating(rating.rating());
|
||||
list.add(productRemRating);
|
||||
System.out.println("UserId: " + rating.user() + ", ProductId: " + rating.product() + ", Rating: " + rating.rating());
|
||||
});
|
||||
return list;
|
||||
}
|
||||
|
||||
public static Dataset<Row> readFromMySQL(SparkSession spark, String url, String table, Properties properties) {
|
||||
// 从 MySQL 读取数据并返回 Dataset
|
||||
return spark.read().jdbc(url, table, properties);
|
||||
}
|
||||
}
|
4
src/main/java/com/util/MySQLReader.java
Normal file
4
src/main/java/com/util/MySQLReader.java
Normal file
@ -0,0 +1,4 @@
|
||||
package com.util;
|
||||
|
||||
public class MySQLReader {
|
||||
}
|
8
src/main/resources/application.yml
Normal file
8
src/main/resources/application.yml
Normal file
@ -0,0 +1,8 @@
|
||||
server:
|
||||
port: 9090
|
||||
spark:
|
||||
mysql:
|
||||
url: jdbc:mysql://localhost:3306/racket
|
||||
user: root
|
||||
password: 123456
|
||||
table: tb_order
|
4
src/main/resources/log4j.properties
Normal file
4
src/main/resources/log4j.properties
Normal file
@ -0,0 +1,4 @@
|
||||
log4j.rootLogger=info, stdout
|
||||
log4j.appender.stdout=org.apache.log4j.ConsoleAppender
|
||||
log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
|
||||
log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss,SSS} %5p --- [%50t] %-80c(line:%5L) : %m%n
|
Loading…
Reference in New Issue
Block a user