API 1.2+3 open ai chat completion
This commit is contained in:
parent
cae214f87e
commit
26d344777c
@ -48,7 +48,7 @@ dependencies {
|
||||
implementation("jakarta.mail:jakarta.mail-api:2.1.1")
|
||||
implementation("org.eclipse.angus:angus-mail:2.0.1")
|
||||
|
||||
implementation ("com.aallam.openai:openai-client:2.1.3")
|
||||
implementation ("com.aallam.openai:openai-client:3.0.0")
|
||||
implementation("io.ktor:ktor-client-java:2.2.3")
|
||||
|
||||
runtimeOnly("mysql:mysql-connector-java")
|
||||
|
@ -1,7 +1,11 @@
|
||||
package com.aitrainer.api.controller
|
||||
|
||||
import com.aallam.openai.api.BetaOpenAI
|
||||
import com.aallam.openai.api.chat.ChatMessage
|
||||
import com.aitrainer.api.model.OpenAI
|
||||
import com.aitrainer.api.model.OpenAIChat
|
||||
import com.aitrainer.api.openai.OpenAIService
|
||||
import com.google.gson.Gson
|
||||
import kotlinx.coroutines.*
|
||||
import org.slf4j.LoggerFactory
|
||||
import org.springframework.web.bind.annotation.*
|
||||
@ -41,6 +45,21 @@ class OpenAIController() {
|
||||
return result
|
||||
}
|
||||
|
||||
@OptIn(BetaOpenAI::class, DelicateCoroutinesApi::class)
|
||||
@PostMapping("/openai/chat_completion")
|
||||
fun getOpenAIChatCompletion(@RequestBody openai: OpenAIChat) : String {
|
||||
var result = ""
|
||||
val openAIService = OpenAIService(openai.modelName, openai.temperature)
|
||||
val deferred = GlobalScope.async {
|
||||
openAIService.chatCompletion(openai.messages)
|
||||
}
|
||||
runBlocking {
|
||||
result = deferred.await().toString()
|
||||
println("Result: $result" )
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@OptIn(DelicateCoroutinesApi::class)
|
||||
@GetMapping("/openai/list_models")
|
||||
fun getOpenAIModels(): MutableList<String> {
|
||||
|
15
src/main/kotlin/com/aitrainer/api/model/OpenAIChat.kt
Normal file
15
src/main/kotlin/com/aitrainer/api/model/OpenAIChat.kt
Normal file
@ -0,0 +1,15 @@
|
||||
package com.aitrainer.api.model
|
||||
|
||||
import com.aallam.openai.api.BetaOpenAI
|
||||
import com.aallam.openai.api.chat.ChatMessage
|
||||
import com.google.gson.annotations.Expose
|
||||
import jakarta.persistence.*
|
||||
import org.springframework.lang.NonNull
|
||||
|
||||
@Entity
|
||||
data class OpenAIChat @OptIn(BetaOpenAI::class) constructor(
|
||||
@Expose @Id @GeneratedValue(strategy = GenerationType.IDENTITY) @get: NonNull var id: Long = 0,
|
||||
@Expose @get: NonNull var messages: String,
|
||||
@Expose @get: NonNull var modelName: String? = null,
|
||||
@Expose @get: NonNull var temperature: Double? = null,
|
||||
)
|
@ -1,5 +1,9 @@
|
||||
package com.aitrainer.api.openai
|
||||
|
||||
import com.aallam.openai.api.BetaOpenAI
|
||||
import com.aallam.openai.api.chat.ChatCompletion
|
||||
import com.aallam.openai.api.chat.ChatCompletionRequest
|
||||
import com.aallam.openai.api.chat.ChatMessage
|
||||
import com.aallam.openai.client.OpenAI
|
||||
import com.aallam.openai.api.completion.CompletionRequest
|
||||
import com.aallam.openai.api.completion.TextCompletion
|
||||
@ -7,6 +11,7 @@ import com.aallam.openai.api.logging.LogLevel
|
||||
import com.aallam.openai.api.model.Model
|
||||
import com.aallam.openai.api.model.ModelId
|
||||
import com.aallam.openai.client.OpenAIConfig
|
||||
import com.google.gson.Gson
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.util.Properties
|
||||
@ -36,6 +41,12 @@ class OpenAIService(private val modelName: String?, private val temperature: Dou
|
||||
|
||||
}
|
||||
|
||||
|
||||
/* models:
|
||||
gpt-3.5-turbo chat/completion
|
||||
text-davinci-003 completion
|
||||
*/
|
||||
|
||||
suspend fun completion(question: String): String {
|
||||
return withContext(Dispatchers.IO) {
|
||||
var realModelName = "text-davinci-003"
|
||||
@ -64,6 +75,41 @@ class OpenAIService(private val modelName: String?, private val temperature: Dou
|
||||
}
|
||||
}
|
||||
|
||||
@OptIn(BetaOpenAI::class)
|
||||
suspend fun chatCompletion(chatMessagesJson: String): String? {
|
||||
return withContext(Dispatchers.IO) {
|
||||
val gson = Gson()
|
||||
val messages = gson.fromJson(chatMessagesJson, Array<ChatMessage>::class.java).toList()
|
||||
val lastQuestion = messages.last().content
|
||||
|
||||
var lengthQuestion = 0
|
||||
for ( message in messages ) {
|
||||
lengthQuestion += message.content.length
|
||||
}
|
||||
|
||||
val realModelName = "gpt-3.5-turbo"
|
||||
var realTemperature = 0.1
|
||||
if ( temperature != null ) {
|
||||
realTemperature = temperature
|
||||
}
|
||||
if (openAI == null) {
|
||||
connect(realModelName)
|
||||
}
|
||||
println("OpenAI Chat Last Question: $lastQuestion")
|
||||
val completionRequest = ChatCompletionRequest(
|
||||
model = ModelId(realModelName),
|
||||
messages = messages,
|
||||
maxTokens = 4096 - lengthQuestion,
|
||||
temperature = realTemperature,
|
||||
)
|
||||
val completion: ChatCompletion = openAI!!.chatCompletion(completionRequest)
|
||||
|
||||
val result = completion.choices[0].message?.content
|
||||
print(result)
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun getModels(): MutableList<String> {
|
||||
return withContext(Dispatchers.IO) {
|
||||
if (openAI == null) {
|
||||
|
@ -1,7 +1,6 @@
|
||||
#spring.config.activate.on-profile=dev,test,prod,prodtest
|
||||
spring.config.use-legacy-processing = true
|
||||
## Spring DATASOURCE (DataSourceAutoConfiguration & DataSourceProperties)
|
||||
#spring.datasource.url = jdbc:mysql://localhost:3306/aitrainer?autoReconnect=true&useUnicode=true&characterEncoding=UTF-8&allowMultiQueries=true&useSSL=false
|
||||
spring.datasource.url = jdbc:mysql://localhost:3306/diet4you?serverTimezone=CET&useSSL=false&characterEncoding=UTF-8&allowMultiQueries=true
|
||||
spring.datasource.username = aitrainer
|
||||
spring.datasource.password = ENC(WZplPYr8WmrLHshesY4T6oXplK3MlUVJ)
|
||||
|
@ -2,8 +2,8 @@ spring.config.activate.on-profile=dietprod
|
||||
spring.config.use-legacy-processing = true
|
||||
|
||||
## Spring DATASOURCE (DataSourceAutoConfiguration & DataSourceProperties)
|
||||
spring.datasource.url = jdbc:mysql://mariadb-shared.db.svc.cluster.local:3306/diet4you?serverTimezone=CET&useSSL=true&characterEncoding=UTF-8&allowPublicKeyRetrieval=true&allowMultiQueries=true
|
||||
spring.datasource.username = aitrainer
|
||||
spring.datasource.url = jdbc:mysql://mariadb-diet4you.diet4you.svc.cluster.local:3306/diet4you?serverTimezone=CET&useSSL=true&characterEncoding=UTF-8&allowPublicKeyRetrieval=true&allowMultiQueries=true
|
||||
spring.datasource.username = bossanyit
|
||||
spring.datasource.password = ENC(WZplPYr8WmrLHshesY4T6oXplK3MlUVJ)
|
||||
|
||||
# The SQL dialect makes Hibernate generate better SQL for the chosen database
|
||||
@ -20,3 +20,4 @@ jwt.secret=aitrainer
|
||||
|
||||
firebase.key=AIzaSyCUXBWV3_qzvV__ZWZA1siHftrrJpjDKh4
|
||||
openai.key=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
|
||||
spring.mail.properties.mail.mime.charset=UTF-8
|
6
src/main/resources/application-dietwsl.properties
Normal file
6
src/main/resources/application-dietwsl.properties
Normal file
@ -0,0 +1,6 @@
|
||||
#spring.config.activate.on-profile=dev,test,prod,prodtest
|
||||
spring.config.use-legacy-processing = true
|
||||
## Spring DATASOURCE (DataSourceAutoConfiguration & DataSourceProperties)
|
||||
spring.datasource.url = jdbc:mysql://192.168.100.98:3306/diet4you?serverTimezone=CET&useSSL=false&characterEncoding=UTF-8&allowMultiQueries=true
|
||||
spring.datasource.username = aitrainer
|
||||
spring.datasource.password = ENC(WZplPYr8WmrLHshesY4T6oXplK3MlUVJ)
|
@ -28,7 +28,7 @@
|
||||
</appender>
|
||||
|
||||
|
||||
<!--<logger name="org.springframework" level="DEBUG" />
|
||||
<!-- <logger name="org.springframework" level="DEBUG" />
|
||||
<logger name="org.apache.tomcat" level="DEBUG"/>
|
||||
<logger name="org.apache.coyote" level="DEBUG"/>
|
||||
<logger name="com.github.ulisesbocchio" level="DEBUG" />
|
||||
|
@ -1,6 +1,10 @@
|
||||
package com.aitrainer.api.test.openai
|
||||
|
||||
import com.aallam.openai.api.BetaOpenAI
|
||||
import com.aallam.openai.api.chat.ChatMessage
|
||||
import com.aallam.openai.api.chat.ChatRole
|
||||
import com.aitrainer.api.model.OpenAI
|
||||
import com.aitrainer.api.model.OpenAIChat
|
||||
import com.aitrainer.api.test.Tokenizer
|
||||
import com.google.gson.Gson
|
||||
import org.junit.jupiter.api.BeforeAll
|
||||
@ -49,6 +53,34 @@ class OpenAITest {
|
||||
|
||||
}*/
|
||||
|
||||
@OptIn(BetaOpenAI::class)
|
||||
@Test
|
||||
fun `get a chat message completion`() {
|
||||
val systemMsg = ChatMessage(
|
||||
role = ChatRole.User,
|
||||
content = "Te a Diet4You applikáció asszisztense vagy. Add meg ennek az ételnek a kalória és tápanyagadatait: Big Mac. Hány gramm egy adag ebből? A válasz ez az objektum JSON alakított formája legyen: Meal [cal: double, ch: double, fat: double, protein: double, sugar: double, portion: double]"
|
||||
)
|
||||
val listMessages = listOf(systemMsg)
|
||||
val gson = Gson()
|
||||
val messages = gson.toJson(listMessages)
|
||||
|
||||
|
||||
val openai = OpenAIChat(
|
||||
messages = messages,
|
||||
modelName = "gpt-3.5-turbo",
|
||||
temperature = 0.1
|
||||
)
|
||||
|
||||
mockMvc.perform(
|
||||
MockMvcRequestBuilders.post("/api/openai/chat_completion")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.header("Authorization", "Bearer $authToken")
|
||||
.content(toJson(openai))
|
||||
)
|
||||
.andExpect(MockMvcResultMatchers.status().isOk)
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `get a question successfully with model name`() {
|
||||
val question = "Who the f. is Alice? Who sing that song?"
|
||||
|
Loading…
Reference in New Issue
Block a user