diff --git a/README.md b/README.md index d243e20..3720a1c 100644 --- a/README.md +++ b/README.md @@ -413,7 +413,7 @@ New addresses must be distinct among themselves (case-insensitive). | Status | Condition | | ------ | --------- | -| `400` | Empty `add_to`, duplicate addresses, or target message has no `pid` | +| `400` | Empty `add_to` or duplicate addresses | | `403` | Authenticated user is not an existing participant (sender or `to` recipient) | | `404` | Message not found | diff --git a/src/handlers/messages.go b/src/handlers/messages.go index c974df5..b3f0170 100644 --- a/src/handlers/messages.go +++ b/src/handlers/messages.go @@ -938,12 +938,13 @@ func (h *MessageHandler) AddRecipients(c *gin.Context) { ctx := c.Request.Context() - // Load the message to verify it exists. + // Load the message to verify it exists. This message's :id becomes the pid + // of the outgoing fmsg messages sent to the new recipients, so the message's + // own pid is irrelevant here and need not be looked up. var fromAddr string - var pid *int64 err := h.DB.Pool.QueryRow(ctx, - "SELECT from_addr, pid FROM msg WHERE id = $1", msgID, - ).Scan(&fromAddr, &pid) + "SELECT from_addr FROM msg WHERE id = $1", msgID, + ).Scan(&fromAddr) if err != nil { if errors.Is(err, pgx.ErrNoRows) { c.JSON(http.StatusNotFound, gin.H{"error": "message not found"}) @@ -954,12 +955,6 @@ func (h *MessageHandler) AddRecipients(c *gin.Context) { return } - // add_to is only valid on replies (messages with a pid). - if pid == nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "add_to is only valid when pid is supplied"}) - return - } - // Verify the requester is an existing participant (from or msg_to). if fromAddr != identity { var recipientCount int @@ -977,16 +972,41 @@ func (h *MessageHandler) AddRecipients(c *gin.Context) { return } - // Insert the new add_to recipients. + // Insert the new add_to recipients and record who added them. Both run in a + // single transaction so a partial failure leaves the message unchanged. + tx, err := h.DB.Pool.Begin(ctx) + if err != nil { + log.Printf("add recipients: begin tx for msg %d: %v", msgID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to add recipients"}) + return + } + defer tx.Rollback(ctx) + for _, addr := range input.AddTo { - if _, err = h.DB.Pool.Exec(ctx, + if _, err = tx.Exec(ctx, "INSERT INTO msg_add_to (msg_id, addr) VALUES ($1, $2) ON CONFLICT DO NOTHING", msgID, addr, ); err != nil { - log.Printf("add recipients: insert %s: %v", addr, err) + log.Printf("add recipients: insert %s into msg %d: %v", addr, msgID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to add recipients"}) + return } } + if _, err = tx.Exec(ctx, + "UPDATE msg SET add_to_from = $1 WHERE id = $2", identity, msgID, + ); err != nil { + log.Printf("add recipients: update add_to_from for msg %d: %v", msgID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to add recipients"}) + return + } + + if err = tx.Commit(ctx); err != nil { + log.Printf("add recipients: commit tx for msg %d: %v", msgID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to add recipients"}) + return + } + c.JSON(http.StatusOK, gin.H{"id": msgID, "added": len(input.AddTo)}) }