Implemented pagination for posts, comments and replies

TODO: use dataloaders to reduce amount of sql queries (figure out how to batch query nested paginated data) and add some basic unit or integrated testing
This commit is contained in:
erius 2024-06-27 04:28:18 +03:00
parent e0aa12b126
commit 8fce488888
10 changed files with 256 additions and 100 deletions

3
go.mod
View file

@ -5,6 +5,7 @@ go 1.22.4
require ( require (
github.com/99designs/gqlgen v0.17.49 github.com/99designs/gqlgen v0.17.49
github.com/vektah/gqlparser/v2 v2.5.16 github.com/vektah/gqlparser/v2 v2.5.16
github.com/vikstrous/dataloadgen v0.0.6
gorm.io/driver/postgres v1.5.9 gorm.io/driver/postgres v1.5.9
gorm.io/gorm v1.25.10 gorm.io/gorm v1.25.10
) )
@ -28,6 +29,8 @@ require (
github.com/sosodev/duration v1.3.1 // indirect github.com/sosodev/duration v1.3.1 // indirect
github.com/urfave/cli/v2 v2.27.2 // indirect github.com/urfave/cli/v2 v2.27.2 // indirect
github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 // indirect github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 // indirect
go.opentelemetry.io/otel v1.11.1 // indirect
go.opentelemetry.io/otel/trace v1.11.1 // indirect
golang.org/x/crypto v0.17.0 // indirect golang.org/x/crypto v0.17.0 // indirect
golang.org/x/mod v0.18.0 // indirect golang.org/x/mod v0.18.0 // indirect
golang.org/x/sync v0.7.0 // indirect golang.org/x/sync v0.7.0 // indirect

8
go.sum
View file

@ -18,6 +18,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g=
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
@ -61,8 +63,14 @@ github.com/urfave/cli/v2 v2.27.2 h1:6e0H+AkS+zDckwPCUrZkKX38mRaau4nL2uipkJpbkcI=
github.com/urfave/cli/v2 v2.27.2/go.mod h1:g0+79LmHHATl7DAcHO99smiR/T7uGLw84w8Y42x+4eM= github.com/urfave/cli/v2 v2.27.2/go.mod h1:g0+79LmHHATl7DAcHO99smiR/T7uGLw84w8Y42x+4eM=
github.com/vektah/gqlparser/v2 v2.5.16 h1:1gcmLTvs3JLKXckwCwlUagVn/IlV2bwqle0vJ0vy5p8= github.com/vektah/gqlparser/v2 v2.5.16 h1:1gcmLTvs3JLKXckwCwlUagVn/IlV2bwqle0vJ0vy5p8=
github.com/vektah/gqlparser/v2 v2.5.16/go.mod h1:1lz1OeCqgQbQepsGxPVywrjdBHW2T08PUS3pJqepRww= github.com/vektah/gqlparser/v2 v2.5.16/go.mod h1:1lz1OeCqgQbQepsGxPVywrjdBHW2T08PUS3pJqepRww=
github.com/vikstrous/dataloadgen v0.0.6 h1:A7s/fI3QNnH80CA9vdNbWK7AsbLjIxNHpZnV+VnOT1s=
github.com/vikstrous/dataloadgen v0.0.6/go.mod h1:8vuQVpBH0ODbMKAPUdCAPcOGezoTIhgAjgex51t4vbg=
github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 h1:+qGGcbkzsfDQNPPe9UDgpxAWQrhbbBXOYJFQDq/dtJw= github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 h1:+qGGcbkzsfDQNPPe9UDgpxAWQrhbbBXOYJFQDq/dtJw=
github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913/go.mod h1:4aEEwZQutDLsQv2Deui4iYQ6DWTxR14g6m8Wv88+Xqk= github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913/go.mod h1:4aEEwZQutDLsQv2Deui4iYQ6DWTxR14g6m8Wv88+Xqk=
go.opentelemetry.io/otel v1.11.1 h1:4WLLAmcfkmDk2ukNXJyq3/kiz/3UzCaYq6PskJsaou4=
go.opentelemetry.io/otel v1.11.1/go.mod h1:1nNhXBbWSD0nsL38H6btgnFN2k4i0sNLHNNMZMSbUGE=
go.opentelemetry.io/otel/trace v1.11.1 h1:ofxdnzsNrGBYXbP7t7zpUK281+go5rF7dvdIZXF8gdQ=
go.opentelemetry.io/otel/trace v1.11.1/go.mod h1:f/Q9G7vzk5u91PhbmKbg1Qn0rzH1LJ4vbPHFGkTPtOk=
golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k= golang.org/x/crypto v0.17.0 h1:r8bRNjWL3GshPW3gkd+RpvzWrZAwPS49OmTGZ/uhM4k=
golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4=
golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0=

View file

@ -57,7 +57,7 @@ type ComplexityRoot struct {
Author func(childComplexity int) int Author func(childComplexity int) int
Contents func(childComplexity int) int Contents func(childComplexity int) int
ID func(childComplexity int) int ID func(childComplexity int) int
Replies func(childComplexity int, first uint, after uint) int Replies func(childComplexity int, first uint, after *uint) int
} }
CommentsConnection struct { CommentsConnection struct {
@ -84,7 +84,7 @@ type ComplexityRoot struct {
Post struct { Post struct {
AllowComments func(childComplexity int) int AllowComments func(childComplexity int) int
Author func(childComplexity int) int Author func(childComplexity int) int
Comments func(childComplexity int, first uint, after uint) int Comments func(childComplexity int, first uint, after *uint) int
Contents func(childComplexity int) int Contents func(childComplexity int) int
ID func(childComplexity int) int ID func(childComplexity int) int
Title func(childComplexity int) int Title func(childComplexity int) int
@ -103,23 +103,23 @@ type ComplexityRoot struct {
Query struct { Query struct {
Comment func(childComplexity int, id uint) int Comment func(childComplexity int, id uint) int
Post func(childComplexity int, id uint) int Post func(childComplexity int, id uint) int
Posts func(childComplexity int, first uint, after uint) int Posts func(childComplexity int, first uint, after *uint) int
} }
} }
type CommentResolver interface { type CommentResolver interface {
Replies(ctx context.Context, obj *model.Comment, first uint, after uint) (*model.CommentsConnection, error) Replies(ctx context.Context, obj *model.Comment, first uint, after *uint) (*model.CommentsConnection, error)
} }
type MutationResolver interface { type MutationResolver interface {
AddPost(ctx context.Context, input model.PostInput) (*model.AddResult, error) AddPost(ctx context.Context, input model.PostInput) (*model.AddResult, error)
AddComment(ctx context.Context, input model.CommentInput) (*model.AddResult, error) AddComment(ctx context.Context, input model.CommentInput) (*model.AddResult, error)
} }
type PostResolver interface { type PostResolver interface {
Comments(ctx context.Context, obj *model.Post, first uint, after uint) (*model.CommentsConnection, error) Comments(ctx context.Context, obj *model.Post, first uint, after *uint) (*model.CommentsConnection, error)
} }
type QueryResolver interface { type QueryResolver interface {
Post(ctx context.Context, id uint) (*model.Post, error) Post(ctx context.Context, id uint) (*model.Post, error)
Posts(ctx context.Context, first uint, after uint) (*model.PostsConnection, error) Posts(ctx context.Context, first uint, after *uint) (*model.PostsConnection, error)
Comment(ctx context.Context, id uint) (*model.Comment, error) Comment(ctx context.Context, id uint) (*model.Comment, error)
} }
@ -180,7 +180,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return 0, false return 0, false
} }
return e.complexity.Comment.Replies(childComplexity, args["first"].(uint), args["after"].(uint)), true return e.complexity.Comment.Replies(childComplexity, args["first"].(uint), args["after"].(*uint)), true
case "CommentsConnection.edges": case "CommentsConnection.edges":
if e.complexity.CommentsConnection.Edges == nil { if e.complexity.CommentsConnection.Edges == nil {
@ -279,7 +279,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return 0, false return 0, false
} }
return e.complexity.Post.Comments(childComplexity, args["first"].(uint), args["after"].(uint)), true return e.complexity.Post.Comments(childComplexity, args["first"].(uint), args["after"].(*uint)), true
case "Post.contents": case "Post.contents":
if e.complexity.Post.Contents == nil { if e.complexity.Post.Contents == nil {
@ -364,7 +364,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
return 0, false return 0, false
} }
return e.complexity.Query.Posts(childComplexity, args["first"].(uint), args["after"].(uint)), true return e.complexity.Query.Posts(childComplexity, args["first"].(uint), args["after"].(*uint)), true
} }
return 0, false return 0, false
@ -506,10 +506,10 @@ func (ec *executionContext) field_Comment_replies_args(ctx context.Context, rawA
} }
} }
args["first"] = arg0 args["first"] = arg0
var arg1 uint var arg1 *uint
if tmp, ok := rawArgs["after"]; ok { if tmp, ok := rawArgs["after"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after")) ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after"))
arg1, err = ec.unmarshalNID2uint(ctx, tmp) arg1, err = ec.unmarshalOID2ᚖuint(ctx, tmp)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -560,10 +560,10 @@ func (ec *executionContext) field_Post_comments_args(ctx context.Context, rawArg
} }
} }
args["first"] = arg0 args["first"] = arg0
var arg1 uint var arg1 *uint
if tmp, ok := rawArgs["after"]; ok { if tmp, ok := rawArgs["after"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after")) ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after"))
arg1, err = ec.unmarshalNID2uint(ctx, tmp) arg1, err = ec.unmarshalOID2ᚖuint(ctx, tmp)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -629,10 +629,10 @@ func (ec *executionContext) field_Query_posts_args(ctx context.Context, rawArgs
} }
} }
args["first"] = arg0 args["first"] = arg0
var arg1 uint var arg1 *uint
if tmp, ok := rawArgs["after"]; ok { if tmp, ok := rawArgs["after"]; ok {
ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after")) ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after"))
arg1, err = ec.unmarshalNID2uint(ctx, tmp) arg1, err = ec.unmarshalOID2ᚖuint(ctx, tmp)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -869,7 +869,7 @@ func (ec *executionContext) _Comment_replies(ctx context.Context, field graphql.
}() }()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children ctx = rctx // use context from middleware stack in children
return ec.resolvers.Comment().Replies(rctx, obj, fc.Args["first"].(uint), fc.Args["after"].(uint)) return ec.resolvers.Comment().Replies(rctx, obj, fc.Args["first"].(uint), fc.Args["after"].(*uint))
}) })
if err != nil { if err != nil {
ec.Error(ctx, err) ec.Error(ctx, err)
@ -1556,7 +1556,7 @@ func (ec *executionContext) _Post_comments(ctx context.Context, field graphql.Co
}() }()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children ctx = rctx // use context from middleware stack in children
return ec.resolvers.Post().Comments(rctx, obj, fc.Args["first"].(uint), fc.Args["after"].(uint)) return ec.resolvers.Post().Comments(rctx, obj, fc.Args["first"].(uint), fc.Args["after"].(*uint))
}) })
if err != nil { if err != nil {
ec.Error(ctx, err) ec.Error(ctx, err)
@ -1934,7 +1934,7 @@ func (ec *executionContext) _Query_posts(ctx context.Context, field graphql.Coll
}() }()
resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) {
ctx = rctx // use context from middleware stack in children ctx = rctx // use context from middleware stack in children
return ec.resolvers.Query().Posts(rctx, fc.Args["first"].(uint), fc.Args["after"].(uint)) return ec.resolvers.Query().Posts(rctx, fc.Args["first"].(uint), fc.Args["after"].(*uint))
}) })
if err != nil { if err != nil {
ec.Error(ctx, err) ec.Error(ctx, err)

View file

@ -8,31 +8,51 @@ import (
const CommentLengthLimit = 2000 const CommentLengthLimit = 2000
var (
EmptyCommentsConnection = CommentsConnection{
Edges: make([]*CommentsEdge, 0),
PageInfo: &PageInfo{
StartCursor: 0,
EndCursor: 0,
HasNextPage: false,
},
}
EmptyPostsConnections = PostsConnection{
Edges: make([]*PostsEdge, 0),
PageInfo: &PageInfo{
StartCursor: 0,
EndCursor: 0,
HasNextPage: false,
},
}
)
type Comment struct { type Comment struct {
// db fields // db fields
ID uint `json:"id" gorm:"primarykey"` ID uint `json:"id" gorm:"primarykey;column:id"`
CreatedAt time.Time CreatedAt time.Time `gorm:"column:created_at"`
UpdatedAt time.Time UpdatedAt time.Time `gorm:"column:updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index;column:deleted_at"`
PostID uint `json:"post_id"` RootComment bool `gorm:"column:root"`
Author string `json:"author"` PostID uint `json:"post_id" gorm:"column:post_id"`
Contents string `json:"contents"` Author string `json:"author" gorm:"column:author"`
Replies []*Comment `json:"replies" gorm:"many2many:comment_replies"` Contents string `json:"contents" gorm:"column:contents"`
Replies []*Comment `json:"replies" gorm:"many2many:comment_replies"`
} }
type Post struct { type Post struct {
// db fields // db fields
ID uint `json:"id" gorm:"primarykey"` ID uint `json:"id" gorm:"primarykey;column:id"`
CreatedAt time.Time CreatedAt time.Time `gorm:"column:created_at"`
UpdatedAt time.Time UpdatedAt time.Time `gorm:"column:updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index"` DeletedAt gorm.DeletedAt `gorm:"index;column:deleted_at"`
Title string `json:"title"` Title string `json:"title" gorm:"column:title"`
Author string `json:"author"` Author string `json:"author" gorm:"column:author"`
Contents string `json:"contents"` Contents string `json:"contents" gorm:"column:contents"`
Comments []*Comment `json:"comments" gorm:"foreignKey:PostID"` Comments []*Comment `json:"comments" gorm:"foreignKey:PostID"`
AllowComments bool `json:"allowComments"` AllowComments bool `json:"allowComments" gorm:"column:allow_comments"`
} }
func PostFromInput(input *PostInput) *Post { func PostFromInput(input *PostInput) *Post {

View file

@ -11,12 +11,12 @@ import (
) )
// Replies is the resolver for the replies field. // Replies is the resolver for the replies field.
func (r *commentResolver) Replies(ctx context.Context, obj *model.Comment, first uint, after uint) (*model.CommentsConnection, error) { func (r *commentResolver) Replies(ctx context.Context, obj *model.Comment, first uint, after *uint) (*model.CommentsConnection, error) {
return r.Storage.GetReplies(obj, first, after, ctx) return r.Storage.GetReplies(obj, first, after, ctx)
} }
// Comments is the resolver for the comments field. // Comments is the resolver for the comments field.
func (r *postResolver) Comments(ctx context.Context, obj *model.Post, first uint, after uint) (*model.CommentsConnection, error) { func (r *postResolver) Comments(ctx context.Context, obj *model.Post, first uint, after *uint) (*model.CommentsConnection, error) {
return r.Storage.GetComments(obj, first, after, ctx) return r.Storage.GetComments(obj, first, after, ctx)
} }
@ -26,7 +26,7 @@ func (r *queryResolver) Post(ctx context.Context, id uint) (*model.Post, error)
} }
// Posts is the resolver for the posts field. // Posts is the resolver for the posts field.
func (r *queryResolver) Posts(ctx context.Context, first uint, after uint) (*model.PostsConnection, error) { func (r *queryResolver) Posts(ctx context.Context, first uint, after *uint) (*model.PostsConnection, error) {
return r.Storage.GetPosts(first, after, ctx) return r.Storage.GetPosts(first, after, ctx)
} }

View file

@ -1,6 +1,6 @@
type Query { type Query {
post(id: ID!): Post! post(id: ID!): Post!
posts(first: ID!, after: ID!): PostsConnection! posts(first: ID!, after: ID): PostsConnection!
comment(id: ID!): Comment! comment(id: ID!): Comment!
} }
@ -9,7 +9,7 @@ type Post {
title: String! title: String!
author: String! author: String!
contents: String! contents: String!
comments(first: ID!, after: ID!): CommentsConnection! comments(first: ID!, after: ID): CommentsConnection!
allowComments: Boolean! allowComments: Boolean!
} }
@ -17,5 +17,5 @@ type Comment {
id: ID! id: ID!
author: String! author: String!
contents: String! contents: String!
replies(first: ID!, after: ID!): CommentsConnection! replies(first: ID!, after: ID): CommentsConnection!
} }

View file

@ -3,15 +3,32 @@ package db
import ( import (
"context" "context"
"fmt" "fmt"
"log"
"git.obamna.ru/erius/ozon-task/graph/model" "git.obamna.ru/erius/ozon-task/graph/model"
"github.com/99designs/gqlgen/graphql" "github.com/vikstrous/dataloadgen"
"gorm.io/gorm" "gorm.io/gorm"
) )
type database struct {
*gorm.DB
}
type Database struct { type Database struct {
db *gorm.DB db database
// loaders for batch loading and caching of objects
// to prevent n + 1 problem and redundant sql queries during nested pagination
postCommentsLoader *dataloadgen.Loader[uint, []*model.Comment]
commentRepliesLoader *dataloadgen.Loader[uint, []*model.Comment]
}
func InitDatabase(con *gorm.DB) *Database {
db := database{con}
return &Database{
db: db,
postCommentsLoader: dataloadgen.NewLoader(db.fetchComments),
commentRepliesLoader: dataloadgen.NewLoader(db.fetchReplies),
}
} }
func (s *Database) AddPost(input *model.PostInput) (*model.AddResult, error) { func (s *Database) AddPost(input *model.PostInput) (*model.AddResult, error) {
@ -22,8 +39,12 @@ func (s *Database) AddPost(input *model.PostInput) (*model.AddResult, error) {
func (s *Database) AddReplyToComment(input *model.CommentInput) (*model.AddResult, error) { func (s *Database) AddReplyToComment(input *model.CommentInput) (*model.AddResult, error) {
comment := model.CommentFromInput(input) comment := model.CommentFromInput(input)
// multiple operations performed in one transaction err := s.db.Model(&model.Comment{}).Select("post_id").Where("id = ?", *input.ParentCommentID).Scan(&comment.PostID).Error
err := s.db.Transaction(func(tx *gorm.DB) error { if err != nil {
return nil, err
}
// multiple mutable operations performed in one transaction
err = s.db.Transaction(func(tx *gorm.DB) error {
// insert comment // insert comment
err := s.db.Create(comment).Error err := s.db.Create(comment).Error
if err != nil { if err != nil {
@ -52,13 +73,13 @@ func (s *Database) AddCommentToPost(input *model.CommentInput) (*model.AddResult
if !allowComments { if !allowComments {
return nil, fmt.Errorf("author disabled comments for this post") return nil, fmt.Errorf("author disabled comments for this post")
} }
comment.RootComment = true
comment.PostID = *input.ParentPostID
err = s.db.Create(comment).Error err = s.db.Create(comment).Error
return &model.AddResult{ItemID: &comment.ID}, err return &model.AddResult{ItemID: &comment.ID}, err
} }
func (s *Database) GetPost(id uint, ctx context.Context) (*model.Post, error) { func (s *Database) GetPost(id uint, ctx context.Context) (*model.Post, error) {
log.Println(graphql.CollectAllFields(ctx))
log.Println(graphql.CollectFieldsCtx(ctx, nil))
var post model.Post var post model.Post
err := s.db.Find(&post, id).Error err := s.db.Find(&post, id).Error
return &post, err return &post, err
@ -70,14 +91,127 @@ func (s *Database) GetComment(id uint, ctx context.Context) (*model.Comment, err
return &comment, err return &comment, err
} }
func (s *Database) GetPosts(first uint, after uint, ctx context.Context) (*model.PostsConnection, error) { func (s *Database) GetPosts(first uint, cursor *uint, ctx context.Context) (*model.PostsConnection, error) {
return nil, nil offset := 0
if cursor != nil {
offset = int(*cursor)
}
var posts []*model.Post
res := s.db.Order("id").Limit(int(first)).Offset(offset).Find(&posts)
if res.Error != nil {
return nil, res.Error
}
if res.RowsAffected == 0 {
return &model.EmptyPostsConnections, nil
}
nextPage := true
if res.RowsAffected < int64(first) {
nextPage = false
}
info, edges := &model.PageInfo{
StartCursor: uint(offset),
EndCursor: posts[len(posts)-1].ID,
HasNextPage: nextPage,
}, make([]*model.PostsEdge, len(posts))
for i, p := range posts {
edges[i] = &model.PostsEdge{
Cursor: p.ID,
Node: p,
}
}
return &model.PostsConnection{
Edges: edges,
PageInfo: info,
}, nil
} }
func (s *Database) GetComments(post *model.Post, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) { func (s *Database) GetComments(post *model.Post, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) {
return nil, nil offset := 0
if cursor != nil {
offset = int(*cursor)
}
res := s.db.Where("post_id = ?", post.ID).Where("root = TRUE").Order("id").Limit(int(first)).Offset(offset).Find(&post.Comments)
if res.Error != nil {
return nil, res.Error
}
if res.RowsAffected == 0 {
return &model.EmptyCommentsConnection, nil
}
nextPage := true
if res.RowsAffected < int64(first) {
nextPage = false
}
info, edges := &model.PageInfo{
StartCursor: uint(offset),
EndCursor: post.Comments[len(post.Comments)-1].ID,
HasNextPage: nextPage,
}, make([]*model.CommentsEdge, len(post.Comments))
for i, c := range post.Comments {
edges[i] = &model.CommentsEdge{
Cursor: c.ID,
Node: c,
}
}
return &model.CommentsConnection{
Edges: edges,
PageInfo: info,
}, nil
} }
func (s *Database) GetReplies(comment *model.Comment, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) { func (s *Database) GetReplies(comment *model.Comment, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) {
offset := 0
if cursor != nil {
offset = int(*cursor)
}
res := s.db.Model(&model.Comment{}).Joins("JOIN comment_replies ON comment_replies.reply_id = id").
Where("comment_id = ?", comment.ID).Order("id").Limit(int(first)).Offset(offset).Find(&comment.Replies)
if res.Error != nil {
return nil, res.Error
}
if res.RowsAffected == 0 {
return &model.EmptyCommentsConnection, nil
}
nextPage := true
if res.RowsAffected < int64(first) {
nextPage = false
}
info, edges := &model.PageInfo{
StartCursor: uint(offset),
EndCursor: comment.Replies[len(comment.Replies)-1].ID,
HasNextPage: nextPage,
}, make([]*model.CommentsEdge, len(comment.Replies))
for i, c := range comment.Replies {
edges[i] = &model.CommentsEdge{
Cursor: c.ID,
Node: c,
}
}
return &model.CommentsConnection{
Edges: edges,
PageInfo: info,
}, nil
}
// TODO: try to fix n + 1 problem by fetching data in bulk with loaders
// this is tricky because we are getting paginated data
// might use sql window functions
func (d *database) fetchComments(ctx context.Context, postIds []uint) ([][]*model.Comment, []error) {
var comments []*model.Comment
err := d.Where("post_id IN (?)", postIds).Where("root = TRUE").Order("post_id").Find(&comments).Error
if err != nil {
return nil, []error{err}
}
postComments := make([][]*model.Comment, 0, len(postIds))
i := 0
for _, c := range comments {
if c.ID != postIds[i] {
i++
}
postComments[i] = append(postComments[i], c)
}
return postComments, nil
}
func (d *database) fetchReplies(ctx context.Context, ids []uint) ([][]*model.Comment, []error) {
return nil, nil return nil, nil
} }

View file

@ -32,12 +32,12 @@ func InitPostgres() (*Database, error) {
return nil, err return nil, err
} }
log.Println("opened connection to PostgreSQL database") log.Println("opened connection to PostgreSQL database")
log.Println("migrating model scheme to database...") log.Println("migrating model schema to database...")
err = db.AutoMigrate(&model.Post{}, &model.Comment{}) err = db.AutoMigrate(&model.Post{}, &model.Comment{})
if err != nil { if err != nil {
log.Printf("failed to automatically migrate model scheme: %s", err) log.Printf("failed to automatically migrate model schema: %s", err)
return nil, err return nil, err
} }
log.Println("finished migrating model scheme") log.Println("finished migrating model schema")
return &Database{db}, nil return InitDatabase(db), nil
} }

View file

@ -48,6 +48,7 @@ func (s *InMemory) AddCommentToPost(input *model.CommentInput) (*model.AddResult
return nil, fmt.Errorf("author disabled comments for this post") return nil, fmt.Errorf("author disabled comments for this post")
} }
comment := model.CommentFromInput(input) comment := model.CommentFromInput(input)
comment.RootComment = true
s.insertComment(comment) s.insertComment(comment)
parent.Comments = append(parent.Comments, comment) parent.Comments = append(parent.Comments, comment)
return &model.AddResult{ItemID: &comment.ID}, nil return &model.AddResult{ItemID: &comment.ID}, nil
@ -67,21 +68,25 @@ func (s *InMemory) GetComment(id uint, ctx context.Context) (*model.Comment, err
return s.comments[id], nil return s.comments[id], nil
} }
func (s *InMemory) GetPosts(first uint, after uint, ctx context.Context) (*model.PostsConnection, error) { func (s *InMemory) GetPosts(first uint, cursor *uint, ctx context.Context) (*model.PostsConnection, error) {
if !s.postExists(after) { start := uint(0)
return nil, &IDNotFoundError{objName: "post", id: after} if cursor != nil {
start = *cursor + 1
} }
nextPage, until := true, after+first if !s.postExists(start) {
return &model.EmptyPostsConnections, nil
}
nextPage, until := true, start+first
if !s.postExists(until) { if !s.postExists(until) {
nextPage = false nextPage = false
until = uint(len(s.posts)) until = uint(len(s.posts))
} }
info, edges := &model.PageInfo{ info, edges := &model.PageInfo{
StartCursor: after, StartCursor: start,
EndCursor: until - 1, EndCursor: until - 1,
HasNextPage: nextPage, HasNextPage: nextPage,
}, make([]*model.PostsEdge, until-after) }, make([]*model.PostsEdge, until-start)
for i, p := range s.posts[after:until] { for i, p := range s.posts[start:until] {
edges[i] = &model.PostsEdge{ edges[i] = &model.PostsEdge{
Cursor: p.ID, Cursor: p.ID,
Node: p, Node: p,
@ -93,47 +98,33 @@ func (s *InMemory) GetPosts(first uint, after uint, ctx context.Context) (*model
}, nil }, nil
} }
func (s *InMemory) GetComments(post *model.Post, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) { func (s *InMemory) GetComments(post *model.Post, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) {
if !s.commentExists(after) { return getCommentsFrom(post.Comments, first, cursor), nil
return nil, &IDNotFoundError{objName: "comment", id: after}
}
nextPage, until := true, after+first
if !s.commentExists(until) {
nextPage = false
until = uint(len(s.comments))
}
info, edges := &model.PageInfo{
StartCursor: after,
EndCursor: until - 1,
HasNextPage: nextPage,
}, make([]*model.CommentsEdge, until-after)
for i, c := range post.Comments[after:until] {
edges[i] = &model.CommentsEdge{
Cursor: c.ID,
Node: c,
}
}
return &model.CommentsConnection{
Edges: edges,
PageInfo: info,
}, nil
} }
func (s *InMemory) GetReplies(comment *model.Comment, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) { func (s *InMemory) GetReplies(comment *model.Comment, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error) {
if !s.commentExists(after) { return getCommentsFrom(comment.Replies, first, cursor), nil
return nil, &IDNotFoundError{objName: "comment", id: after} }
func getCommentsFrom(source []*model.Comment, first uint, cursor *uint) *model.CommentsConnection {
start := uint(0)
if cursor != nil {
start = *cursor + 1
} }
nextPage, until := true, after+first if start >= uint(len(source)) {
if !s.commentExists(until) { return &model.EmptyCommentsConnection
}
nextPage, until := true, start+first
if until >= uint(len(source)) {
nextPage = false nextPage = false
until = uint(len(s.comments)) until = uint(len(source))
} }
info, edges := &model.PageInfo{ info, edges := &model.PageInfo{
StartCursor: after, StartCursor: start,
EndCursor: until - 1, EndCursor: until - 1,
HasNextPage: nextPage, HasNextPage: nextPage,
}, make([]*model.CommentsEdge, until-after) }, make([]*model.CommentsEdge, until-start)
for i, c := range comment.Replies[after:until] { for i, c := range source[start:until] {
edges[i] = &model.CommentsEdge{ edges[i] = &model.CommentsEdge{
Cursor: c.ID, Cursor: c.ID,
Node: c, Node: c,
@ -142,7 +133,7 @@ func (s *InMemory) GetReplies(comment *model.Comment, first uint, after uint, ct
return &model.CommentsConnection{ return &model.CommentsConnection{
Edges: edges, Edges: edges,
PageInfo: info, PageInfo: info,
}, nil }
} }
func (s *InMemory) postExists(id uint) bool { func (s *InMemory) postExists(id uint) bool {

View file

@ -41,9 +41,9 @@ type Storage interface {
GetComment(id uint, ctx context.Context) (*model.Comment, error) GetComment(id uint, ctx context.Context) (*model.Comment, error)
// returns paginated data in the form of model.*Connection (passing context to prevent overfetching) // returns paginated data in the form of model.*Connection (passing context to prevent overfetching)
GetPosts(first uint, after uint, ctx context.Context) (*model.PostsConnection, error) GetPosts(first uint, cursor *uint, ctx context.Context) (*model.PostsConnection, error)
GetComments(post *model.Post, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) GetComments(post *model.Post, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error)
GetReplies(comment *model.Comment, first uint, after uint, ctx context.Context) (*model.CommentsConnection, error) GetReplies(comment *model.Comment, first uint, cursor *uint, ctx context.Context) (*model.CommentsConnection, error)
} }
func InitStorage() (Storage, error) { func InitStorage() (Storage, error) {