made a fix for the pr

This commit is contained in:
stasex 2023-10-26 12:57:52 +03:00
parent adf956b44c
commit 51aca2e270
5 changed files with 56 additions and 67 deletions

View File

@ -1,12 +0,0 @@
using ShoppingAssistantApi.Application.IServices;
using ShoppingAssistantApi.Application.Models.Dtos;
using ShoppingAssistantApi.Application.Models.ProductSearch;
using ShoppingAssistantApi.Domain.Entities;
using ShoppingAssistantApi.Infrastructure.Services;
namespace ShoppingAssistantApi.Api.Mutations;
[ExtendObjectType(OperationTypeNames.Mutation)]
public class ProductMutation
{
}

View File

@ -1,11 +0,0 @@
using HotChocolate.Authorization;
using ShoppingAssistantApi.Application.IServices;
using ShoppingAssistantApi.Domain.Entities;
namespace ShoppingAssistantApi.Api.Queries;
[ExtendObjectType(OperationTypeNames.Query)]
public class ProductQuery
{
}

View File

@ -29,8 +29,6 @@ public class ProductService : IProductService
public async IAsyncEnumerable<ServerSentEvent> SearchProductAsync(string wishlistId, MessageCreateDto message, CancellationToken cancellationToken)
{
string firstMessageForUser = "What are you looking for?";
string promptForGpt =
"You are a Shopping Assistant that helps people find product recommendations. Ask user additional questions if more context needed." +
"\nYou must return data with one of the prefixes:" +
@ -39,56 +37,43 @@ public class ProductService : IProductService
"\n[Message] - return text" +
"\n[Products] - return semicolon separated product names";
var isFirstMessage = await _messagesRepository.GetCountAsync(message=>message.WishlistId==ObjectId.Parse((wishlistId)), cancellationToken);
var countOfMessage = await _messagesRepository
.GetCountAsync(message=>message.WishlistId==ObjectId.Parse((wishlistId)), cancellationToken);
var chatRequest = new ChatCompletionRequest();
var previousMessages = await _wishlistsService
.GetMessagesPageFromPersonalWishlistAsync(wishlistId, 1, countOfMessage, cancellationToken);
if (isFirstMessage==0)
var chatRequest = new ChatCompletionRequest
{
chatRequest = new ChatCompletionRequest
Messages = new List<OpenAiMessage>
{
Messages = new List<OpenAiMessage>
new OpenAiMessage
{
new OpenAiMessage
{
Role = OpenAiRole.System.ToString().ToLower(),
Content = promptForGpt
},
new OpenAiMessage()
{
Role = OpenAiRole.System.ToString().ToLower(),
Content = firstMessageForUser
}
},
Stream = true
};
_wishlistsService.AddMessageToPersonalWishlistAsync(wishlistId, new MessageCreateDto()
{
Text = firstMessageForUser,
}, cancellationToken);
Role = OpenAiRoleExtensions.RequestConvert(OpenAiRole.System),
Content = promptForGpt
}
}
};
if (countOfMessage==1)
{
yield return new ServerSentEvent
{
Event = SearchEventType.Message,
Data = firstMessageForUser
Data = previousMessages.Items.FirstOrDefault()?.Text
};
}
if(isFirstMessage!=0)
if(countOfMessage>1)
{
var previousMessages = _wishlistsService
.GetMessagesPageFromPersonalWishlistAsync(wishlistId, 1, 50, cancellationToken).Result.Items.ToList();
var messagesForOpenAI = new List<OpenAiMessage>();
messagesForOpenAI.Add(new OpenAiMessage()
{
Role = OpenAiRole.System.ToString().ToLower(),
Role = OpenAiRoleExtensions.RequestConvert(OpenAiRole.System),
Content = promptForGpt
});
foreach (var item in previousMessages )
foreach (var item in previousMessages.Items)
{
messagesForOpenAI.Add(
new OpenAiMessage()
@ -100,15 +85,11 @@ public class ProductService : IProductService
messagesForOpenAI.Add(new OpenAiMessage()
{
Role = MessageRoles.User.ToString().ToLower(),
Role = OpenAiRoleExtensions.RequestConvert(OpenAiRole.User),
Content = message.Text
});
chatRequest = new ChatCompletionRequest
{
Messages = messagesForOpenAI,
Stream = true
};
chatRequest.Messages.AddRange(messagesForOpenAI);
var suggestionBuffer = new Suggestion();
var messageBuffer = new MessagePart();

View File

@ -239,6 +239,14 @@ public class DbInitializer
CreatedById = user2.Id,
CreatedDateUtc = DateTime.UtcNow
},
new Message
{
Text = "What are you looking for?",
Role = "system",
WishlistId = wishlistId3,
CreatedById = user2.Id,
CreatedDateUtc = DateTime.UtcNow
},
};
await messagesCollection.InsertManyAsync(messages);

View File

@ -56,11 +56,35 @@ public class ProductTests
.Returns(expectedSseData.ToAsyncEnumerable());
_messagesRepositoryMock.Setup(m => m.GetCountAsync(It.IsAny<Expression<Func<Message, bool>>>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(0);
.ReturnsAsync(1);
_wishListServiceMock.Setup(w => w.AddMessageToPersonalWishlistAsync(wishlistId, It.IsAny<MessageCreateDto>(), cancellationToken))
.Verifiable();
_wishListServiceMock
.Setup(m => m.GetMessagesPageFromPersonalWishlistAsync(
It.IsAny<string>(), // Очікуваний параметр wishlistId
It.IsAny<int>(), // Очікуваний параметр pageNumber
It.IsAny<int>(), // Очікуваний параметр pageSize
It.IsAny<CancellationToken>()) // Очікуваний параметр cancellationToken
)
.ReturnsAsync(new PagedList<MessageDto>(
new List<MessageDto>
{
new MessageDto
{
Text = "What are you looking for?",
Id = "3",
CreatedById = "User2",
Role = "User"
},
},
1,
1,
1
));
// Act
var resultStream = _productService.SearchProductAsync(wishlistId, message, cancellationToken);
@ -81,7 +105,6 @@ public class ProductTests
// Check if the actual SSE events match the expected SSE events
Assert.NotNull(actualSseEvents);
Assert.Equal(expectedMessages, receivedMessages);
_wishListServiceMock.Verify(w => w.AddMessageToPersonalWishlistAsync(wishlistId, It.IsAny<MessageCreateDto>(), cancellationToken), Times.Once);
}