API 1.2+3 open ai chat completion

This commit is contained in:
Tibor Bossanyi 2023-03-12 17:06:44 +01:00
parent cae214f87e
commit 26d344777c
9 changed files with 128 additions and 10 deletions

View File

@ -48,7 +48,7 @@ dependencies {
implementation("jakarta.mail:jakarta.mail-api:2.1.1") implementation("jakarta.mail:jakarta.mail-api:2.1.1")
implementation("org.eclipse.angus:angus-mail:2.0.1") implementation("org.eclipse.angus:angus-mail:2.0.1")
implementation ("com.aallam.openai:openai-client:2.1.3") implementation ("com.aallam.openai:openai-client:3.0.0")
implementation("io.ktor:ktor-client-java:2.2.3") implementation("io.ktor:ktor-client-java:2.2.3")
runtimeOnly("mysql:mysql-connector-java") runtimeOnly("mysql:mysql-connector-java")

View File

@ -1,7 +1,11 @@
package com.aitrainer.api.controller package com.aitrainer.api.controller
import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.chat.ChatMessage
import com.aitrainer.api.model.OpenAI import com.aitrainer.api.model.OpenAI
import com.aitrainer.api.model.OpenAIChat
import com.aitrainer.api.openai.OpenAIService import com.aitrainer.api.openai.OpenAIService
import com.google.gson.Gson
import kotlinx.coroutines.* import kotlinx.coroutines.*
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import org.springframework.web.bind.annotation.* import org.springframework.web.bind.annotation.*
@ -41,6 +45,21 @@ class OpenAIController() {
return result return result
} }
@OptIn(BetaOpenAI::class, DelicateCoroutinesApi::class)
@PostMapping("/openai/chat_completion")
fun getOpenAIChatCompletion(@RequestBody openai: OpenAIChat) : String {
var result = ""
val openAIService = OpenAIService(openai.modelName, openai.temperature)
val deferred = GlobalScope.async {
openAIService.chatCompletion(openai.messages)
}
runBlocking {
result = deferred.await().toString()
println("Result: $result" )
}
return result
}
@OptIn(DelicateCoroutinesApi::class) @OptIn(DelicateCoroutinesApi::class)
@GetMapping("/openai/list_models") @GetMapping("/openai/list_models")
fun getOpenAIModels(): MutableList<String> { fun getOpenAIModels(): MutableList<String> {

View File

@ -0,0 +1,15 @@
package com.aitrainer.api.model
import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.chat.ChatMessage
import com.google.gson.annotations.Expose
import jakarta.persistence.*
import org.springframework.lang.NonNull
@Entity
data class OpenAIChat @OptIn(BetaOpenAI::class) constructor(
@Expose @Id @GeneratedValue(strategy = GenerationType.IDENTITY) @get: NonNull var id: Long = 0,
@Expose @get: NonNull var messages: String,
@Expose @get: NonNull var modelName: String? = null,
@Expose @get: NonNull var temperature: Double? = null,
)

View File

@ -1,5 +1,9 @@
package com.aitrainer.api.openai package com.aitrainer.api.openai
import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.chat.ChatCompletion
import com.aallam.openai.api.chat.ChatCompletionRequest
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.client.OpenAI import com.aallam.openai.client.OpenAI
import com.aallam.openai.api.completion.CompletionRequest import com.aallam.openai.api.completion.CompletionRequest
import com.aallam.openai.api.completion.TextCompletion import com.aallam.openai.api.completion.TextCompletion
@ -7,6 +11,7 @@ import com.aallam.openai.api.logging.LogLevel
import com.aallam.openai.api.model.Model import com.aallam.openai.api.model.Model
import com.aallam.openai.api.model.ModelId import com.aallam.openai.api.model.ModelId
import com.aallam.openai.client.OpenAIConfig import com.aallam.openai.client.OpenAIConfig
import com.google.gson.Gson
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import java.util.Properties import java.util.Properties
@ -36,6 +41,12 @@ class OpenAIService(private val modelName: String?, private val temperature: Dou
} }
/* models:
gpt-3.5-turbo chat/completion
text-davinci-003 completion
*/
suspend fun completion(question: String): String { suspend fun completion(question: String): String {
return withContext(Dispatchers.IO) { return withContext(Dispatchers.IO) {
var realModelName = "text-davinci-003" var realModelName = "text-davinci-003"
@ -64,6 +75,41 @@ class OpenAIService(private val modelName: String?, private val temperature: Dou
} }
} }
@OptIn(BetaOpenAI::class)
suspend fun chatCompletion(chatMessagesJson: String): String? {
return withContext(Dispatchers.IO) {
val gson = Gson()
val messages = gson.fromJson(chatMessagesJson, Array<ChatMessage>::class.java).toList()
val lastQuestion = messages.last().content
var lengthQuestion = 0
for ( message in messages ) {
lengthQuestion += message.content.length
}
val realModelName = "gpt-3.5-turbo"
var realTemperature = 0.1
if ( temperature != null ) {
realTemperature = temperature
}
if (openAI == null) {
connect(realModelName)
}
println("OpenAI Chat Last Question: $lastQuestion")
val completionRequest = ChatCompletionRequest(
model = ModelId(realModelName),
messages = messages,
maxTokens = 4096 - lengthQuestion,
temperature = realTemperature,
)
val completion: ChatCompletion = openAI!!.chatCompletion(completionRequest)
val result = completion.choices[0].message?.content
print(result)
result
}
}
suspend fun getModels(): MutableList<String> { suspend fun getModels(): MutableList<String> {
return withContext(Dispatchers.IO) { return withContext(Dispatchers.IO) {
if (openAI == null) { if (openAI == null) {

View File

@ -1,7 +1,6 @@
#spring.config.activate.on-profile=dev,test,prod,prodtest #spring.config.activate.on-profile=dev,test,prod,prodtest
spring.config.use-legacy-processing = true spring.config.use-legacy-processing = true
## Spring DATASOURCE (DataSourceAutoConfiguration & DataSourceProperties) ## Spring DATASOURCE (DataSourceAutoConfiguration & DataSourceProperties)
#spring.datasource.url = jdbc:mysql://localhost:3306/aitrainer?autoReconnect=true&useUnicode=true&characterEncoding=UTF-8&allowMultiQueries=true&useSSL=false
spring.datasource.url = jdbc:mysql://localhost:3306/diet4you?serverTimezone=CET&useSSL=false&characterEncoding=UTF-8&allowMultiQueries=true spring.datasource.url = jdbc:mysql://localhost:3306/diet4you?serverTimezone=CET&useSSL=false&characterEncoding=UTF-8&allowMultiQueries=true
spring.datasource.username = aitrainer spring.datasource.username = aitrainer
spring.datasource.password = ENC(WZplPYr8WmrLHshesY4T6oXplK3MlUVJ) spring.datasource.password = ENC(WZplPYr8WmrLHshesY4T6oXplK3MlUVJ)

View File

@ -2,8 +2,8 @@ spring.config.activate.on-profile=dietprod
spring.config.use-legacy-processing = true spring.config.use-legacy-processing = true
## Spring DATASOURCE (DataSourceAutoConfiguration & DataSourceProperties) ## Spring DATASOURCE (DataSourceAutoConfiguration & DataSourceProperties)
spring.datasource.url = jdbc:mysql://mariadb-shared.db.svc.cluster.local:3306/diet4you?serverTimezone=CET&useSSL=true&characterEncoding=UTF-8&allowPublicKeyRetrieval=true&allowMultiQueries=true spring.datasource.url = jdbc:mysql://mariadb-diet4you.diet4you.svc.cluster.local:3306/diet4you?serverTimezone=CET&useSSL=true&characterEncoding=UTF-8&allowPublicKeyRetrieval=true&allowMultiQueries=true
spring.datasource.username = aitrainer spring.datasource.username = bossanyit
spring.datasource.password = ENC(WZplPYr8WmrLHshesY4T6oXplK3MlUVJ) spring.datasource.password = ENC(WZplPYr8WmrLHshesY4T6oXplK3MlUVJ)
# The SQL dialect makes Hibernate generate better SQL for the chosen database # The SQL dialect makes Hibernate generate better SQL for the chosen database
@ -19,4 +19,5 @@ application.version=1.2.0
jwt.secret=aitrainer jwt.secret=aitrainer
firebase.key=AIzaSyCUXBWV3_qzvV__ZWZA1siHftrrJpjDKh4 firebase.key=AIzaSyCUXBWV3_qzvV__ZWZA1siHftrrJpjDKh4
openai.key=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 openai.key=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
spring.mail.properties.mail.mime.charset=UTF-8

View File

@ -0,0 +1,6 @@
#spring.config.activate.on-profile=dev,test,prod,prodtest
spring.config.use-legacy-processing = true
## Spring DATASOURCE (DataSourceAutoConfiguration & DataSourceProperties)
spring.datasource.url = jdbc:mysql://192.168.100.98:3306/diet4you?serverTimezone=CET&useSSL=false&characterEncoding=UTF-8&allowMultiQueries=true
spring.datasource.username = aitrainer
spring.datasource.password = ENC(WZplPYr8WmrLHshesY4T6oXplK3MlUVJ)

View File

@ -28,11 +28,11 @@
</appender> </appender>
<!--<logger name="org.springframework" level="DEBUG" /> <!-- <logger name="org.springframework" level="DEBUG" />
<logger name="org.apache.tomcat" level="DEBUG"/> <logger name="org.apache.tomcat" level="DEBUG"/>
<logger name="org.apache.coyote" level="DEBUG"/> <logger name="org.apache.coyote" level="DEBUG"/>
<logger name="com.github.ulisesbocchio" level="DEBUG" /> <logger name="com.github.ulisesbocchio" level="DEBUG" />
<logger name="javax.net.ssl" level="DEBUG"/> --> <logger name="javax.net.ssl" level="DEBUG"/> -->
<logger name="com.aitrainer" level="INFO" /> <logger name="com.aitrainer" level="INFO" />
<logger name="org.hibernate" level="INFO" /> <logger name="org.hibernate" level="INFO" />

View File

@ -1,6 +1,10 @@
package com.aitrainer.api.test.openai package com.aitrainer.api.test.openai
import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.chat.ChatMessage
import com.aallam.openai.api.chat.ChatRole
import com.aitrainer.api.model.OpenAI import com.aitrainer.api.model.OpenAI
import com.aitrainer.api.model.OpenAIChat
import com.aitrainer.api.test.Tokenizer import com.aitrainer.api.test.Tokenizer
import com.google.gson.Gson import com.google.gson.Gson
import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.BeforeAll
@ -49,6 +53,34 @@ class OpenAITest {
}*/ }*/
@OptIn(BetaOpenAI::class)
@Test
fun `get a chat message completion`() {
val systemMsg = ChatMessage(
role = ChatRole.User,
content = "Te a Diet4You applikáció asszisztense vagy. Add meg ennek az ételnek a kalória és tápanyagadatait: Big Mac. Hány gramm egy adag ebből? A válasz ez az objektum JSON alakított formája legyen: Meal [cal: double, ch: double, fat: double, protein: double, sugar: double, portion: double]"
)
val listMessages = listOf(systemMsg)
val gson = Gson()
val messages = gson.toJson(listMessages)
val openai = OpenAIChat(
messages = messages,
modelName = "gpt-3.5-turbo",
temperature = 0.1
)
mockMvc.perform(
MockMvcRequestBuilders.post("/api/openai/chat_completion")
.contentType(MediaType.APPLICATION_JSON)
.header("Authorization", "Bearer $authToken")
.content(toJson(openai))
)
.andExpect(MockMvcResultMatchers.status().isOk)
}
@Test @Test
fun `get a question successfully with model name`() { fun `get a question successfully with model name`() {
val question = "Who the f. is Alice? Who sing that song?" val question = "Who the f. is Alice? Who sing that song?"