API 1.2+4 open ai chat completion fix

This commit is contained in:
Tibor Bossanyi 2023-03-13 07:52:38 +01:00
parent 26d344777c
commit c2cc98eeb7
5 changed files with 85 additions and 39 deletions

View File

@ -1,13 +1,12 @@
package com.aitrainer.api.controller
import com.aallam.openai.api.BetaOpenAI
import com.aallam.openai.api.chat.ChatMessage
import com.aitrainer.api.model.OpenAI
import com.aitrainer.api.model.OpenAIChat
import com.aitrainer.api.openai.OpenAIService
import com.google.gson.Gson
import kotlinx.coroutines.*
import org.slf4j.LoggerFactory
import org.springframework.beans.factory.annotation.Value
import org.springframework.web.bind.annotation.*
@RestController
@ -17,9 +16,9 @@ class OpenAIController() {
@OptIn(DelicateCoroutinesApi::class)
@PostMapping("/openai/completion")
fun getOpenAIResponse(@RequestBody question: String) : String {
var result = ""
val openAIService = OpenAIService(null, null)
fun getOpenAIResponse(@RequestBody question: String, @Value("\${openai.key}") openaiKey: String) : String {
var result: String
val openAIService = OpenAIService(openaiKey, null, null)
val deferred = GlobalScope.async {
openAIService.completion(question)
}
@ -32,9 +31,9 @@ class OpenAIController() {
@OptIn(DelicateCoroutinesApi::class)
@PostMapping("/openai/completion_with_model")
fun getOpenAIResponseWithModel(@RequestBody openai: OpenAI) : String {
var result = ""
val openAIService = OpenAIService(openai.modelName, openai.temperature)
fun getOpenAIResponseWithModel(@RequestBody openai: OpenAI, @Value("\${openai.key}") openaiKey: String) : String {
var result: String
val openAIService = OpenAIService(openaiKey, openai.modelName, openai.temperature)
val deferred = GlobalScope.async {
openAIService.completion(openai.question)
}
@ -47,24 +46,23 @@ class OpenAIController() {
@OptIn(BetaOpenAI::class, DelicateCoroutinesApi::class)
@PostMapping("/openai/chat_completion")
fun getOpenAIChatCompletion(@RequestBody openai: OpenAIChat) : String {
var result = ""
val openAIService = OpenAIService(openai.modelName, openai.temperature)
fun getOpenAIChatCompletion(@RequestBody openai: OpenAIChat, @Value("\${openai.key}") openaiKey: String, ) : String {
var result: String
val openAIService = OpenAIService(openaiKey, openai.modelName, openai.temperature)
val deferred = GlobalScope.async {
openAIService.chatCompletion(openai.messages)
}
runBlocking {
result = deferred.await().toString()
println("Result: $result" )
}
return result
}
@OptIn(DelicateCoroutinesApi::class)
@GetMapping("/openai/list_models")
fun getOpenAIModels(): MutableList<String> {
var result = mutableListOf<String>()
val openAIService = OpenAIService(null, null)
fun getOpenAIModels(@Value("\${openai.key}") openaiKey: String): MutableList<String> {
var result: MutableList<String>
val openAIService = OpenAIService(openaiKey,null, null)
val deferred = GlobalScope.async {
openAIService.getModels()
}

View File

@ -14,31 +14,23 @@ import com.aallam.openai.client.OpenAIConfig
import com.google.gson.Gson
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import java.util.Properties
class OpenAIService(private val modelName: String?, private val temperature: Double?) {
import org.springframework.beans.factory.annotation.Value
import org.springframework.stereotype.Service
@Service
class OpenAIService(@Value("\${openai.key}") private val openaiKey: String, private val modelName: String?, private val temperature: Double?) {
private var openAI: OpenAI? = null
var model: Model? = null
private val properties = Properties()
init {
val inputStream = ClassLoader.getSystemResourceAsStream("application.properties")
properties.load(inputStream)
inputStream?.close()
}
private var modelId: ModelId? = null
private suspend fun connect(modelName: String) {
val config = OpenAIConfig(
token = properties.getProperty("openai.key"),
token = openaiKey,
logLevel = LogLevel.All
)
openAI = OpenAI(config)
modelId = ModelId(modelName)
model = openAI!!.model(modelId!!)
}
@ -112,11 +104,12 @@ class OpenAIService(private val modelName: String?, private val temperature: Dou
suspend fun getModels(): MutableList<String> {
return withContext(Dispatchers.IO) {
val list: MutableList<String> = mutableListOf()
if (openAI == null) {
openAI = OpenAI(token = properties.getProperty("openai.key"))
openAI = OpenAI(token = openaiKey)
}
val list: MutableList<String> = mutableListOf()
openAI!!.models().forEach {
println(it)
list.add(it.id.id)

View File

@ -3,15 +3,16 @@ package com.aitrainer.api.test
import com.aitrainer.api.controller.OpenAIController
import org.junit.jupiter.api.Test
import org.junit.jupiter.api.TestInstance
import org.springframework.beans.factory.annotation.Value
import org.springframework.boot.test.context.SpringBootTest
@SpringBootTest
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
class CompletionTest {
@Test
fun testCompletion() {
fun testCompletion(@Value("\${openai.key}") openaiKey: String ) {
val openAIController = OpenAIController()
val response = openAIController.getOpenAIResponse("Mennyi 3 darab közepes tükörtojás kalóriaértéke? Csak egy szám intervallumértéket adj vissza eredményként")
val response = openAIController.getOpenAIResponse( "Mennyi 3 darab közepes tükörtojás kalóriaértéke? Csak egy szám intervallumértéket adj vissza eredményként", openaiKey)
println(response)
// response = openAIController.getOpenAIResponse("Az utolsó 5 kérdésemre: Mennyi 3 darab tükörtojás kalóriaértéke? nagy szórású intervallumokat adtál vissza. Hogyan tudom pontosítani a kérdést?")
// println(response)

View File

@ -100,7 +100,7 @@ class MealTest {
.andExpect(jsonPath("$.quantity").value(90.0))
.andExpect(jsonPath("$.proteinMax").value(25.0))
val mvcResult = mockMvc.perform(
var mvcResult = mockMvc.perform(
MockMvcRequestBuilders.post("/api/meal")
.contentType(MediaType.APPLICATION_JSON)
.header("Authorization", "Bearer $authToken")
@ -110,8 +110,8 @@ class MealTest {
.andReturn()
val gson= Gson()
val newRawMaterialJson = mvcResult.response.contentAsString
val newMeal = gson.fromJson(newRawMaterialJson, Meal::class.java)
val newMealJson = mvcResult.response.contentAsString
val newMeal = gson.fromJson(newMealJson, Meal::class.java)
mockMvc.perform(
MockMvcRequestBuilders.post("/api/meal")
@ -146,9 +146,29 @@ class MealTest {
.andExpect(jsonPath("$.name").value("Töltötttojás"))
.andExpect(jsonPath("$.quantity").value(330.0))
mealRepository.delete(meal)
mealRepository.delete(meal2)
mealRepository.delete(meal3)
mvcResult = mockMvc.perform(get("/api/meal/${newMeal.id-1}")
.header("Authorization", "Bearer $authToken")
.contentType(MediaType.APPLICATION_JSON))
.andExpect(status().isOk)
. andReturn()
val newMealJson1 = mvcResult.response.contentAsString
val newMeal1 = gson.fromJson(newMealJson1, Meal::class.java)
mealRepository.delete(newMeal1)
mealRepository.delete(newMeal)
mvcResult = mockMvc.perform(get("/api/meal/${newMeal.id+1}")
.header("Authorization", "Bearer $authToken")
.contentType(MediaType.APPLICATION_JSON))
.andExpect(status().isOk)
. andReturn()
val newMealJson3 = mvcResult.response.contentAsString
val newMeal3 = gson.fromJson(newMealJson3, Meal::class.java)
mealRepository.delete(newMeal3)
}
private fun toJson(obj: Any): String {

View File

@ -63,6 +63,7 @@ class OpenAITest {
val listMessages = listOf(systemMsg)
val gson = Gson()
val messages = gson.toJson(listMessages)
println("****** json $messages ********* ")
val openai = OpenAIChat(
@ -71,11 +72,44 @@ class OpenAITest {
temperature = 0.1
)
val json: String = toJson(openai)
println(json)
mockMvc.perform(
MockMvcRequestBuilders.post("/api/openai/chat_completion")
.contentType(MediaType.APPLICATION_JSON)
.header("Authorization", "Bearer $authToken")
.content(toJson(openai))
.content(json)
)
.andExpect(MockMvcResultMatchers.status().isOk)
}
@OptIn(BetaOpenAI::class)
@Test
fun `get a chat message completion no modelname`() {
val systemMsg = ChatMessage(
role = ChatRole.User,
content = "Te a Diet4You applikáció asszisztense vagy. Add meg ennek az ételnek a kalória és tápanyagadatait: 'Sült sütőtök 100g'. A válasz ez az objektum JSON alakított formája legyen: Meal [cal: double, ch: double, fat: double, protein: double, sugar: double, portion: double]. A 'portion' paraméter ezt tartalmazza: hány gramm egy átlagos adag ebből az ételből? "
)
val listMessages = listOf(systemMsg)
val gson = Gson()
val messages = gson.toJson(listMessages)
println("****** json $messages ********* ")
val openai = OpenAIChat(
messages = messages
)
val json: String = toJson(openai)
println(json)
mockMvc.perform(
MockMvcRequestBuilders.post("/api/openai/chat_completion")
.contentType(MediaType.APPLICATION_JSON)
.header("Authorization", "Bearer $authToken")
.content(json)
)
.andExpect(MockMvcResultMatchers.status().isOk)