From 8fce4888885aa2e6560ea8b76216603df6610b8f Mon Sep 17 00:00:00 2001 From: erius Date: Thu, 27 Jun 2024 04:28:18 +0300 Subject: [PATCH] Implemented pagination for posts, comments and replies TODO: use dataloaders to reduce amount of sql queries (figure out how to batch query nested paginated data) and add some basic unit or integrated testing --- go.mod | 3 + go.sum | 8 ++ graph/generated.go | 36 ++++---- graph/model/data.go | 52 +++++++---- graph/query.resolvers.go | 6 +- graph/schema/query.graphqls | 6 +- internal/storage/db/database.go | 158 +++++++++++++++++++++++++++++--- internal/storage/db/postgres.go | 8 +- internal/storage/memory.go | 73 +++++++-------- internal/storage/storage.go | 6 +- 10 files changed, 256 insertions(+), 100 deletions(-) diff --git a/go.mod b/go.mod index 8e099bd..4399481 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22.4 require ( github.com/99designs/gqlgen v0.17.49 github.com/vektah/gqlparser/v2 v2.5.16 + github.com/vikstrous/dataloadgen v0.0.6 gorm.io/driver/postgres v1.5.9 gorm.io/gorm v1.25.10 ) @@ -28,6 +29,8 @@ require ( github.com/sosodev/duration v1.3.1 // indirect github.com/urfave/cli/v2 v2.27.2 // indirect github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 // indirect + go.opentelemetry.io/otel v1.11.1 // indirect + go.opentelemetry.io/otel/trace v1.11.1 // indirect golang.org/x/crypto v0.17.0 // indirect golang.org/x/mod v0.18.0 // indirect golang.org/x/sync v0.7.0 // indirect diff --git a/go.sum b/go.sum index ad8ad83..6549a18 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= @@ -61,8 +63,14 @@ github.com/urfave/cli/v2 v2.27.2 h1:6e0H+AkS+zDckwPCUrZkKX38mRaau4nL2uipkJpbkcI= github.com/urfave/cli/v2 v2.27.2/go.mod h1:g0+79LmHHATl7DAcHO99smiR/T7uGLw84w8Y42x+4eM= github.com/vektah/gqlparser/v2 v2.5.16 h1:1gcmLTvs3JLKXckwCwlUagVn/IlV2bwqle0vJ0vy5p8= github.com/vektah/gqlparser/v2 v2.5.16/go.mod h1:1lz1OeCqgQbQepsGxPVywrjdBHW2T08PUS3pJqepRww= +github.com/vikstrous/dataloadgen v0.0.6 h1:A7s/fI3QNnH80CA9vdNbWK7AsbLjIxNHpZnV+VnOT1s= +github.com/vikstrous/dataloadgen v0.0.6/go.mod h1:8vuQVpBH0ODbMKAPUdCAPcOGezoTIhgAjgex51t4vbg= github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 h1:+qGGcbkzsfDQNPPe9UDgpxAWQrhbbBXOYJFQDq/dtJw= github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913/go.mod h1:4aEEwZQutDLsQv2Deui4iYQ6DWTxR14g6m8Wv88+Xqk= +go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4= +go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE= +go.opentelemetry.io/otel/trace v1.11.1 h1:ofxdnzsNrGBYXbP7t7zpUK281+go5rF7dvdIZXF8gdQ= +go.opentelemetry.io/otel/trace v1.11.1/go.mod h1:f/Q9G7vzk5u91PhbmKbg1Qn0rzH1LJ4vbPHFGkTPtOk= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= diff --git a/graph/generated.go b/graph/generated.go index 68193ca..6f4597f 100644 --- a/graph/generated.go +++ b/graph/generated.go @@ -57,7 +57,7 @@ type ComplexityRoot struct { Author func(childComplexity int) int Contents func(childComplexity int) int ID func(childComplexity int) int - Replies func(childComplexity int, first uint, after uint) int + Replies func(childComplexity int, first uint, after *uint) int } CommentsConnection struct { @@ -84,7 +84,7 @@ type ComplexityRoot struct { Post struct { AllowComments func(childComplexity int) int Author func(childComplexity int) int - Comments func(childComplexity int, first uint, after uint) int + Comments func(childComplexity int, first uint, after *uint) int Contents func(childComplexity int) int ID func(childComplexity int) int Title func(childComplexity int) int @@ -103,23 +103,23 @@ type ComplexityRoot struct { Query struct { Comment func(childComplexity int, id uint) int Post func(childComplexity int, id uint) int - Posts func(childComplexity int, first uint, after uint) int + Posts func(childComplexity int, first uint, after *uint) int } } type CommentResolver interface { - Replies(ctx context.Context, obj *model.Comment, first uint, after uint) (*model.CommentsConnection, error) + Replies(ctx context.Context, obj *model.Comment, first uint, after *uint) (*model.CommentsConnection, error) } type MutationResolver interface { AddPost(ctx context.Context, input model.PostInput) (*model.AddResult, error) AddComment(ctx context.Context, input model.CommentInput) (*model.AddResult, error) } type PostResolver interface { - Comments(ctx context.Context, obj *model.Post, first uint, after uint) (*model.CommentsConnection, error) + Comments(ctx context.Context, obj *model.Post, first uint, after *uint) (*model.CommentsConnection, error) } type QueryResolver interface { Post(ctx context.Context, id uint) (*model.Post, error) - Posts(ctx context.Context, first uint, after uint) (*model.PostsConnection, error) + Posts(ctx context.Context, first uint, after *uint) (*model.PostsConnection, error) Comment(ctx context.Context, id uint) (*model.Comment, error) } @@ -180,7 +180,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Comment.Replies(childComplexity, args["first"].(uint), args["after"].(uint)), true + return e.complexity.Comment.Replies(childComplexity, args["first"].(uint), args["after"].(*uint)), true case "CommentsConnection.edges": if e.complexity.CommentsConnection.Edges == nil { @@ -279,7 +279,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Post.Comments(childComplexity, args["first"].(uint), args["after"].(uint)), true + return e.complexity.Post.Comments(childComplexity, args["first"].(uint), args["after"].(*uint)), true case "Post.contents": if e.complexity.Post.Contents == nil { @@ -364,7 +364,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.Posts(childComplexity, args["first"].(uint), args["after"].(uint)), true + return e.complexity.Query.Posts(childComplexity, args["first"].(uint), args["after"].(*uint)), true } return 0, false @@ -506,10 +506,10 @@ func (ec *executionContext) field_Comment_replies_args(ctx context.Context, rawA } } args["first"] = arg0 - var arg1 uint + var arg1 *uint if tmp, ok := rawArgs["after"]; ok { ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after")) - arg1, err = ec.unmarshalNID2uint(ctx, tmp) + arg1, err = ec.unmarshalOID2ᚖuint(ctx, tmp) if err != nil { return nil, err } @@ -560,10 +560,10 @@ func (ec *executionContext) field_Post_comments_args(ctx context.Context, rawArg } } args["first"] = arg0 - var arg1 uint + var arg1 *uint if tmp, ok := rawArgs["after"]; ok { ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after")) - arg1, err = ec.unmarshalNID2uint(ctx, tmp) + arg1, err = ec.unmarshalOID2ᚖuint(ctx, tmp) if err != nil { return nil, err } @@ -629,10 +629,10 @@ func (ec *executionContext) field_Query_posts_args(ctx context.Context, rawArgs } } args["first"] = arg0 - var arg1 uint + var arg1 *uint if tmp, ok := rawArgs["after"]; ok { ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after")) - arg1, err = ec.unmarshalNID2uint(ctx, tmp) + arg1, err = ec.unmarshalOID2ᚖuint(ctx, tmp) if err != nil { return nil, err } @@ -869,7 +869,7 @@ func (ec *executionContext) _Comment_replies(ctx context.Context, field graphql. }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Comment().Replies(rctx, obj, fc.Args["first"].(uint), fc.Args["after"].(uint)) + return ec.resolvers.Comment().Replies(rctx, obj, fc.Args["first"].(uint), fc.Args["after"].(*uint)) }) if err != nil { ec.Error(ctx, err) @@ -1556,7 +1556,7 @@ func (ec *executionContext) _Post_comments(ctx context.Context, field graphql.Co }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Post().Comments(rctx, obj, fc.Args["first"].(uint), fc.Args["after"].(uint)) + return ec.resolvers.Post().Comments(rctx, obj, fc.Args["first"].(uint), fc.Args["after"].(*uint)) }) if err != nil { ec.Error(ctx, err) @@ -1934,7 +1934,7 @@ func (ec *executionContext) _Query_posts(ctx context.Context, field graphql.Coll }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Posts(rctx, fc.Args["first"].(uint), fc.Args["after"].(uint)) + return ec.resolvers.Query().Posts(rctx, fc.Args["first"].(uint), fc.Args["after"].(*uint)) }) if err != nil { ec.Error(ctx, err) diff --git a/graph/model/data.go b/graph/model/data.go index 81632d3..61b3ad3 100644 --- a/graph/model/data.go +++ b/graph/model/data.go @@ -8,31 +8,51 @@ import ( const CommentLengthLimit = 2000 +var ( + EmptyCommentsConnection = CommentsConnection{ + Edges: make([]*CommentsEdge, 0), + PageInfo: &PageInfo{ + StartCursor: 0, + EndCursor: 0, + HasNextPage: false, + }, + } + EmptyPostsConnections = PostsConnection{ + Edges: make([]*PostsEdge, 0), + PageInfo: &PageInfo{ + StartCursor: 0, + EndCursor: 0, + HasNextPage: false, + }, + } +) + type Comment struct { // db fields - ID uint `json:"id" gorm:"primarykey"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt gorm.DeletedAt `gorm:"index"` + ID uint `json:"id" gorm:"primarykey;column:id"` + CreatedAt time.Time `gorm:"column:created_at"` + UpdatedAt time.Time `gorm:"column:updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index;column:deleted_at"` - PostID uint `json:"post_id"` - Author string `json:"author"` - Contents string `json:"contents"` - Replies []*Comment `json:"replies" gorm:"many2many:comment_replies"` + RootComment bool `gorm:"column:root"` + PostID uint `json:"post_id" gorm:"column:post_id"` + Author string `json:"author" gorm:"column:author"` + Contents string `json:"contents" gorm:"column:contents"` + Replies []*Comment `json:"replies" gorm:"many2many:comment_replies"` } type Post struct { // db fields - ID uint `json:"id" gorm:"primarykey"` - CreatedAt time.Time - UpdatedAt time.Time - DeletedAt gorm.DeletedAt `gorm:"index"` + ID uint `json:"id" gorm:"primarykey;column:id"` + CreatedAt time.Time `gorm:"column:created_at"` + UpdatedAt time.Time `gorm:"column:updated_at"` + DeletedAt gorm.DeletedAt `gorm:"index;column:deleted_at"` - Title string `json:"title"` - Author string `json:"author"` - Contents string `json:"contents"` + Title string `json:"title" gorm:"column:title"` + Author string `json:"author" gorm:"column:author"` + Contents string `json:"contents" gorm:"column:contents"` Comments []*Comment `json:"comments" gorm:"foreignKey:PostID"` - AllowComments bool `json:"allowComments"` + AllowComments bool `json:"allowComments" gorm:"column:allow_comments"` } func PostFromInput(input *PostInput) *Post { diff --git a/graph/query.resolvers.go b/graph/query.resolvers.go index 7b53909..d1e95e4 100644 --- a/graph/query.resolvers.go +++ b/graph/query.resolvers.go @@ -11,12 +11,12 @@ import ( ) // Replies is the resolver for the replies field. -func (r *commentResolver) Replies(ctx context.Context, obj *model.Comment, first uint, after uint) (*model.CommentsConnection, error) { +func (r *commentResolver) Replies(ctx context.Context, obj *model.Comment, first uint, after *uint) (*model.CommentsConnection, error) { return r.Storage.GetReplies(obj, first, after, ctx) } // Comments is the resolver for the comments field. -func (r *postResolver) Comments(ctx context.Context, obj *model.Post, first uint, after uint) (*model.CommentsConnection, error) { +func (r *postResolver) Comments(ctx context.Context, obj *model.Post, first uint, after *uint) (*model.CommentsConnection, error) { return r.Storage.GetComments(obj, first, after, ctx) } @@ -26,7 +26,7 @@ func (r *queryResolver) Post(ctx context.Context, id uint) (*model.Post, error) } // Posts is the resolver for the posts field. -func (r *queryResolver) Posts(ctx context.Context, first uint, after uint) (*model.PostsConnection, error) { +func (r *queryResolver) Posts(ctx context.Context, first uint, after *uint) (*model.PostsConnection, error) { return r.Storage.GetPosts(first, after, ctx) } diff --git a/graph/schema/query.graphqls b/graph/schema/query.graphqls index c2e8788..747255d 100644 --- a/graph/schema/query.graphqls +++ b/graph/schema/query.graphqls @@ -1,6 +1,6 @@ type Query { post(id: ID!): Post! - posts(first: ID!, after: ID!): PostsConnection! + posts(first: ID!, after: ID): PostsConnection! comment(id: ID!): Comment! } @@ -9,7 +9,7 @@ type Post { title: String! author: String! contents: String! - comments(first: ID!, after: ID!): CommentsConnection! + comments(first: ID!, after: ID): CommentsConnection! allowComments: Boolean! } @@ -17,5 +17,5 @@ type Comment { id: ID! author: String! contents: String! - replies(first: ID!, after: ID!): CommentsConnection! + replies(first: ID!, after: ID): CommentsConnection! } diff --git a/internal/storage/db/database.go b/internal/storage/db/database.go index 1de6719..031a948 100644 --- a/internal/storage/db/database.go +++ b/internal/storage/db/database.go @@ -3,15 +3,32 @@ package db import ( "context" "fmt" - "log" "git.obamna.ru/erius/ozon-task/graph/model" - "github.com/99designs/gqlgen/graphql" + "github.com/vikstrous/dataloadgen" "gorm.io/gorm" ) +type database struct { + *gorm.DB +} + type Database struct { - db *gorm.DB + db database + + // loaders for batch loading and caching of objects + // to prevent n + 1 problem and redundant sql queries during nested pagination + postCommentsLoader *dataloadgen.Loader[uint, []*model.Comment] + commentRepliesLoader *dataloadgen.Loader[uint, []*model.Comment] +} + +func InitDatabase(con *gorm.DB) *Database { + db := database{con} + return &Database{ + db: db, + postCommentsLoader: dataloadgen.NewLoader(db.fetchComments), + commentRepliesLoader: dataloadgen.NewLoader(db.fetchReplies), + } } func (s *Database) AddPost(input *model.PostInput) (*model.AddResult, error) { @@ -22,8 +39,12 @@ func (s *Database) AddPost(input *model.PostInput) (*model.AddResult, error) { func (s *Database) AddReplyToComment(input *model.CommentInput) (*model.AddResult, error) { comment := model.CommentFromInput(input) - // multiple operations performed in one transaction - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.db.Model(&model.Comment{}).Select("post_id").Where("id = ?", *input.ParentCommentID).Scan(&comment.PostID).Error + if err != nil { + return nil, err + } + // multiple mutable operations performed in one transaction + err = s.db.Transaction(func(tx *gorm.DB) error { // insert comment err := s.db.Create(comment).Error if err != nil { @@ -52,13 +73,13 @@ func (s *Database) AddCommentToPost(input *model.CommentInput) (*model.AddResult if !allowComments { return nil, fmt.Errorf("author disabled comments for this post") } + comment.RootComment = true + comment.PostID = *input.ParentPostID err = s.db.Create(comment).Error return &model.AddResult{ItemID: &comment.ID}, err } func (s *Database) GetPost(id uint, ctx context.Context) (*model.Post, error) { - log.Println(graphql.CollectAllFields(ctx)) - log.Println(graphql.CollectFieldsCtx(ctx, nil)) var post model.Post err := s.db.Find(&post, id).Error return &post, err @@ -70,14 +91,127 @@ func (s *Database) GetComment(id uint, ctx context.Context) (*model.Comment, err return &comment, err } -func (s *Database) GetPosts(first uint, after uint, ctx context.Context) (*model.PostsConnection, error) { - return nil, nil +func (s *Database) GetPosts(first uint, cursor *uint, ctx context.Context) (*model.PostsConnection, error) { + offset := 0 + if cursor != nil { + offset = int(*cursor) + } + var posts []*model.Post + res := s.db.Order("id").Limit(int(first)).Offset(offset).Find(&posts) + if res.Error != nil { + return nil, res.Error + } + if res.RowsAffected == 0 { + return &model.EmptyPostsConnections, nil + } + nextPage := true + if res.RowsAffected < int64(first) { + nextPage = false + } + info, edges := &model.PageInfo{ + StartCursor: uint(offset), + EndCursor: posts[len(posts)-1].ID, + HasNextPage: nextPage, + }, make([]*model.PostsEdge, len(posts)) + for i, p := range posts { + edges[i] = &model.PostsEdge{ + Cursor: p.ID, + Node: p, + } + } + return &model.PostsConnection{ + Edges: edges, + PageInfo: info, + }, nil } -func (s *Database) GetComments(post *model.Post, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) { - return nil, nil +func (s *Database) GetComments(post *model.Post, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) { + offset := 0 + if cursor != nil { + offset = int(*cursor) + } + res := s.db.Where("post_id = ?", post.ID).Where("root = TRUE").Order("id").Limit(int(first)).Offset(offset).Find(&post.Comments) + if res.Error != nil { + return nil, res.Error + } + if res.RowsAffected == 0 { + return &model.EmptyCommentsConnection, nil + } + nextPage := true + if res.RowsAffected < int64(first) { + nextPage = false + } + info, edges := &model.PageInfo{ + StartCursor: uint(offset), + EndCursor: post.Comments[len(post.Comments)-1].ID, + HasNextPage: nextPage, + }, make([]*model.CommentsEdge, len(post.Comments)) + for i, c := range post.Comments { + edges[i] = &model.CommentsEdge{ + Cursor: c.ID, + Node: c, + } + } + return &model.CommentsConnection{ + Edges: edges, + PageInfo: info, + }, nil } -func (s *Database) GetReplies(comment *model.Comment, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) { +func (s *Database) GetReplies(comment *model.Comment, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) { + offset := 0 + if cursor != nil { + offset = int(*cursor) + } + res := s.db.Model(&model.Comment{}).Joins("JOIN comment_replies ON comment_replies.reply_id = id"). + Where("comment_id = ?", comment.ID).Order("id").Limit(int(first)).Offset(offset).Find(&comment.Replies) + if res.Error != nil { + return nil, res.Error + } + if res.RowsAffected == 0 { + return &model.EmptyCommentsConnection, nil + } + nextPage := true + if res.RowsAffected < int64(first) { + nextPage = false + } + info, edges := &model.PageInfo{ + StartCursor: uint(offset), + EndCursor: comment.Replies[len(comment.Replies)-1].ID, + HasNextPage: nextPage, + }, make([]*model.CommentsEdge, len(comment.Replies)) + for i, c := range comment.Replies { + edges[i] = &model.CommentsEdge{ + Cursor: c.ID, + Node: c, + } + } + return &model.CommentsConnection{ + Edges: edges, + PageInfo: info, + }, nil +} + +// TODO: try to fix n + 1 problem by fetching data in bulk with loaders +// this is tricky because we are getting paginated data +// might use sql window functions +func (d *database) fetchComments(ctx context.Context, postIds []uint) ([][]*model.Comment, []error) { + var comments []*model.Comment + err := d.Where("post_id IN (?)", postIds).Where("root = TRUE").Order("post_id").Find(&comments).Error + if err != nil { + return nil, []error{err} + } + postComments := make([][]*model.Comment, 0, len(postIds)) + i := 0 + for _, c := range comments { + if c.ID != postIds[i] { + i++ + } + postComments[i] = append(postComments[i], c) + } + return postComments, nil +} + +func (d *database) fetchReplies(ctx context.Context, ids []uint) ([][]*model.Comment, []error) { return nil, nil } diff --git a/internal/storage/db/postgres.go b/internal/storage/db/postgres.go index e19b0c3..aa2815b 100644 --- a/internal/storage/db/postgres.go +++ b/internal/storage/db/postgres.go @@ -32,12 +32,12 @@ func InitPostgres() (*Database, error) { return nil, err } log.Println("opened connection to PostgreSQL database") - log.Println("migrating model scheme to database...") + log.Println("migrating model schema to database...") err = db.AutoMigrate(&model.Post{}, &model.Comment{}) if err != nil { - log.Printf("failed to automatically migrate model scheme: %s", err) + log.Printf("failed to automatically migrate model schema: %s", err) return nil, err } - log.Println("finished migrating model scheme") - return &Database{db}, nil + log.Println("finished migrating model schema") + return InitDatabase(db), nil } diff --git a/internal/storage/memory.go b/internal/storage/memory.go index 7dc8662..606f4ae 100644 --- a/internal/storage/memory.go +++ b/internal/storage/memory.go @@ -48,6 +48,7 @@ func (s *InMemory) AddCommentToPost(input *model.CommentInput) (*model.AddResult return nil, fmt.Errorf("author disabled comments for this post") } comment := model.CommentFromInput(input) + comment.RootComment = true s.insertComment(comment) parent.Comments = append(parent.Comments, comment) return &model.AddResult{ItemID: &comment.ID}, nil @@ -67,21 +68,25 @@ func (s *InMemory) GetComment(id uint, ctx context.Context) (*model.Comment, err return s.comments[id], nil } -func (s *InMemory) GetPosts(first uint, after uint, ctx context.Context) (*model.PostsConnection, error) { - if !s.postExists(after) { - return nil, &IDNotFoundError{objName: "post", id: after} +func (s *InMemory) GetPosts(first uint, cursor *uint, ctx context.Context) (*model.PostsConnection, error) { + start := uint(0) + if cursor != nil { + start = *cursor + 1 } - nextPage, until := true, after+first + if !s.postExists(start) { + return &model.EmptyPostsConnections, nil + } + nextPage, until := true, start+first if !s.postExists(until) { nextPage = false until = uint(len(s.posts)) } info, edges := &model.PageInfo{ - StartCursor: after, + StartCursor: start, EndCursor: until - 1, HasNextPage: nextPage, - }, make([]*model.PostsEdge, until-after) - for i, p := range s.posts[after:until] { + }, make([]*model.PostsEdge, until-start) + for i, p := range s.posts[start:until] { edges[i] = &model.PostsEdge{ Cursor: p.ID, Node: p, @@ -93,47 +98,33 @@ func (s *InMemory) GetPosts(first uint, after uint, ctx context.Context) (*model }, nil } -func (s *InMemory) GetComments(post *model.Post, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) { - if !s.commentExists(after) { - return nil, &IDNotFoundError{objName: "comment", id: after} - } - nextPage, until := true, after+first - if !s.commentExists(until) { - nextPage = false - until = uint(len(s.comments)) - } - info, edges := &model.PageInfo{ - StartCursor: after, - EndCursor: until - 1, - HasNextPage: nextPage, - }, make([]*model.CommentsEdge, until-after) - for i, c := range post.Comments[after:until] { - edges[i] = &model.CommentsEdge{ - Cursor: c.ID, - Node: c, - } - } - return &model.CommentsConnection{ - Edges: edges, - PageInfo: info, - }, nil +func (s *InMemory) GetComments(post *model.Post, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) { + return getCommentsFrom(post.Comments, first, cursor), nil } -func (s *InMemory) GetReplies(comment *model.Comment, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) { - if !s.commentExists(after) { - return nil, &IDNotFoundError{objName: "comment", id: after} +func (s *InMemory) GetReplies(comment *model.Comment, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) { + return getCommentsFrom(comment.Replies, first, cursor), nil +} + +func getCommentsFrom(source []*model.Comment, first uint, cursor *uint) *model.CommentsConnection { + start := uint(0) + if cursor != nil { + start = *cursor + 1 } - nextPage, until := true, after+first - if !s.commentExists(until) { + if start >= uint(len(source)) { + return &model.EmptyCommentsConnection + } + nextPage, until := true, start+first + if until >= uint(len(source)) { nextPage = false - until = uint(len(s.comments)) + until = uint(len(source)) } info, edges := &model.PageInfo{ - StartCursor: after, + StartCursor: start, EndCursor: until - 1, HasNextPage: nextPage, - }, make([]*model.CommentsEdge, until-after) - for i, c := range comment.Replies[after:until] { + }, make([]*model.CommentsEdge, until-start) + for i, c := range source[start:until] { edges[i] = &model.CommentsEdge{ Cursor: c.ID, Node: c, @@ -142,7 +133,7 @@ func (s *InMemory) GetReplies(comment *model.Comment, first uint, after uint, ct return &model.CommentsConnection{ Edges: edges, PageInfo: info, - }, nil + } } func (s *InMemory) postExists(id uint) bool { diff --git a/internal/storage/storage.go b/internal/storage/storage.go index d82552b..52966bc 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -41,9 +41,9 @@ type Storage interface { GetComment(id uint, ctx context.Context) (*model.Comment, error) // returns paginated data in the form of model.*Connection (passing context to prevent overfetching) - GetPosts(first uint, after uint, ctx context.Context) (*model.PostsConnection, error) - GetComments(post *model.Post, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) - GetReplies(comment *model.Comment, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) + GetPosts(first uint, cursor *uint, ctx context.Context) (*model.PostsConnection, error) + GetComments(post *model.Post, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) + GetReplies(comment *model.Comment, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) } func InitStorage() (Storage, error) {