Merge pull request 'API 1.2+4 open ai chat completion fix' (#13) from tibor into master
Reviewed-on: https://git.workouttest.org/bossanyit/aitrainer_server/pulls/13
This commit is contained in:
commit
cc643eb844
@ -1,13 +1,12 @@
|
||||
package com.aitrainer.api.controller
|
||||
|
||||
import com.aallam.openai.api.BetaOpenAI
|
||||
import com.aallam.openai.api.chat.ChatMessage
|
||||
import com.aitrainer.api.model.OpenAI
|
||||
import com.aitrainer.api.model.OpenAIChat
|
||||
import com.aitrainer.api.openai.OpenAIService
|
||||
import com.google.gson.Gson
|
||||
import kotlinx.coroutines.*
|
||||
import org.slf4j.LoggerFactory
|
||||
import org.springframework.beans.factory.annotation.Value
|
||||
import org.springframework.web.bind.annotation.*
|
||||
|
||||
@RestController
|
||||
@ -17,9 +16,9 @@ class OpenAIController() {
|
||||
|
||||
@OptIn(DelicateCoroutinesApi::class)
|
||||
@PostMapping("/openai/completion")
|
||||
fun getOpenAIResponse(@RequestBody question: String) : String {
|
||||
var result = ""
|
||||
val openAIService = OpenAIService(null, null)
|
||||
fun getOpenAIResponse(@RequestBody question: String, @Value("\${openai.key}") openaiKey: String) : String {
|
||||
var result: String
|
||||
val openAIService = OpenAIService(openaiKey, null, null)
|
||||
val deferred = GlobalScope.async {
|
||||
openAIService.completion(question)
|
||||
}
|
||||
@ -32,9 +31,9 @@ class OpenAIController() {
|
||||
|
||||
@OptIn(DelicateCoroutinesApi::class)
|
||||
@PostMapping("/openai/completion_with_model")
|
||||
fun getOpenAIResponseWithModel(@RequestBody openai: OpenAI) : String {
|
||||
var result = ""
|
||||
val openAIService = OpenAIService(openai.modelName, openai.temperature)
|
||||
fun getOpenAIResponseWithModel(@RequestBody openai: OpenAI, @Value("\${openai.key}") openaiKey: String) : String {
|
||||
var result: String
|
||||
val openAIService = OpenAIService(openaiKey, openai.modelName, openai.temperature)
|
||||
val deferred = GlobalScope.async {
|
||||
openAIService.completion(openai.question)
|
||||
}
|
||||
@ -47,24 +46,23 @@ class OpenAIController() {
|
||||
|
||||
@OptIn(BetaOpenAI::class, DelicateCoroutinesApi::class)
|
||||
@PostMapping("/openai/chat_completion")
|
||||
fun getOpenAIChatCompletion(@RequestBody openai: OpenAIChat) : String {
|
||||
var result = ""
|
||||
val openAIService = OpenAIService(openai.modelName, openai.temperature)
|
||||
fun getOpenAIChatCompletion(@RequestBody openai: OpenAIChat, @Value("\${openai.key}") openaiKey: String, ) : String {
|
||||
var result: String
|
||||
val openAIService = OpenAIService(openaiKey, openai.modelName, openai.temperature)
|
||||
val deferred = GlobalScope.async {
|
||||
openAIService.chatCompletion(openai.messages)
|
||||
}
|
||||
runBlocking {
|
||||
result = deferred.await().toString()
|
||||
println("Result: $result" )
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@OptIn(DelicateCoroutinesApi::class)
|
||||
@GetMapping("/openai/list_models")
|
||||
fun getOpenAIModels(): MutableList<String> {
|
||||
var result = mutableListOf<String>()
|
||||
val openAIService = OpenAIService(null, null)
|
||||
fun getOpenAIModels(@Value("\${openai.key}") openaiKey: String): MutableList<String> {
|
||||
var result: MutableList<String>
|
||||
val openAIService = OpenAIService(openaiKey,null, null)
|
||||
val deferred = GlobalScope.async {
|
||||
openAIService.getModels()
|
||||
}
|
||||
|
@ -14,31 +14,23 @@ import com.aallam.openai.client.OpenAIConfig
|
||||
import com.google.gson.Gson
|
||||
import kotlinx.coroutines.Dispatchers
|
||||
import kotlinx.coroutines.withContext
|
||||
import java.util.Properties
|
||||
|
||||
class OpenAIService(private val modelName: String?, private val temperature: Double?) {
|
||||
|
||||
import org.springframework.beans.factory.annotation.Value
|
||||
import org.springframework.stereotype.Service
|
||||
@Service
|
||||
class OpenAIService(@Value("\${openai.key}") private val openaiKey: String, private val modelName: String?, private val temperature: Double?) {
|
||||
|
||||
private var openAI: OpenAI? = null
|
||||
var model: Model? = null
|
||||
private val properties = Properties()
|
||||
|
||||
init {
|
||||
val inputStream = ClassLoader.getSystemResourceAsStream("application.properties")
|
||||
properties.load(inputStream)
|
||||
inputStream?.close()
|
||||
}
|
||||
|
||||
private var modelId: ModelId? = null
|
||||
private suspend fun connect(modelName: String) {
|
||||
val config = OpenAIConfig(
|
||||
token = properties.getProperty("openai.key"),
|
||||
token = openaiKey,
|
||||
logLevel = LogLevel.All
|
||||
)
|
||||
openAI = OpenAI(config)
|
||||
modelId = ModelId(modelName)
|
||||
model = openAI!!.model(modelId!!)
|
||||
|
||||
}
|
||||
|
||||
|
||||
@ -112,11 +104,12 @@ class OpenAIService(private val modelName: String?, private val temperature: Dou
|
||||
|
||||
suspend fun getModels(): MutableList<String> {
|
||||
return withContext(Dispatchers.IO) {
|
||||
val list: MutableList<String> = mutableListOf()
|
||||
if (openAI == null) {
|
||||
openAI = OpenAI(token = properties.getProperty("openai.key"))
|
||||
openAI = OpenAI(token = openaiKey)
|
||||
}
|
||||
|
||||
val list: MutableList<String> = mutableListOf()
|
||||
|
||||
openAI!!.models().forEach {
|
||||
println(it)
|
||||
list.add(it.id.id)
|
||||
|
@ -3,15 +3,16 @@ package com.aitrainer.api.test
|
||||
import com.aitrainer.api.controller.OpenAIController
|
||||
import org.junit.jupiter.api.Test
|
||||
import org.junit.jupiter.api.TestInstance
|
||||
import org.springframework.beans.factory.annotation.Value
|
||||
import org.springframework.boot.test.context.SpringBootTest
|
||||
|
||||
@SpringBootTest
|
||||
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
|
||||
class CompletionTest {
|
||||
@Test
|
||||
fun testCompletion() {
|
||||
fun testCompletion(@Value("\${openai.key}") openaiKey: String ) {
|
||||
val openAIController = OpenAIController()
|
||||
val response = openAIController.getOpenAIResponse("Mennyi 3 darab közepes tükörtojás kalóriaértéke? Csak egy szám intervallumértéket adj vissza eredményként")
|
||||
val response = openAIController.getOpenAIResponse( "Mennyi 3 darab közepes tükörtojás kalóriaértéke? Csak egy szám intervallumértéket adj vissza eredményként", openaiKey)
|
||||
println(response)
|
||||
// response = openAIController.getOpenAIResponse("Az utolsó 5 kérdésemre: Mennyi 3 darab tükörtojás kalóriaértéke? nagy szórású intervallumokat adtál vissza. Hogyan tudom pontosítani a kérdést?")
|
||||
// println(response)
|
||||
|
@ -100,7 +100,7 @@ class MealTest {
|
||||
.andExpect(jsonPath("$.quantity").value(90.0))
|
||||
.andExpect(jsonPath("$.proteinMax").value(25.0))
|
||||
|
||||
val mvcResult = mockMvc.perform(
|
||||
var mvcResult = mockMvc.perform(
|
||||
MockMvcRequestBuilders.post("/api/meal")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.header("Authorization", "Bearer $authToken")
|
||||
@ -110,8 +110,8 @@ class MealTest {
|
||||
.andReturn()
|
||||
|
||||
val gson= Gson()
|
||||
val newRawMaterialJson = mvcResult.response.contentAsString
|
||||
val newMeal = gson.fromJson(newRawMaterialJson, Meal::class.java)
|
||||
val newMealJson = mvcResult.response.contentAsString
|
||||
val newMeal = gson.fromJson(newMealJson, Meal::class.java)
|
||||
|
||||
mockMvc.perform(
|
||||
MockMvcRequestBuilders.post("/api/meal")
|
||||
@ -146,9 +146,29 @@ class MealTest {
|
||||
.andExpect(jsonPath("$.name").value("Töltötttojás"))
|
||||
.andExpect(jsonPath("$.quantity").value(330.0))
|
||||
|
||||
mealRepository.delete(meal)
|
||||
mealRepository.delete(meal2)
|
||||
mealRepository.delete(meal3)
|
||||
|
||||
mvcResult = mockMvc.perform(get("/api/meal/${newMeal.id-1}")
|
||||
.header("Authorization", "Bearer $authToken")
|
||||
.contentType(MediaType.APPLICATION_JSON))
|
||||
.andExpect(status().isOk)
|
||||
. andReturn()
|
||||
|
||||
val newMealJson1 = mvcResult.response.contentAsString
|
||||
val newMeal1 = gson.fromJson(newMealJson1, Meal::class.java)
|
||||
|
||||
mealRepository.delete(newMeal1)
|
||||
mealRepository.delete(newMeal)
|
||||
|
||||
mvcResult = mockMvc.perform(get("/api/meal/${newMeal.id+1}")
|
||||
.header("Authorization", "Bearer $authToken")
|
||||
.contentType(MediaType.APPLICATION_JSON))
|
||||
.andExpect(status().isOk)
|
||||
. andReturn()
|
||||
|
||||
val newMealJson3 = mvcResult.response.contentAsString
|
||||
val newMeal3 = gson.fromJson(newMealJson3, Meal::class.java)
|
||||
|
||||
mealRepository.delete(newMeal3)
|
||||
}
|
||||
|
||||
private fun toJson(obj: Any): String {
|
||||
|
@ -63,6 +63,7 @@ class OpenAITest {
|
||||
val listMessages = listOf(systemMsg)
|
||||
val gson = Gson()
|
||||
val messages = gson.toJson(listMessages)
|
||||
println("****** json $messages ********* ")
|
||||
|
||||
|
||||
val openai = OpenAIChat(
|
||||
@ -71,16 +72,49 @@ class OpenAITest {
|
||||
temperature = 0.1
|
||||
)
|
||||
|
||||
val json: String = toJson(openai)
|
||||
println(json)
|
||||
|
||||
mockMvc.perform(
|
||||
MockMvcRequestBuilders.post("/api/openai/chat_completion")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.header("Authorization", "Bearer $authToken")
|
||||
.content(toJson(openai))
|
||||
.content(json)
|
||||
)
|
||||
.andExpect(MockMvcResultMatchers.status().isOk)
|
||||
|
||||
}
|
||||
|
||||
@OptIn(BetaOpenAI::class)
|
||||
@Test
|
||||
fun `get a chat message completion no modelname`() {
|
||||
val systemMsg = ChatMessage(
|
||||
role = ChatRole.User,
|
||||
content = "Te a Diet4You applikáció asszisztense vagy. Add meg ennek az ételnek a kalória és tápanyagadatait: 'Sült sütőtök 100g'. A válasz ez az objektum JSON alakított formája legyen: Meal [cal: double, ch: double, fat: double, protein: double, sugar: double, portion: double]. A 'portion' paraméter ezt tartalmazza: hány gramm egy átlagos adag ebből az ételből? "
|
||||
)
|
||||
val listMessages = listOf(systemMsg)
|
||||
val gson = Gson()
|
||||
val messages = gson.toJson(listMessages)
|
||||
println("****** json $messages ********* ")
|
||||
|
||||
|
||||
val openai = OpenAIChat(
|
||||
messages = messages
|
||||
)
|
||||
|
||||
val json: String = toJson(openai)
|
||||
println(json)
|
||||
|
||||
mockMvc.perform(
|
||||
MockMvcRequestBuilders.post("/api/openai/chat_completion")
|
||||
.contentType(MediaType.APPLICATION_JSON)
|
||||
.header("Authorization", "Bearer $authToken")
|
||||
.content(json)
|
||||
)
|
||||
.andExpect(MockMvcResultMatchers.status().isOk)
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
fun `get a question successfully with model name`() {
|
||||
val question = "Who the f. is Alice? Who sing that song?"
|
||||
|
Loading…
Reference in New Issue
Block a user