Implemented pagination for posts, comments and replies

TODO: use dataloaders to reduce amount of sql queries (figure out how to batch query nested paginated data) and add some basic unit or integrated testing
This commit is contained in:
Egor 2024-06-27 04:28:18 +03:00
parent e0aa12b126
commit 8fce488888
10 changed files with 256 additions and 100 deletions

3
go.mod
View file

@ -5,6 +5,7 @@ go 1.22.4
require (
github.com/99designs/gqlgen v0.17.49
github.com/vektah/gqlparser/v2 v2.5.16
github.com/vikstrous/dataloadgen v0.0.6
gorm.io/driver/postgres v1.5.9
gorm.io/gorm v1.25.10
)
@ -28,6 +29,8 @@ require (
github.com/sosodev/duration v1.3.1 // indirect
github.com/urfave/cli/v2 v2.27.2 // indirect
github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 // indirect
go.opentelemetry.io/otel v1.11.1 // indirect
go.opentelemetry.io/otel/trace v1.11.1 // indirect
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/mod v0.18.0 // indirect
golang.org/x/sync v0.7.0 // indirect

8
go.sum
View file

@ -18,6 +18,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g=
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
@ -61,8 +63,14 @@ github.com/urfave/cli/v2 v2.27.2 h1:6e0H+AkS+zDckwPCUrZkKX38mRaau4nL2uipkJpbkcI=
github.com/urfave/cli/v2 v2.27.2/go.mod h1:g0+79LmHHATl7DAcHO99smiR/T7uGLw84w8Y42x+4eM=
github.com/vektah/gqlparser/v2 v2.5.16 h1:1gcmLTvs3JLKXckwCwlUagVn/IlV2bwqle0vJ0vy5p8=
github.com/vektah/gqlparser/v2 v2.5.16/go.mod h1:1lz1OeCqgQbQepsGxPVywrjdBHW2T08PUS3pJqepRww=
github.com/vikstrous/dataloadgen v0.0.6 h1:A7s/fI3QNnH80CA9vdNbWK7AsbLjIxNHpZnV+VnOT1s=
github.com/vikstrous/dataloadgen v0.0.6/go.mod h1:8vuQVpBH0ODbMKAPUdCAPcOGezoTIhgAjgex51t4vbg=
github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 h1:+qGGcbkzsfDQNPPe9UDgpxAWQrhbbBXOYJFQDq/dtJw=
github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913/go.mod h1:4aEEwZQutDLsQv2Deui4iYQ6DWTxR14g6m8Wv88+Xqk=
go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4=
go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE=
go.opentelemetry.io/otel/trace v1.11.1 h1:ofxdnzsNrGBYXbP7t7zpUK281+go5rF7dvdIZXF8gdQ=
go.opentelemetry.io/otel/trace v1.11.1/go.mod h1:f/Q9G7vzk5u91PhbmKbg1Qn0rzH1LJ4vbPHFGkTPtOk=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=

View file

@ -57,7 +57,7 @@ type ComplexityRoot struct {
Author func(childComplexity int) int
Contents func(childComplexity int) int
ID func(childComplexity int) int
Replies func(childComplexity int, first uint, after uint) int
Replies func(childComplexity int, first uint, after *uint) int
}
CommentsConnection struct {
@ -84,7 +84,7 @@ type ComplexityRoot struct {
Post struct {
AllowComments func(childComplexity int) int
Author func(childComplexity int) int
Comments func(childComplexity int, first uint, after uint) int
Comments func(childComplexity int, first uint, after *uint) int
Contents func(childComplexity int) int
ID func(childComplexity int) int
Title func(childComplexity int) int
@ -103,23 +103,23 @@ type ComplexityRoot struct {
Query struct {
Comment func(childComplexity int, id uint) int
Post func(childComplexity int, id uint) int
Posts func(childComplexity int, first uint, after uint) int
Posts func(childComplexity int, first uint, after *uint) int
}
}
type CommentResolver interface {
Replies(ctx context.Context, obj *model.Comment, first uint, after uint) (*model.CommentsConnection, error)
Replies(ctx context.Context, obj *model.Comment, first uint, after *uint) (*model.CommentsConnection, error)
}
type MutationResolver interface {
AddPost(ctx context.Context, input model.PostInput) (*model.AddResult, error)
AddComment(ctx context.Context, input model.CommentInput) (*model.AddResult, error)
}
type PostResolver interface {
Comments(ctx context.Context, obj *model.Post, first uint, after uint) (*model.CommentsConnection, error)
Comments(ctx context.Context, obj *model.Post, first uint, after *uint) (*model.CommentsConnection, error)
}
type QueryResolver interface {
Post(ctx context.Context, id uint) (*model.Post, error)
Posts(ctx context.Context, first uint, after uint) (*model.PostsConnection, error)
Posts(ctx context.Context, first uint, after *uint) (*model.PostsConnection, error)
Comment(ctx context.Context, id uint) (*model.Comment, error)
}
@ -180,7 +180,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return 0, false
}
return e.complexity.Comment.Replies(childComplexity, args["first"].(uint), args["after"].(uint)), true
return e.complexity.Comment.Replies(childComplexity, args["first"].(uint), args["after"].(*uint)), true
case "CommentsConnection.edges":
if e.complexity.CommentsConnection.Edges == nil {
@ -279,7 +279,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return 0, false
}
return e.complexity.Post.Comments(childComplexity, args["first"].(uint), args["after"].(uint)), true
return e.complexity.Post.Comments(childComplexity, args["first"].(uint), args["after"].(*uint)), true
case "Post.contents":
if e.complexity.Post.Contents == nil {
@ -364,7 +364,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return 0, false
}
return e.complexity.Query.Posts(childComplexity, args["first"].(uint), args["after"].(uint)), true
return e.complexity.Query.Posts(childComplexity, args["first"].(uint), args["after"].(*uint)), true
}
return 0, false
@ -506,10 +506,10 @@ func (ec *executionContext) field_Comment_replies_args(ctx context.Context, rawA
}
}
args["first"] = arg0
var arg1 uint
var arg1 *uint
if tmp, ok := rawArgs["after"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after"))
arg1, err = ec.unmarshalNID2uint(ctx, tmp)
arg1, err = ec.unmarshalOID2ᚖuint(ctx, tmp)
if err != nil {
return nil, err
}
@ -560,10 +560,10 @@ func (ec *executionContext) field_Post_comments_args(ctx context.Context, rawArg
}
}
args["first"] = arg0
var arg1 uint
var arg1 *uint
if tmp, ok := rawArgs["after"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after"))
arg1, err = ec.unmarshalNID2uint(ctx, tmp)
arg1, err = ec.unmarshalOID2ᚖuint(ctx, tmp)
if err != nil {
return nil, err
}
@ -629,10 +629,10 @@ func (ec *executionContext) field_Query_posts_args(ctx context.Context, rawArgs
}
}
args["first"] = arg0
var arg1 uint
var arg1 *uint
if tmp, ok := rawArgs["after"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after"))
arg1, err = ec.unmarshalNID2uint(ctx, tmp)
arg1, err = ec.unmarshalOID2ᚖuint(ctx, tmp)
if err != nil {
return nil, err
}
@ -869,7 +869,7 @@ func (ec *executionContext) _Comment_replies(ctx context.Context, field graphql.
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Comment().Replies(rctx, obj, fc.Args["first"].(uint), fc.Args["after"].(uint))
return ec.resolvers.Comment().Replies(rctx, obj, fc.Args["first"].(uint), fc.Args["after"].(*uint))
})
if err != nil {
ec.Error(ctx, err)
@ -1556,7 +1556,7 @@ func (ec *executionContext) _Post_comments(ctx context.Context, field graphql.Co
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Post().Comments(rctx, obj, fc.Args["first"].(uint), fc.Args["after"].(uint))
return ec.resolvers.Post().Comments(rctx, obj, fc.Args["first"].(uint), fc.Args["after"].(*uint))
})
if err != nil {
ec.Error(ctx, err)
@ -1934,7 +1934,7 @@ func (ec *executionContext) _Query_posts(ctx context.Context, field graphql.Coll
}()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children
return ec.resolvers.Query().Posts(rctx, fc.Args["first"].(uint), fc.Args["after"].(uint))
return ec.resolvers.Query().Posts(rctx, fc.Args["first"].(uint), fc.Args["after"].(*uint))
})
if err != nil {
ec.Error(ctx, err)

View file

@ -8,31 +8,51 @@ import (
const CommentLengthLimit = 2000
var (
EmptyCommentsConnection = CommentsConnection{
Edges: make([]*CommentsEdge, 0),
PageInfo: &PageInfo{
StartCursor: 0,
EndCursor: 0,
HasNextPage: false,
},
}
EmptyPostsConnections = PostsConnection{
Edges: make([]*PostsEdge, 0),
PageInfo: &PageInfo{
StartCursor: 0,
EndCursor: 0,
HasNextPage: false,
},
}
)
type Comment struct {
// db fields
ID uint `json:"id" gorm:"primarykey"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
ID uint `json:"id" gorm:"primarykey;column:id"`
CreatedAt time.Time `gorm:"column:created_at"`
UpdatedAt time.Time `gorm:"column:updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index;column:deleted_at"`
PostID uint `json:"post_id"`
Author string `json:"author"`
Contents string `json:"contents"`
RootComment bool `gorm:"column:root"`
PostID uint `json:"post_id" gorm:"column:post_id"`
Author string `json:"author" gorm:"column:author"`
Contents string `json:"contents" gorm:"column:contents"`
Replies []*Comment `json:"replies" gorm:"many2many:comment_replies"`
}
type Post struct {
// db fields
ID uint `json:"id" gorm:"primarykey"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
ID uint `json:"id" gorm:"primarykey;column:id"`
CreatedAt time.Time `gorm:"column:created_at"`
UpdatedAt time.Time `gorm:"column:updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index;column:deleted_at"`
Title string `json:"title"`
Author string `json:"author"`
Contents string `json:"contents"`
Title string `json:"title" gorm:"column:title"`
Author string `json:"author" gorm:"column:author"`
Contents string `json:"contents" gorm:"column:contents"`
Comments []*Comment `json:"comments" gorm:"foreignKey:PostID"`
AllowComments bool `json:"allowComments"`
AllowComments bool `json:"allowComments" gorm:"column:allow_comments"`
}
func PostFromInput(input *PostInput) *Post {

View file

@ -11,12 +11,12 @@ import (
)
// Replies is the resolver for the replies field.
func (r *commentResolver) Replies(ctx context.Context, obj *model.Comment, first uint, after uint) (*model.CommentsConnection, error) {
func (r *commentResolver) Replies(ctx context.Context, obj *model.Comment, first uint, after *uint) (*model.CommentsConnection, error) {
return r.Storage.GetReplies(obj, first, after, ctx)
}
// Comments is the resolver for the comments field.
func (r *postResolver) Comments(ctx context.Context, obj *model.Post, first uint, after uint) (*model.CommentsConnection, error) {
func (r *postResolver) Comments(ctx context.Context, obj *model.Post, first uint, after *uint) (*model.CommentsConnection, error) {
return r.Storage.GetComments(obj, first, after, ctx)
}
@ -26,7 +26,7 @@ func (r *queryResolver) Post(ctx context.Context, id uint) (*model.Post, error)
}
// Posts is the resolver for the posts field.
func (r *queryResolver) Posts(ctx context.Context, first uint, after uint) (*model.PostsConnection, error) {
func (r *queryResolver) Posts(ctx context.Context, first uint, after *uint) (*model.PostsConnection, error) {
return r.Storage.GetPosts(first, after, ctx)
}

View file

@ -1,6 +1,6 @@
type Query {
post(id: ID!): Post!
posts(first: ID!, after: ID!): PostsConnection!
posts(first: ID!, after: ID): PostsConnection!
comment(id: ID!): Comment!
}
@ -9,7 +9,7 @@ type Post {
title: String!
author: String!
contents: String!
comments(first: ID!, after: ID!): CommentsConnection!
comments(first: ID!, after: ID): CommentsConnection!
allowComments: Boolean!
}
@ -17,5 +17,5 @@ type Comment {
id: ID!
author: String!
contents: String!
replies(first: ID!, after: ID!): CommentsConnection!
replies(first: ID!, after: ID): CommentsConnection!
}

View file

@ -3,15 +3,32 @@ package db
import (
"context"
"fmt"
"log"
"git.obamna.ru/erius/ozon-task/graph/model"
"github.com/99designs/gqlgen/graphql"
"github.com/vikstrous/dataloadgen"
"gorm.io/gorm"
)
type database struct {
*gorm.DB
}
type Database struct {
db *gorm.DB
db database
// loaders for batch loading and caching of objects
// to prevent n + 1 problem and redundant sql queries during nested pagination
postCommentsLoader *dataloadgen.Loader[uint, []*model.Comment]
commentRepliesLoader *dataloadgen.Loader[uint, []*model.Comment]
}
func InitDatabase(con *gorm.DB) *Database {
db := database{con}
return &Database{
db: db,
postCommentsLoader: dataloadgen.NewLoader(db.fetchComments),
commentRepliesLoader: dataloadgen.NewLoader(db.fetchReplies),
}
}
func (s *Database) AddPost(input *model.PostInput) (*model.AddResult, error) {
@ -22,8 +39,12 @@ func (s *Database) AddPost(input *model.PostInput) (*model.AddResult, error) {
func (s *Database) AddReplyToComment(input *model.CommentInput) (*model.AddResult, error) {
comment := model.CommentFromInput(input)
// multiple operations performed in one transaction
err := s.db.Transaction(func(tx *gorm.DB) error {
err := s.db.Model(&model.Comment{}).Select("post_id").Where("id = ?", *input.ParentCommentID).Scan(&comment.PostID).Error
if err != nil {
return nil, err
}
// multiple mutable operations performed in one transaction
err = s.db.Transaction(func(tx *gorm.DB) error {
// insert comment
err := s.db.Create(comment).Error
if err != nil {
@ -52,13 +73,13 @@ func (s *Database) AddCommentToPost(input *model.CommentInput) (*model.AddResult
if !allowComments {
return nil, fmt.Errorf("author disabled comments for this post")
}
comment.RootComment = true
comment.PostID = *input.ParentPostID
err = s.db.Create(comment).Error
return &model.AddResult{ItemID: &comment.ID}, err
}
func (s *Database) GetPost(id uint, ctx context.Context) (*model.Post, error) {
log.Println(graphql.CollectAllFields(ctx))
log.Println(graphql.CollectFieldsCtx(ctx, nil))
var post model.Post
err := s.db.Find(&post, id).Error
return &post, err
@ -70,14 +91,127 @@ func (s *Database) GetComment(id uint, ctx context.Context) (*model.Comment, err
return &comment, err
}
func (s *Database) GetPosts(first uint, after uint, ctx context.Context) (*model.PostsConnection, error) {
return nil, nil
func (s *Database) GetPosts(first uint, cursor *uint, ctx context.Context) (*model.PostsConnection, error) {
offset := 0
if cursor != nil {
offset = int(*cursor)
}
var posts []*model.Post
res := s.db.Order("id").Limit(int(first)).Offset(offset).Find(&posts)
if res.Error != nil {
return nil, res.Error
}
if res.RowsAffected == 0 {
return &model.EmptyPostsConnections, nil
}
nextPage := true
if res.RowsAffected < int64(first) {
nextPage = false
}
info, edges := &model.PageInfo{
StartCursor: uint(offset),
EndCursor: posts[len(posts)-1].ID,
HasNextPage: nextPage,
}, make([]*model.PostsEdge, len(posts))
for i, p := range posts {
edges[i] = &model.PostsEdge{
Cursor: p.ID,
Node: p,
}
}
return &model.PostsConnection{
Edges: edges,
PageInfo: info,
}, nil
}
func (s *Database) GetComments(post *model.Post, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) {
return nil, nil
func (s *Database) GetComments(post *model.Post, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) {
offset := 0
if cursor != nil {
offset = int(*cursor)
}
res := s.db.Where("post_id = ?", post.ID).Where("root = TRUE").Order("id").Limit(int(first)).Offset(offset).Find(&post.Comments)
if res.Error != nil {
return nil, res.Error
}
if res.RowsAffected == 0 {
return &model.EmptyCommentsConnection, nil
}
nextPage := true
if res.RowsAffected < int64(first) {
nextPage = false
}
info, edges := &model.PageInfo{
StartCursor: uint(offset),
EndCursor: post.Comments[len(post.Comments)-1].ID,
HasNextPage: nextPage,
}, make([]*model.CommentsEdge, len(post.Comments))
for i, c := range post.Comments {
edges[i] = &model.CommentsEdge{
Cursor: c.ID,
Node: c,
}
}
return &model.CommentsConnection{
Edges: edges,
PageInfo: info,
}, nil
}
func (s *Database) GetReplies(comment *model.Comment, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) {
func (s *Database) GetReplies(comment *model.Comment, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) {
offset := 0
if cursor != nil {
offset = int(*cursor)
}
res := s.db.Model(&model.Comment{}).Joins("JOIN comment_replies ON comment_replies.reply_id = id").
Where("comment_id = ?", comment.ID).Order("id").Limit(int(first)).Offset(offset).Find(&comment.Replies)
if res.Error != nil {
return nil, res.Error
}
if res.RowsAffected == 0 {
return &model.EmptyCommentsConnection, nil
}
nextPage := true
if res.RowsAffected < int64(first) {
nextPage = false
}
info, edges := &model.PageInfo{
StartCursor: uint(offset),
EndCursor: comment.Replies[len(comment.Replies)-1].ID,
HasNextPage: nextPage,
}, make([]*model.CommentsEdge, len(comment.Replies))
for i, c := range comment.Replies {
edges[i] = &model.CommentsEdge{
Cursor: c.ID,
Node: c,
}
}
return &model.CommentsConnection{
Edges: edges,
PageInfo: info,
}, nil
}
// TODO: try to fix n + 1 problem by fetching data in bulk with loaders
// this is tricky because we are getting paginated data
// might use sql window functions
func (d *database) fetchComments(ctx context.Context, postIds []uint) ([][]*model.Comment, []error) {
var comments []*model.Comment
err := d.Where("post_id IN (?)", postIds).Where("root = TRUE").Order("post_id").Find(&comments).Error
if err != nil {
return nil, []error{err}
}
postComments := make([][]*model.Comment, 0, len(postIds))
i := 0
for _, c := range comments {
if c.ID != postIds[i] {
i++
}
postComments[i] = append(postComments[i], c)
}
return postComments, nil
}
func (d *database) fetchReplies(ctx context.Context, ids []uint) ([][]*model.Comment, []error) {
return nil, nil
}

View file

@ -32,12 +32,12 @@ func InitPostgres() (*Database, error) {
return nil, err
}
log.Println("opened connection to PostgreSQL database")
log.Println("migrating model scheme to database...")
log.Println("migrating model schema to database...")
err = db.AutoMigrate(&model.Post{}, &model.Comment{})
if err != nil {
log.Printf("failed to automatically migrate model scheme: %s", err)
log.Printf("failed to automatically migrate model schema: %s", err)
return nil, err
}
log.Println("finished migrating model scheme")
return &Database{db}, nil
log.Println("finished migrating model schema")
return InitDatabase(db), nil
}

View file

@ -48,6 +48,7 @@ func (s *InMemory) AddCommentToPost(input *model.CommentInput) (*model.AddResult
return nil, fmt.Errorf("author disabled comments for this post")
}
comment := model.CommentFromInput(input)
comment.RootComment = true
s.insertComment(comment)
parent.Comments = append(parent.Comments, comment)
return &model.AddResult{ItemID: &comment.ID}, nil
@ -67,21 +68,25 @@ func (s *InMemory) GetComment(id uint, ctx context.Context) (*model.Comment, err
return s.comments[id], nil
}
func (s *InMemory) GetPosts(first uint, after uint, ctx context.Context) (*model.PostsConnection, error) {
if !s.postExists(after) {
return nil, &IDNotFoundError{objName: "post", id: after}
func (s *InMemory) GetPosts(first uint, cursor *uint, ctx context.Context) (*model.PostsConnection, error) {
start := uint(0)
if cursor != nil {
start = *cursor + 1
}
nextPage, until := true, after+first
if !s.postExists(start) {
return &model.EmptyPostsConnections, nil
}
nextPage, until := true, start+first
if !s.postExists(until) {
nextPage = false
until = uint(len(s.posts))
}
info, edges := &model.PageInfo{
StartCursor: after,
StartCursor: start,
EndCursor: until - 1,
HasNextPage: nextPage,
}, make([]*model.PostsEdge, until-after)
for i, p := range s.posts[after:until] {
}, make([]*model.PostsEdge, until-start)
for i, p := range s.posts[start:until] {
edges[i] = &model.PostsEdge{
Cursor: p.ID,
Node: p,
@ -93,47 +98,33 @@ func (s *InMemory) GetPosts(first uint, after uint, ctx context.Context) (*model
}, nil
}
func (s *InMemory) GetComments(post *model.Post, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) {
if !s.commentExists(after) {
return nil, &IDNotFoundError{objName: "comment", id: after}
}
nextPage, until := true, after+first
if !s.commentExists(until) {
nextPage = false
until = uint(len(s.comments))
}
info, edges := &model.PageInfo{
StartCursor: after,
EndCursor: until - 1,
HasNextPage: nextPage,
}, make([]*model.CommentsEdge, until-after)
for i, c := range post.Comments[after:until] {
edges[i] = &model.CommentsEdge{
Cursor: c.ID,
Node: c,
}
}
return &model.CommentsConnection{
Edges: edges,
PageInfo: info,
}, nil
func (s *InMemory) GetComments(post *model.Post, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) {
return getCommentsFrom(post.Comments, first, cursor), nil
}
func (s *InMemory) GetReplies(comment *model.Comment, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) {
if !s.commentExists(after) {
return nil, &IDNotFoundError{objName: "comment", id: after}
func (s *InMemory) GetReplies(comment *model.Comment, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) {
return getCommentsFrom(comment.Replies, first, cursor), nil
}
func getCommentsFrom(source []*model.Comment, first uint, cursor *uint) *model.CommentsConnection {
start := uint(0)
if cursor != nil {
start = *cursor + 1
}
nextPage, until := true, after+first
if !s.commentExists(until) {
if start >= uint(len(source)) {
return &model.EmptyCommentsConnection
}
nextPage, until := true, start+first
if until >= uint(len(source)) {
nextPage = false
until = uint(len(s.comments))
until = uint(len(source))
}
info, edges := &model.PageInfo{
StartCursor: after,
StartCursor: start,
EndCursor: until - 1,
HasNextPage: nextPage,
}, make([]*model.CommentsEdge, until-after)
for i, c := range comment.Replies[after:until] {
}, make([]*model.CommentsEdge, until-start)
for i, c := range source[start:until] {
edges[i] = &model.CommentsEdge{
Cursor: c.ID,
Node: c,
@ -142,7 +133,7 @@ func (s *InMemory) GetReplies(comment *model.Comment, first uint, after uint, ct
return &model.CommentsConnection{
Edges: edges,
PageInfo: info,
}, nil
}
}
func (s *InMemory) postExists(id uint) bool {

View file

@ -41,9 +41,9 @@ type Storage interface {
GetComment(id uint, ctx context.Context) (*model.Comment, error)
// returns paginated data in the form of model.*Connection (passing context to prevent overfetching)
GetPosts(first uint, after uint, ctx context.Context) (*model.PostsConnection, error)
GetComments(post *model.Post, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error)
GetReplies(comment *model.Comment, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error)
GetPosts(first uint, cursor *uint, ctx context.Context) (*model.PostsConnection, error)
GetComments(post *model.Post, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error)
GetReplies(comment *model.Comment, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error)
}
func InitStorage() (Storage, error) {