package db import ( "context" "fmt" "git.obamna.ru/erius/ozon-task/graph/model" "github.com/vikstrous/dataloadgen" "gorm.io/gorm" ) type database struct { *gorm.DB } type Database struct { db database // loaders for batch loading and caching of objects // to prevent n + 1 problem and redundant sql queries during nested pagination postCommentsLoader *dataloadgen.Loader[uint, []*model.Comment] commentRepliesLoader *dataloadgen.Loader[uint, []*model.Comment] } func InitDatabase(con *gorm.DB) *Database { db := database{con} return &Database{ db: db, postCommentsLoader: dataloadgen.NewLoader(db.fetchComments), commentRepliesLoader: dataloadgen.NewLoader(db.fetchReplies), } } func (s *Database) AddPost(input *model.PostInput) (*model.AddResult, error) { post := model.PostFromInput(input) err := s.db.Create(post).Error return &model.AddResult{ItemID: &post.ID}, err } func (s *Database) AddReplyToComment(input *model.CommentInput) (*model.AddResult, error) { comment := model.CommentFromInput(input) err := s.db.Model(&model.Comment{}).Select("post_id").Where("id = ?", *input.ParentCommentID).Scan(&comment.PostID).Error if err != nil { return nil, err } // multiple mutable operations performed in one transaction err = s.db.Transaction(func(tx *gorm.DB) error { // insert comment err := s.db.Create(comment).Error if err != nil { return err } // add new reply to parent err = s.db.Table("comment_replies").Create(map[string]interface{}{ "comment_id": *input.ParentCommentID, "reply_id": comment.ID, }).Error if err != nil { return err } return nil }) return &model.AddResult{ItemID: &comment.ID}, err } func (s *Database) AddCommentToPost(input *model.CommentInput) (*model.AddResult, error) { comment := model.CommentFromInput(input) var allowComments bool err := s.db.Table("posts").Select("allow_comments").Where("id = ?", *input.ParentPostID).Scan(&allowComments).Error if err != nil { return nil, err } if !allowComments { return nil, fmt.Errorf("author disabled comments for this post") } comment.RootComment = true comment.PostID = *input.ParentPostID err = s.db.Create(comment).Error return &model.AddResult{ItemID: &comment.ID}, err } func (s *Database) GetPost(id uint, ctx context.Context) (*model.Post, error) { var post model.Post err := s.db.Find(&post, id).Error return &post, err } func (s *Database) GetComment(id uint, ctx context.Context) (*model.Comment, error) { var comment model.Comment err := s.db.Find(&comment, id).Error return &comment, err } func (s *Database) GetPosts(first uint, cursor *uint, ctx context.Context) (*model.PostsConnection, error) { offset := 0 if cursor != nil { offset = int(*cursor) } var posts []*model.Post res := s.db.Order("id").Limit(int(first)).Offset(offset).Find(&posts) if res.Error != nil { return nil, res.Error } if res.RowsAffected == 0 { return &model.EmptyPostsConnections, nil } nextPage := true if res.RowsAffected < int64(first) { nextPage = false } info, edges := &model.PageInfo{ StartCursor: uint(offset), EndCursor: posts[len(posts)-1].ID, HasNextPage: nextPage, }, make([]*model.PostsEdge, len(posts)) for i, p := range posts { edges[i] = &model.PostsEdge{ Cursor: p.ID, Node: p, } } return &model.PostsConnection{ Edges: edges, PageInfo: info, }, nil } func (s *Database) GetComments(post *model.Post, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) { offset := 0 if cursor != nil { offset = int(*cursor) } res := s.db.Where("post_id = ?", post.ID).Where("root = TRUE").Order("id").Limit(int(first)).Offset(offset).Find(&post.Comments) if res.Error != nil { return nil, res.Error } if res.RowsAffected == 0 { return &model.EmptyCommentsConnection, nil } nextPage := true if res.RowsAffected < int64(first) { nextPage = false } info, edges := &model.PageInfo{ StartCursor: uint(offset), EndCursor: post.Comments[len(post.Comments)-1].ID, HasNextPage: nextPage, }, make([]*model.CommentsEdge, len(post.Comments)) for i, c := range post.Comments { edges[i] = &model.CommentsEdge{ Cursor: c.ID, Node: c, } } return &model.CommentsConnection{ Edges: edges, PageInfo: info, }, nil } func (s *Database) GetReplies(comment *model.Comment, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) { offset := 0 if cursor != nil { offset = int(*cursor) } res := s.db.Model(&model.Comment{}).Joins("JOIN comment_replies ON comment_replies.reply_id = id"). Where("comment_id = ?", comment.ID).Order("id").Limit(int(first)).Offset(offset).Find(&comment.Replies) if res.Error != nil { return nil, res.Error } if res.RowsAffected == 0 { return &model.EmptyCommentsConnection, nil } nextPage := true if res.RowsAffected < int64(first) { nextPage = false } info, edges := &model.PageInfo{ StartCursor: uint(offset), EndCursor: comment.Replies[len(comment.Replies)-1].ID, HasNextPage: nextPage, }, make([]*model.CommentsEdge, len(comment.Replies)) for i, c := range comment.Replies { edges[i] = &model.CommentsEdge{ Cursor: c.ID, Node: c, } } return &model.CommentsConnection{ Edges: edges, PageInfo: info, }, nil } // TODO: try to fix n + 1 problem by fetching data in bulk with loaders // this is tricky because we are getting paginated data // might use sql window functions func (d *database) fetchComments(ctx context.Context, postIds []uint) ([][]*model.Comment, []error) { var comments []*model.Comment err := d.Where("post_id IN (?)", postIds).Where("root = TRUE").Order("post_id").Find(&comments).Error if err != nil { return nil, []error{err} } postComments := make([][]*model.Comment, 0, len(postIds)) i := 0 for _, c := range comments { if c.ID != postIds[i] { i++ } postComments[i] = append(postComments[i], c) } return postComments, nil } func (d *database) fetchReplies(ctx context.Context, ids []uint) ([][]*model.Comment, []error) { return nil, nil }