diff --git a/ShoppingAssistantApi.Infrastructure/Services/ProductService.cs b/ShoppingAssistantApi.Infrastructure/Services/ProductService.cs index 32fb2fc..8c4d567 100644 --- a/ShoppingAssistantApi.Infrastructure/Services/ProductService.cs +++ b/ShoppingAssistantApi.Infrastructure/Services/ProductService.cs @@ -55,127 +55,110 @@ public class ProductService : IProductService } }; - if (countOfMessage==1) + + var messagesForOpenAI = new List(); + + foreach (var item in previousMessages.Items) { - yield return new ServerSentEvent - { - Event = SearchEventType.Message, - Data = previousMessages.Items.FirstOrDefault()?.Text - }; - } - - if(countOfMessage>1) - { - var messagesForOpenAI = new List(); - messagesForOpenAI.Add(new OpenAiMessage() - { - Role = OpenAiRoleExtensions.RequestConvert(OpenAiRole.System), - Content = promptForGpt - }); - - foreach (var item in previousMessages.Items) - { - messagesForOpenAI.Add( - new OpenAiMessage() + messagesForOpenAI + .Add(new OpenAiMessage() { Role = item.Role.ToLower(), - Content = item.Text + Content = item.Text }); + } + + messagesForOpenAI.Add(new OpenAiMessage() + { + Role = OpenAiRoleExtensions.RequestConvert(OpenAiRole.User), + Content = message.Text + }); + + chatRequest.Messages.AddRange(messagesForOpenAI); + + var suggestionBuffer = new Suggestion(); + var messageBuffer = new MessagePart(); + var productBuffer = new ProductName(); + var currentDataType = SearchEventType.Wishlist; + var dataTypeHolder = string.Empty; + + await foreach (var data in _openAiService.GetChatCompletionStream(chatRequest, cancellationToken)) + { + if (data.Contains("[")) + { + if (dataTypeHolder=="[Message]" && messageBuffer.Text!=null) + { + _wishlistsService.AddMessageToPersonalWishlistAsync(wishlistId, new MessageCreateDto() + { + Text = messageBuffer.Text, + }, cancellationToken); + } + dataTypeHolder = string.Empty; + dataTypeHolder += data; } - - messagesForOpenAI.Add(new OpenAiMessage() - { - Role = OpenAiRoleExtensions.RequestConvert(OpenAiRole.User), - Content = message.Text - }); - - chatRequest.Messages.AddRange(messagesForOpenAI); - - var suggestionBuffer = new Suggestion(); - var messageBuffer = new MessagePart(); - var productBuffer = new ProductName(); - var currentDataType = SearchEventType.Wishlist; - var dataTypeHolder = string.Empty; - await foreach (var data in _openAiService.GetChatCompletionStream(chatRequest, cancellationToken)) + else if (data.Contains("]")) { - if (data.Contains("[")) + dataTypeHolder += data; + currentDataType = DetermineDataType(dataTypeHolder); + } + + else if (dataTypeHolder=="[" && !data.Contains("[")) + { + dataTypeHolder += data; + } + + else + { + switch (currentDataType) { - if (dataTypeHolder=="[Message]" && messageBuffer.Text!=null) - { - _wishlistsService.AddMessageToPersonalWishlistAsync(wishlistId, new MessageCreateDto() + case SearchEventType.Message: + yield return new ServerSentEvent { - Text = messageBuffer.Text, - }, cancellationToken); - } - dataTypeHolder = string.Empty; - dataTypeHolder += data; - } + Event = SearchEventType.Message, + Data = data + }; + messageBuffer.Text += data; + break; - else if (data.Contains("]")) - { - dataTypeHolder += data; - currentDataType = DetermineDataType(dataTypeHolder); - } - - else if (dataTypeHolder=="[" && !data.Contains("[")) - { - dataTypeHolder += data; - } - - else - { - switch (currentDataType) - { - case SearchEventType.Message: + case SearchEventType.Suggestion: + if (data.Contains(";")) + { yield return new ServerSentEvent { - Event = SearchEventType.Message, - Data = data + Event = SearchEventType.Suggestion, + Data = suggestionBuffer.Text }; - messageBuffer.Text += data; - break; - - case SearchEventType.Suggestion: - if (data.Contains(";")) - { - yield return new ServerSentEvent - { - Event = SearchEventType.Suggestion, - Data = suggestionBuffer.Text - }; - suggestionBuffer.Text = string.Empty; - break; - } - suggestionBuffer.Text += data; + suggestionBuffer.Text = string.Empty; break; + } + suggestionBuffer.Text += data; + break; - case SearchEventType.Product: - if (data.Contains(";")) + case SearchEventType.Product: + if (data.Contains(";")) + { + yield return new ServerSentEvent { - yield return new ServerSentEvent - { - Event = SearchEventType.Product, - Data = productBuffer.Name - }; - productBuffer.Name = string.Empty; + Event = SearchEventType.Product, + Data = productBuffer.Name + }; + productBuffer.Name = string.Empty; - //a complete description of the entity when the Amazon API is connected - await _wishlistsService.AddProductToPersonalWishlistAsync(wishlistId, new ProductCreateDto() - { - Url = "", - Name = productBuffer.Name, - Rating = 0, - Description = "", - ImagesUrls = new []{"", ""}, - WasOpened = false - }, cancellationToken); - - break; - } - productBuffer.Name += data; - break; - } + //a complete description of the entity when the Amazon API is connected + await _wishlistsService.AddProductToPersonalWishlistAsync(wishlistId, new ProductCreateDto() + { + Url = "", + Name = productBuffer.Name, + Rating = 0, + Description = "", + ImagesUrls = new []{"", ""}, + WasOpened = false + }, cancellationToken); + break; + } + productBuffer.Name += data; + break; } } } diff --git a/ShoppingAssistantApi.Tests/TestExtentions/DbInitializer.cs b/ShoppingAssistantApi.Tests/TestExtentions/DbInitializer.cs index f66a9b1..02929da 100644 --- a/ShoppingAssistantApi.Tests/TestExtentions/DbInitializer.cs +++ b/ShoppingAssistantApi.Tests/TestExtentions/DbInitializer.cs @@ -220,13 +220,8 @@ public class DbInitializer }, new Message { - Text = "You are a Shopping Assistant that helps people find product recommendations. Ask user additional questions if more context needed." + - "\nYou must return data with one of the prefixes:" + - "\n[Question] - return question" + - "\n[Suggestions] - return semicolon separated suggestion how to answer to a question" + - "\n[Message] - return text" + - "\n[Products] - return semicolon separated product names", - Role = "system", + Text = "What are you looking for?", + Role = "assistant", WishlistId = wishlistId4, CreatedById = user2.Id, CreatedDateUtc = DateTime.UtcNow @@ -234,15 +229,7 @@ public class DbInitializer new Message { Text = "What are you looking for?", - Role = "system", - WishlistId = wishlistId4, - CreatedById = user2.Id, - CreatedDateUtc = DateTime.UtcNow - }, - new Message - { - Text = "What are you looking for?", - Role = "system", + Role = "assistant", WishlistId = wishlistId3, CreatedById = user2.Id, CreatedDateUtc = DateTime.UtcNow diff --git a/ShoppingAssistantApi.Tests/Tests/ProductsTests.cs b/ShoppingAssistantApi.Tests/Tests/ProductsTests.cs index a017f06..f9756ec 100644 --- a/ShoppingAssistantApi.Tests/Tests/ProductsTests.cs +++ b/ShoppingAssistantApi.Tests/Tests/ProductsTests.cs @@ -57,7 +57,7 @@ public class ProductsTests : TestsBase if (eventName == "event: Message") { foundMessageEvent = true; - Assert.Equal("\"What are you looking for?\"", eventData); + Assert.NotNull(eventData); break; } } diff --git a/ShoppingAssistantApi.UnitTests/ProductTests.cs b/ShoppingAssistantApi.UnitTests/ProductTests.cs index 8a6fb99..9887d90 100644 --- a/ShoppingAssistantApi.UnitTests/ProductTests.cs +++ b/ShoppingAssistantApi.UnitTests/ProductTests.cs @@ -49,7 +49,8 @@ public class ProductTests " 3070TI", " ;", " GTX", " 4070TI", " ;", " ?" }; - var expectedMessages = new List { "What are you looking for?" }; + var expectedMessages = new List { " What", " u", " want", " ?" }; + var expectedSuggestion = new List { " USB-C", " Keyboard ultra", " USB-C" }; // Mock the GetChatCompletionStream method to provide the expected SSE data _openAiServiceMock.Setup(x => x.GetChatCompletionStream(It.IsAny(), cancellationToken)) @@ -63,10 +64,10 @@ public class ProductTests _wishListServiceMock .Setup(m => m.GetMessagesPageFromPersonalWishlistAsync( - It.IsAny(), // Очікуваний параметр wishlistId - It.IsAny(), // Очікуваний параметр pageNumber - It.IsAny(), // Очікуваний параметр pageSize - It.IsAny()) // Очікуваний параметр cancellationToken + It.IsAny(), + It.IsAny(), + It.IsAny(), + It.IsAny()) ) .ReturnsAsync(new PagedList( new List @@ -78,7 +79,6 @@ public class ProductTests CreatedById = "User2", Role = "User" }, - }, 1, 1, @@ -105,6 +105,7 @@ public class ProductTests // Check if the actual SSE events match the expected SSE events Assert.NotNull(actualSseEvents); Assert.Equal(expectedMessages, receivedMessages); + Assert.Equal(expectedSuggestion, receivedSuggestions); }