ozon-task/internal/storage/db/database.go

218 lines
6 KiB
Go
Raw Permalink Normal View History

2024-06-24 23:34:10 +00:00
package db
import (
"context"
"fmt"
2024-06-24 23:34:10 +00:00
"git.obamna.ru/erius/ozon-task/graph/model"
"github.com/vikstrous/dataloadgen"
2024-06-24 23:34:10 +00:00
"gorm.io/gorm"
)
type database struct {
*gorm.DB
}
2024-06-24 23:34:10 +00:00
type Database struct {
db database
// loaders for batch loading and caching of objects
// to prevent n + 1 problem and redundant sql queries during nested pagination
postCommentsLoader *dataloadgen.Loader[uint, []*model.Comment]
commentRepliesLoader *dataloadgen.Loader[uint, []*model.Comment]
}
func InitDatabase(con *gorm.DB) *Database {
db := database{con}
return &Database{
db: db,
postCommentsLoader: dataloadgen.NewLoader(db.fetchComments),
commentRepliesLoader: dataloadgen.NewLoader(db.fetchReplies),
}
2024-06-24 23:34:10 +00:00
}
func (s *Database) AddPost(input *model.PostInput) (*model.AddResult, error) {
post := model.PostFromInput(input)
err := s.db.Create(post).Error
return &model.AddResult{ItemID: &post.ID}, err
}
func (s *Database) AddReplyToComment(input *model.CommentInput) (*model.AddResult, error) {
2024-06-24 23:34:10 +00:00
comment := model.CommentFromInput(input)
err := s.db.Model(&model.Comment{}).Select("post_id").Where("id = ?", *input.ParentCommentID).Scan(&comment.PostID).Error
if err != nil {
return nil, err
}
// multiple mutable operations performed in one transaction
err = s.db.Transaction(func(tx *gorm.DB) error {
// insert comment
err := s.db.Create(comment).Error
2024-06-24 23:34:10 +00:00
if err != nil {
return err
2024-06-24 23:34:10 +00:00
}
// add new reply to parent
err = s.db.Table("comment_replies").Create(map[string]interface{}{
"comment_id": *input.ParentCommentID,
"reply_id": comment.ID,
}).Error
2024-06-24 23:34:10 +00:00
if err != nil {
return err
2024-06-24 23:34:10 +00:00
}
return nil
})
return &model.AddResult{ItemID: &comment.ID}, err
}
func (s *Database) AddCommentToPost(input *model.CommentInput) (*model.AddResult, error) {
comment := model.CommentFromInput(input)
var allowComments bool
err := s.db.Table("posts").Select("allow_comments").Where("id = ?", *input.ParentPostID).Scan(&allowComments).Error
if err != nil {
return nil, err
2024-06-24 23:34:10 +00:00
}
if !allowComments {
return nil, fmt.Errorf("author disabled comments for this post")
}
comment.RootComment = true
comment.PostID = *input.ParentPostID
err = s.db.Create(comment).Error
return &model.AddResult{ItemID: &comment.ID}, err
}
func (s *Database) GetPost(id uint, ctx context.Context) (*model.Post, error) {
var post model.Post
err := s.db.Find(&post, id).Error
return &post, err
}
func (s *Database) GetComment(id uint, ctx context.Context) (*model.Comment, error) {
var comment model.Comment
err := s.db.Find(&comment, id).Error
return &comment, err
}
func (s *Database) GetPosts(first uint, cursor *uint, ctx context.Context) (*model.PostsConnection, error) {
offset := 0
if cursor != nil {
offset = int(*cursor)
}
var posts []*model.Post
res := s.db.Order("id").Limit(int(first)).Offset(offset).Find(&posts)
if res.Error != nil {
return nil, res.Error
}
if res.RowsAffected == 0 {
return &model.EmptyPostsConnections, nil
}
nextPage := true
if res.RowsAffected < int64(first) {
nextPage = false
}
info, edges := &model.PageInfo{
StartCursor: uint(offset),
EndCursor: posts[len(posts)-1].ID,
HasNextPage: nextPage,
}, make([]*model.PostsEdge, len(posts))
for i, p := range posts {
edges[i] = &model.PostsEdge{
Cursor: p.ID,
Node: p,
}
}
return &model.PostsConnection{
Edges: edges,
PageInfo: info,
}, nil
}
func (s *Database) GetComments(post *model.Post, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) {
offset := 0
if cursor != nil {
offset = int(*cursor)
}
res := s.db.Where("post_id = ?", post.ID).Where("root = TRUE").Order("id").Limit(int(first)).Offset(offset).Find(&post.Comments)
if res.Error != nil {
return nil, res.Error
}
if res.RowsAffected == 0 {
return &model.EmptyCommentsConnection, nil
}
nextPage := true
if res.RowsAffected < int64(first) {
nextPage = false
}
info, edges := &model.PageInfo{
StartCursor: uint(offset),
EndCursor: post.Comments[len(post.Comments)-1].ID,
HasNextPage: nextPage,
}, make([]*model.CommentsEdge, len(post.Comments))
for i, c := range post.Comments {
edges[i] = &model.CommentsEdge{
Cursor: c.ID,
Node: c,
}
}
return &model.CommentsConnection{
Edges: edges,
PageInfo: info,
}, nil
}
func (s *Database) GetReplies(comment *model.Comment, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) {
offset := 0
if cursor != nil {
offset = int(*cursor)
}
res := s.db.Model(&model.Comment{}).Joins("JOIN comment_replies ON comment_replies.reply_id = id").
Where("comment_id = ?", comment.ID).Order("id").Limit(int(first)).Offset(offset).Find(&comment.Replies)
if res.Error != nil {
return nil, res.Error
}
if res.RowsAffected == 0 {
return &model.EmptyCommentsConnection, nil
}
nextPage := true
if res.RowsAffected < int64(first) {
nextPage = false
}
info, edges := &model.PageInfo{
StartCursor: uint(offset),
EndCursor: comment.Replies[len(comment.Replies)-1].ID,
HasNextPage: nextPage,
}, make([]*model.CommentsEdge, len(comment.Replies))
for i, c := range comment.Replies {
edges[i] = &model.CommentsEdge{
Cursor: c.ID,
Node: c,
}
}
return &model.CommentsConnection{
Edges: edges,
PageInfo: info,
}, nil
}
// TODO: try to fix n + 1 problem by fetching data in bulk with loaders
// this is tricky because we are getting paginated data
// might use sql window functions
func (d *database) fetchComments(ctx context.Context, postIds []uint) ([][]*model.Comment, []error) {
var comments []*model.Comment
err := d.Where("post_id IN (?)", postIds).Where("root = TRUE").Order("post_id").Find(&comments).Error
if err != nil {
return nil, []error{err}
}
postComments := make([][]*model.Comment, 0, len(postIds))
i := 0
for _, c := range comments {
if c.ID != postIds[i] {
i++
}
postComments[i] = append(postComments[i], c)
}
return postComments, nil
2024-06-24 23:34:10 +00:00
}
func (d *database) fetchReplies(ctx context.Context, ids []uint) ([][]*model.Comment, []error) {
return nil, nil
2024-06-24 23:34:10 +00:00
}