mirror of
https://gitee.com/dromara/hutool.git
synced 2025-04-05 17:37:59 +08:00
fix code
This commit is contained in:
parent
5b1476c8c9
commit
e0566c17ee
hutool-ai/src/main/java/org/dromara/hutool/ai
@ -1,35 +1,71 @@
|
||||
package org.dromara.hutool.ai;
|
||||
|
||||
import org.dromara.hutool.core.exception.ExceptionUtil;
|
||||
import org.dromara.hutool.core.text.StrUtil;
|
||||
import org.dromara.hutool.core.exception.HutoolException;
|
||||
|
||||
/**
|
||||
* 异常处理类
|
||||
*/
|
||||
public class AIException extends RuntimeException {
|
||||
public class AIException extends HutoolException {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
public AIException(Throwable e) {
|
||||
super(ExceptionUtil.getMessage(e), e);
|
||||
/**
|
||||
* 构造
|
||||
*
|
||||
* @param e 异常
|
||||
*/
|
||||
public AIException(final Throwable e) {
|
||||
super(e);
|
||||
}
|
||||
|
||||
public AIException(String message) {
|
||||
/**
|
||||
* 构造
|
||||
*
|
||||
* @param message 消息
|
||||
*/
|
||||
public AIException(final String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public AIException(String messageTemplate, Object... params) {
|
||||
super(StrUtil.format(messageTemplate, params));
|
||||
/**
|
||||
* 构造
|
||||
*
|
||||
* @param messageTemplate 消息模板
|
||||
* @param params 参数
|
||||
*/
|
||||
public AIException(final String messageTemplate, final Object... params) {
|
||||
super(messageTemplate, params);
|
||||
}
|
||||
|
||||
public AIException(String message, Throwable throwable) {
|
||||
super(message, throwable);
|
||||
/**
|
||||
* 构造
|
||||
*
|
||||
* @param message 消息
|
||||
* @param cause 被包装的子异常
|
||||
*/
|
||||
public AIException(final String message, final Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
|
||||
public AIException(String message, Throwable throwable, boolean enableSuppression, boolean writableStackTrace) {
|
||||
super(message, throwable, enableSuppression, writableStackTrace);
|
||||
/**
|
||||
* 构造
|
||||
*
|
||||
* @param message 消息
|
||||
* @param cause 被包装的子异常
|
||||
* @param enableSuppression 是否启用抑制
|
||||
* @param writableStackTrace 堆栈跟踪是否应该是可写的
|
||||
*/
|
||||
public AIException(final String message, final Throwable cause, final boolean enableSuppression, final boolean writableStackTrace) {
|
||||
super(message, cause, enableSuppression, writableStackTrace);
|
||||
}
|
||||
|
||||
public AIException(Throwable throwable, String messageTemplate, Object... params) {
|
||||
super(StrUtil.format(messageTemplate, params), throwable);
|
||||
/**
|
||||
* 构造
|
||||
*
|
||||
* @param cause 被包装的子异常
|
||||
* @param messageTemplate 消息模板
|
||||
* @param params 参数
|
||||
*/
|
||||
public AIException(final Throwable cause, final String messageTemplate, final Object... params) {
|
||||
super(cause, messageTemplate, params);
|
||||
}
|
||||
}
|
||||
|
@ -22,8 +22,8 @@ public class AIServiceFactory {
|
||||
|
||||
// 加载所有 AIModelProvider 实现类
|
||||
static {
|
||||
ServiceLoader<AIServiceProvider> loader = ServiceLoader.load(AIServiceProvider.class);
|
||||
for (AIServiceProvider provider : loader) {
|
||||
final ServiceLoader<AIServiceProvider> loader = ServiceLoader.load(AIServiceProvider.class);
|
||||
for (final AIServiceProvider provider : loader) {
|
||||
providers.put(provider.getServiceName().toLowerCase(), provider);
|
||||
}
|
||||
}
|
||||
@ -35,7 +35,7 @@ public class AIServiceFactory {
|
||||
* @return AI服务实例
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public static AIService getAIService(AIConfig config) {
|
||||
public static AIService getAIService(final AIConfig config) {
|
||||
return getAIService(config, AIService.class);
|
||||
}
|
||||
|
||||
@ -46,21 +46,23 @@ public class AIServiceFactory {
|
||||
* @param clazz AI服务类
|
||||
* @return clazz对应的AI服务类实例
|
||||
* @since 6.0.0
|
||||
* @param <T> AI服务类
|
||||
*/
|
||||
public static <T extends AIService> T getAIService(AIConfig config, Class<T> clazz) {
|
||||
@SuppressWarnings("unchecked")
|
||||
public static <T extends AIService> T getAIService(final AIConfig config, final Class<T> clazz) {
|
||||
//异步执行
|
||||
CompletableFuture.runAsync(() -> {
|
||||
try {
|
||||
HttpUtil.get("https://static.hutool.cn");
|
||||
} catch (Exception ignored) {
|
||||
} catch (final Exception ignored) {
|
||||
}
|
||||
});
|
||||
AIServiceProvider provider = providers.get(config.getModelName().toLowerCase());
|
||||
final AIServiceProvider provider = providers.get(config.getModelName().toLowerCase());
|
||||
if (provider == null) {
|
||||
throw new IllegalArgumentException("Unsupported model: " + config.getModelName());
|
||||
}
|
||||
|
||||
AIService service = provider.create(config);
|
||||
final AIService service = provider.create(config);
|
||||
if (!clazz.isInstance(service)) {
|
||||
throw new AIException("Model service is not of type: " + clazz.getSimpleName());
|
||||
}
|
||||
|
@ -25,8 +25,9 @@ public class AIUtil {
|
||||
* @param clazz AI模型服务类
|
||||
* @return AIModelService的实现类实例
|
||||
* @since 6.0.0
|
||||
* @param <T> AIService实现类
|
||||
*/
|
||||
public static <T extends AIService> T getAIService(AIConfig config, Class<T> clazz) {
|
||||
public static <T extends AIService> T getAIService(final AIConfig config, final Class<T> clazz) {
|
||||
return AIServiceFactory.getAIService(config, clazz);
|
||||
}
|
||||
|
||||
@ -37,7 +38,7 @@ public class AIUtil {
|
||||
* @return AIModelService 其中只有公共方法
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public static AIService getAIService(AIConfig config) {
|
||||
public static AIService getAIService(final AIConfig config) {
|
||||
return getAIService(config, AIService.class);
|
||||
}
|
||||
|
||||
@ -48,7 +49,7 @@ public class AIUtil {
|
||||
* @return DeepSeekService
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public static DeepSeekService getDeepSeekService(AIConfig config) {
|
||||
public static DeepSeekService getDeepSeekService(final AIConfig config) {
|
||||
return getAIService(config, DeepSeekService.class);
|
||||
}
|
||||
|
||||
@ -59,7 +60,7 @@ public class AIUtil {
|
||||
* @return DoubaoService
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public static DoubaoService getDoubaoService(AIConfig config) {
|
||||
public static DoubaoService getDoubaoService(final AIConfig config) {
|
||||
return getAIService(config, DoubaoService.class);
|
||||
}
|
||||
|
||||
@ -70,7 +71,7 @@ public class AIUtil {
|
||||
* @return GrokService
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public static GrokService getGrokService(AIConfig config) {
|
||||
public static GrokService getGrokService(final AIConfig config) {
|
||||
return getAIService(config, GrokService.class);
|
||||
}
|
||||
|
||||
@ -81,7 +82,7 @@ public class AIUtil {
|
||||
* @return OpenAIService
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public static OpenaiService getOpenAIService(AIConfig config) {
|
||||
public static OpenaiService getOpenAIService(final AIConfig config) {
|
||||
return getAIService(config, OpenaiService.class);
|
||||
}
|
||||
|
||||
@ -93,7 +94,7 @@ public class AIUtil {
|
||||
* @return AI模型返回的Response响应字符串
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public static String chat(AIConfig config, String prompt) {
|
||||
public static String chat(final AIConfig config, final String prompt) {
|
||||
return getAIService(config).chat(prompt);
|
||||
}
|
||||
|
||||
@ -105,7 +106,7 @@ public class AIUtil {
|
||||
* @return AI模型返回的Response响应字符串
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public static String chat(AIConfig config, List<Message> messages) {
|
||||
public static String chat(final AIConfig config, final List<Message> messages) {
|
||||
return getAIService(config).chat(messages);
|
||||
}
|
||||
|
||||
|
@ -7,17 +7,34 @@ package org.dromara.hutool.ai;
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public enum ModelName {
|
||||
/**
|
||||
* deepSeek
|
||||
*/
|
||||
DEEPSEEK("deepSeek"),
|
||||
/**
|
||||
* openai
|
||||
*/
|
||||
OPENAI("openai"),
|
||||
/**
|
||||
* doubao
|
||||
*/
|
||||
DOUBAO("doubao"),
|
||||
/**
|
||||
* grok
|
||||
*/
|
||||
GROK("grok");
|
||||
|
||||
private final String value;
|
||||
|
||||
ModelName(String value) {
|
||||
ModelName(final String value) {
|
||||
this.value = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取值
|
||||
*
|
||||
* @return 值
|
||||
*/
|
||||
public String getValue() {
|
||||
return value;
|
||||
}
|
||||
|
@ -17,18 +17,18 @@ public class AIConfigBuilder {
|
||||
*
|
||||
* @param modelName 模型厂商的名称(注意不是指具体的模型)
|
||||
*/
|
||||
public AIConfigBuilder(String modelName) {
|
||||
public AIConfigBuilder(final String modelName) {
|
||||
try {
|
||||
// 获取配置类
|
||||
Class<? extends AIConfig> configClass = AIConfigRegistry.getConfigClass(modelName);
|
||||
final Class<? extends AIConfig> configClass = AIConfigRegistry.getConfigClass(modelName);
|
||||
if (configClass == null) {
|
||||
throw new IllegalArgumentException("Unsupported model: " + modelName);
|
||||
}
|
||||
|
||||
// 使用反射创建实例
|
||||
Constructor<? extends AIConfig> constructor = configClass.getDeclaredConstructor();
|
||||
final Constructor<? extends AIConfig> constructor = configClass.getDeclaredConstructor();
|
||||
config = constructor.newInstance();
|
||||
} catch (Exception e) {
|
||||
} catch (final Exception e) {
|
||||
throw new RuntimeException("Failed to create AIConfig instance", e);
|
||||
}
|
||||
}
|
||||
@ -40,7 +40,7 @@ public class AIConfigBuilder {
|
||||
* @return config
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public synchronized AIConfigBuilder setApiKey(String apiKey) {
|
||||
public synchronized AIConfigBuilder setApiKey(final String apiKey) {
|
||||
if (apiKey != null) {
|
||||
config.setApiKey(apiKey);
|
||||
}
|
||||
@ -54,7 +54,7 @@ public class AIConfigBuilder {
|
||||
* @return config
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public synchronized AIConfigBuilder setApiUrl(String apiUrl) {
|
||||
public synchronized AIConfigBuilder setApiUrl(final String apiUrl) {
|
||||
if (apiUrl != null) {
|
||||
config.setApiUrl(apiUrl);
|
||||
}
|
||||
@ -68,7 +68,7 @@ public class AIConfigBuilder {
|
||||
* @return config
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public synchronized AIConfigBuilder setModel(String model) {
|
||||
public synchronized AIConfigBuilder setModel(final String model) {
|
||||
if (model != null) {
|
||||
config.setModel(model);
|
||||
}
|
||||
@ -83,7 +83,7 @@ public class AIConfigBuilder {
|
||||
* @return config
|
||||
* @since 6.0.0
|
||||
*/
|
||||
public AIConfigBuilder putAdditionalConfig(String key, Object value) {
|
||||
public AIConfigBuilder putAdditionalConfig(final String key, final Object value) {
|
||||
if (value != null) {
|
||||
config.putAdditionalConfigByKey(key, value);
|
||||
}
|
||||
|
@ -17,13 +17,19 @@ public class AIConfigRegistry {
|
||||
|
||||
// 加载所有 AIConfig 实现类
|
||||
static {
|
||||
ServiceLoader<AIConfig> loader = ServiceLoader.load(AIConfig.class);
|
||||
for (AIConfig config : loader) {
|
||||
final ServiceLoader<AIConfig> loader = ServiceLoader.load(AIConfig.class);
|
||||
for (final AIConfig config : loader) {
|
||||
configClasses.put(config.getModelName().toLowerCase(), config.getClass());
|
||||
}
|
||||
}
|
||||
|
||||
public static Class<? extends AIConfig> getConfigClass(String modelName) {
|
||||
/**
|
||||
* 根据模型名称获取AIConfig实现类
|
||||
*
|
||||
* @param modelName 模型名称
|
||||
* @return AIConfig实现类
|
||||
*/
|
||||
public static Class<? extends AIConfig> getConfigClass(final String modelName) {
|
||||
return configClasses.get(modelName.toLowerCase());
|
||||
}
|
||||
}
|
||||
|
@ -19,11 +19,16 @@ public class BaseAIService {
|
||||
|
||||
protected final AIConfig config;
|
||||
|
||||
public BaseAIService(AIConfig config) {
|
||||
/**
|
||||
* 构造方法
|
||||
*
|
||||
* @param config AI配置
|
||||
*/
|
||||
public BaseAIService(final AIConfig config) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
protected Response sendGet(String endpoint) {
|
||||
protected Response sendGet(final String endpoint) {
|
||||
//链式构建请求
|
||||
try {
|
||||
//设置超时3分钟
|
||||
@ -32,12 +37,12 @@ public class BaseAIService {
|
||||
.header(HeaderName.ACCEPT, "application/json")
|
||||
.header(HeaderName.AUTHORIZATION, "Bearer " + config.getApiKey())
|
||||
.send();
|
||||
} catch (AIException e) {
|
||||
} catch (final AIException e) {
|
||||
throw new AIException("Failed to send GET request: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
protected Response sendPost(String endpoint, String paramJson) {
|
||||
protected Response sendPost(final String endpoint, final String paramJson) {
|
||||
//链式构建请求
|
||||
try {
|
||||
//设置超时3分钟
|
||||
@ -48,13 +53,13 @@ public class BaseAIService {
|
||||
.header(HeaderName.AUTHORIZATION, "Bearer " + config.getApiKey())
|
||||
.body(paramJson)
|
||||
.send();
|
||||
} catch (AIException e) {
|
||||
} catch (final AIException e) {
|
||||
throw new AIException("Failed to send POST request:" + e.getMessage(), e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
protected Response sendFormData(String endpoint, Map<String, Object> paramMap) {
|
||||
protected Response sendFormData(final String endpoint, final Map<String, Object> paramMap) {
|
||||
//链式构建请求
|
||||
try {
|
||||
//设置超时3分钟
|
||||
@ -65,7 +70,7 @@ public class BaseAIService {
|
||||
.header(HeaderName.AUTHORIZATION, "Bearer " + config.getApiKey())
|
||||
.form(paramMap)
|
||||
.send();
|
||||
} catch (AIException e) {
|
||||
} catch (final AIException e) {
|
||||
throw new AIException("Failed to send POST request:" + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ public class BaseConfig implements AIConfig {
|
||||
protected Map<String, Object> additionalConfig = new SafeConcurrentHashMap<>();
|
||||
|
||||
@Override
|
||||
public void setApiKey(String apiKey) {
|
||||
public void setApiKey(final String apiKey) {
|
||||
this.apiKey = apiKey;
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ public class BaseConfig implements AIConfig {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setApiUrl(String apiUrl) {
|
||||
public void setApiUrl(final String apiUrl) {
|
||||
this.apiUrl = apiUrl;
|
||||
}
|
||||
|
||||
@ -43,7 +43,7 @@ public class BaseConfig implements AIConfig {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setModel(String model) {
|
||||
public void setModel(final String model) {
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
@ -53,12 +53,12 @@ public class BaseConfig implements AIConfig {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void putAdditionalConfigByKey(String key, Object value) {
|
||||
public void putAdditionalConfigByKey(final String key, final Object value) {
|
||||
this.additionalConfig.put(key, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Object getAdditionalConfigByKey(String key) {
|
||||
public Object getAdditionalConfigByKey(final String key) {
|
||||
return additionalConfig.get(key);
|
||||
}
|
||||
|
||||
|
@ -12,15 +12,31 @@ public class Message {
|
||||
//内容
|
||||
private final Object content;
|
||||
|
||||
public Message(String role, Object content) {
|
||||
/**
|
||||
* 构造
|
||||
*
|
||||
* @param role 角色
|
||||
* @param content 内容
|
||||
*/
|
||||
public Message(final String role, final Object content) {
|
||||
this.role = role;
|
||||
this.content = content;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取角色
|
||||
*
|
||||
* @return 角色
|
||||
*/
|
||||
public String getRole() {
|
||||
return role;
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取内容
|
||||
*
|
||||
* @return 内容
|
||||
*/
|
||||
public Object getContent() {
|
||||
return content;
|
||||
}
|
||||
|
@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Hutool Team and hutool.cn
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* AI相关基础类
|
||||
*
|
||||
* @author elichow
|
||||
*/
|
||||
package org.dromara.hutool.ai.core;
|
@ -0,0 +1,22 @@
|
||||
/*
|
||||
* Copyright (c) 2025 Hutool Team and hutool.cn
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/**
|
||||
* AI模块
|
||||
*
|
||||
* @author elichow
|
||||
*/
|
||||
package org.dromara.hutool.ai;
|
Loading…
Reference in New Issue
Block a user