erius
8fce488888
TODO: use dataloaders to reduce amount of sql queries (figure out how to batch query nested paginated data) and add some basic unit or integrated testing
217 lines
6 KiB
Go
217 lines
6 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"git.obamna.ru/erius/ozon-task/graph/model"
|
|
"github.com/vikstrous/dataloadgen"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type database struct {
|
|
*gorm.DB
|
|
}
|
|
|
|
type Database struct {
|
|
db database
|
|
|
|
// loaders for batch loading and caching of objects
|
|
// to prevent n + 1 problem and redundant sql queries during nested pagination
|
|
postCommentsLoader *dataloadgen.Loader[uint, []*model.Comment]
|
|
commentRepliesLoader *dataloadgen.Loader[uint, []*model.Comment]
|
|
}
|
|
|
|
func InitDatabase(con *gorm.DB) *Database {
|
|
db := database{con}
|
|
return &Database{
|
|
db: db,
|
|
postCommentsLoader: dataloadgen.NewLoader(db.fetchComments),
|
|
commentRepliesLoader: dataloadgen.NewLoader(db.fetchReplies),
|
|
}
|
|
}
|
|
|
|
func (s *Database) AddPost(input *model.PostInput) (*model.AddResult, error) {
|
|
post := model.PostFromInput(input)
|
|
err := s.db.Create(post).Error
|
|
return &model.AddResult{ItemID: &post.ID}, err
|
|
}
|
|
|
|
func (s *Database) AddReplyToComment(input *model.CommentInput) (*model.AddResult, error) {
|
|
comment := model.CommentFromInput(input)
|
|
err := s.db.Model(&model.Comment{}).Select("post_id").Where("id = ?", *input.ParentCommentID).Scan(&comment.PostID).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// multiple mutable operations performed in one transaction
|
|
err = s.db.Transaction(func(tx *gorm.DB) error {
|
|
// insert comment
|
|
err := s.db.Create(comment).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// add new reply to parent
|
|
err = s.db.Table("comment_replies").Create(map[string]interface{}{
|
|
"comment_id": *input.ParentCommentID,
|
|
"reply_id": comment.ID,
|
|
}).Error
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
return &model.AddResult{ItemID: &comment.ID}, err
|
|
}
|
|
|
|
func (s *Database) AddCommentToPost(input *model.CommentInput) (*model.AddResult, error) {
|
|
comment := model.CommentFromInput(input)
|
|
var allowComments bool
|
|
err := s.db.Table("posts").Select("allow_comments").Where("id = ?", *input.ParentPostID).Scan(&allowComments).Error
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !allowComments {
|
|
return nil, fmt.Errorf("author disabled comments for this post")
|
|
}
|
|
comment.RootComment = true
|
|
comment.PostID = *input.ParentPostID
|
|
err = s.db.Create(comment).Error
|
|
return &model.AddResult{ItemID: &comment.ID}, err
|
|
}
|
|
|
|
func (s *Database) GetPost(id uint, ctx context.Context) (*model.Post, error) {
|
|
var post model.Post
|
|
err := s.db.Find(&post, id).Error
|
|
return &post, err
|
|
}
|
|
|
|
func (s *Database) GetComment(id uint, ctx context.Context) (*model.Comment, error) {
|
|
var comment model.Comment
|
|
err := s.db.Find(&comment, id).Error
|
|
return &comment, err
|
|
}
|
|
|
|
func (s *Database) GetPosts(first uint, cursor *uint, ctx context.Context) (*model.PostsConnection, error) {
|
|
offset := 0
|
|
if cursor != nil {
|
|
offset = int(*cursor)
|
|
}
|
|
var posts []*model.Post
|
|
res := s.db.Order("id").Limit(int(first)).Offset(offset).Find(&posts)
|
|
if res.Error != nil {
|
|
return nil, res.Error
|
|
}
|
|
if res.RowsAffected == 0 {
|
|
return &model.EmptyPostsConnections, nil
|
|
}
|
|
nextPage := true
|
|
if res.RowsAffected < int64(first) {
|
|
nextPage = false
|
|
}
|
|
info, edges := &model.PageInfo{
|
|
StartCursor: uint(offset),
|
|
EndCursor: posts[len(posts)-1].ID,
|
|
HasNextPage: nextPage,
|
|
}, make([]*model.PostsEdge, len(posts))
|
|
for i, p := range posts {
|
|
edges[i] = &model.PostsEdge{
|
|
Cursor: p.ID,
|
|
Node: p,
|
|
}
|
|
}
|
|
return &model.PostsConnection{
|
|
Edges: edges,
|
|
PageInfo: info,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Database) GetComments(post *model.Post, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) {
|
|
offset := 0
|
|
if cursor != nil {
|
|
offset = int(*cursor)
|
|
}
|
|
res := s.db.Where("post_id = ?", post.ID).Where("root = TRUE").Order("id").Limit(int(first)).Offset(offset).Find(&post.Comments)
|
|
if res.Error != nil {
|
|
return nil, res.Error
|
|
}
|
|
if res.RowsAffected == 0 {
|
|
return &model.EmptyCommentsConnection, nil
|
|
}
|
|
nextPage := true
|
|
if res.RowsAffected < int64(first) {
|
|
nextPage = false
|
|
}
|
|
info, edges := &model.PageInfo{
|
|
StartCursor: uint(offset),
|
|
EndCursor: post.Comments[len(post.Comments)-1].ID,
|
|
HasNextPage: nextPage,
|
|
}, make([]*model.CommentsEdge, len(post.Comments))
|
|
for i, c := range post.Comments {
|
|
edges[i] = &model.CommentsEdge{
|
|
Cursor: c.ID,
|
|
Node: c,
|
|
}
|
|
}
|
|
return &model.CommentsConnection{
|
|
Edges: edges,
|
|
PageInfo: info,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Database) GetReplies(comment *model.Comment, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) {
|
|
offset := 0
|
|
if cursor != nil {
|
|
offset = int(*cursor)
|
|
}
|
|
res := s.db.Model(&model.Comment{}).Joins("JOIN comment_replies ON comment_replies.reply_id = id").
|
|
Where("comment_id = ?", comment.ID).Order("id").Limit(int(first)).Offset(offset).Find(&comment.Replies)
|
|
if res.Error != nil {
|
|
return nil, res.Error
|
|
}
|
|
if res.RowsAffected == 0 {
|
|
return &model.EmptyCommentsConnection, nil
|
|
}
|
|
nextPage := true
|
|
if res.RowsAffected < int64(first) {
|
|
nextPage = false
|
|
}
|
|
info, edges := &model.PageInfo{
|
|
StartCursor: uint(offset),
|
|
EndCursor: comment.Replies[len(comment.Replies)-1].ID,
|
|
HasNextPage: nextPage,
|
|
}, make([]*model.CommentsEdge, len(comment.Replies))
|
|
for i, c := range comment.Replies {
|
|
edges[i] = &model.CommentsEdge{
|
|
Cursor: c.ID,
|
|
Node: c,
|
|
}
|
|
}
|
|
return &model.CommentsConnection{
|
|
Edges: edges,
|
|
PageInfo: info,
|
|
}, nil
|
|
}
|
|
|
|
// TODO: try to fix n + 1 problem by fetching data in bulk with loaders
|
|
// this is tricky because we are getting paginated data
|
|
// might use sql window functions
|
|
func (d *database) fetchComments(ctx context.Context, postIds []uint) ([][]*model.Comment, []error) {
|
|
var comments []*model.Comment
|
|
err := d.Where("post_id IN (?)", postIds).Where("root = TRUE").Order("post_id").Find(&comments).Error
|
|
if err != nil {
|
|
return nil, []error{err}
|
|
}
|
|
postComments := make([][]*model.Comment, 0, len(postIds))
|
|
i := 0
|
|
for _, c := range comments {
|
|
if c.ID != postIds[i] {
|
|
i++
|
|
}
|
|
postComments[i] = append(postComments[i], c)
|
|
}
|
|
return postComments, nil
|
|
}
|
|
|
|
func (d *database) fetchReplies(ctx context.Context, ids []uint) ([][]*model.Comment, []error) {
|
|
return nil, nil
|
|
}
|