blob: 755d8ba3b1bb67ef24dd142c4bbcd90f4ff00082 [file] [log] [blame]
package jsonpatch
import (
"bytes"
"encoding/json"
"fmt"
"strconv"
"strings"
)
const (
eRaw = iota
eDoc
eAry
)
type lazyNode struct {
raw *json.RawMessage
doc partialDoc
ary partialArray
which int
}
type operation map[string]*json.RawMessage
// Patch is an ordered collection of operations.
type Patch []operation
type partialDoc map[string]*lazyNode
type partialArray []*lazyNode
type container interface {
get(key string) (*lazyNode, error)
set(key string, val *lazyNode) error
add(key string, val *lazyNode) error
remove(key string) error
}
func newLazyNode(raw *json.RawMessage) *lazyNode {
return &lazyNode{raw: raw, doc: nil, ary: nil, which: eRaw}
}
func (n *lazyNode) MarshalJSON() ([]byte, error) {
switch n.which {
case eRaw:
return json.Marshal(n.raw)
case eDoc:
return json.Marshal(n.doc)
case eAry:
return json.Marshal(n.ary)
default:
return nil, fmt.Errorf("Unknown type")
}
}
func (n *lazyNode) UnmarshalJSON(data []byte) error {
dest := make(json.RawMessage, len(data))
copy(dest, data)
n.raw = &dest
n.which = eRaw
return nil
}
func (n *lazyNode) intoDoc() (*partialDoc, error) {
if n.which == eDoc {
return &n.doc, nil
}
if n.raw == nil {
return nil, fmt.Errorf("Unable to unmarshal nil pointer as partial document")
}
err := json.Unmarshal(*n.raw, &n.doc)
if err != nil {
return nil, err
}
n.which = eDoc
return &n.doc, nil
}
func (n *lazyNode) intoAry() (*partialArray, error) {
if n.which == eAry {
return &n.ary, nil
}
if n.raw == nil {
return nil, fmt.Errorf("Unable to unmarshal nil pointer as partial array")
}
err := json.Unmarshal(*n.raw, &n.ary)
if err != nil {
return nil, err
}
n.which = eAry
return &n.ary, nil
}
func (n *lazyNode) compact() []byte {
buf := &bytes.Buffer{}
if n.raw == nil {
return nil
}
err := json.Compact(buf, *n.raw)
if err != nil {
return *n.raw
}
return buf.Bytes()
}
func (n *lazyNode) tryDoc() bool {
if n.raw == nil {
return false
}
err := json.Unmarshal(*n.raw, &n.doc)
if err != nil {
return false
}
n.which = eDoc
return true
}
func (n *lazyNode) tryAry() bool {
if n.raw == nil {
return false
}
err := json.Unmarshal(*n.raw, &n.ary)
if err != nil {
return false
}
n.which = eAry
return true
}
func (n *lazyNode) equal(o *lazyNode) bool {
if n.which == eRaw {
if !n.tryDoc() && !n.tryAry() {
if o.which != eRaw {
return false
}
return bytes.Equal(n.compact(), o.compact())
}
}
if n.which == eDoc {
if o.which == eRaw {
if !o.tryDoc() {
return false
}
}
if o.which != eDoc {
return false
}
for k, v := range n.doc {
ov, ok := o.doc[k]
if !ok {
return false
}
if v == nil && ov == nil {
continue
}
if !v.equal(ov) {
return false
}
}
return true
}
if o.which != eAry && !o.tryAry() {
return false
}
if len(n.ary) != len(o.ary) {
return false
}
for idx, val := range n.ary {
if !val.equal(o.ary[idx]) {
return false
}
}
return true
}
func (o operation) kind() string {
if obj, ok := o["op"]; ok {
var op string
err := json.Unmarshal(*obj, &op)
if err != nil {
return "unknown"
}
return op
}
return "unknown"
}
func (o operation) path() string {
if obj, ok := o["path"]; ok {
var op string
err := json.Unmarshal(*obj, &op)
if err != nil {
return "unknown"
}
return op
}
return "unknown"
}
func (o operation) from() string {
if obj, ok := o["from"]; ok {
var op string
err := json.Unmarshal(*obj, &op)
if err != nil {
return "unknown"
}
return op
}
return "unknown"
}
func (o operation) value() *lazyNode {
if obj, ok := o["value"]; ok {
return newLazyNode(obj)
}
return nil
}
func isArray(buf []byte) bool {
Loop:
for _, c := range buf {
switch c {
case ' ':
case '\n':
case '\t':
continue
case '[':
return true
default:
break Loop
}
}
return false
}
func findObject(pd *container, path string) (container, string) {
doc := *pd
split := strings.Split(path, "/")
if len(split) < 2 {
return nil, ""
}
parts := split[1 : len(split)-1]
key := split[len(split)-1]
var err error
for _, part := range parts {
next, ok := doc.get(decodePatchKey(part))
if next == nil || ok != nil {
return nil, ""
}
if isArray(*next.raw) {
doc, err = next.intoAry()
if err != nil {
return nil, ""
}
} else {
doc, err = next.intoDoc()
if err != nil {
return nil, ""
}
}
}
return doc, decodePatchKey(key)
}
func (d *partialDoc) set(key string, val *lazyNode) error {
(*d)[key] = val
return nil
}
func (d *partialDoc) add(key string, val *lazyNode) error {
(*d)[key] = val
return nil
}
func (d *partialDoc) get(key string) (*lazyNode, error) {
return (*d)[key], nil
}
func (d *partialDoc) remove(key string) error {
_, ok := (*d)[key]
if !ok {
return fmt.Errorf("Unable to remove nonexistent key: %s", key)
}
delete(*d, key)
return nil
}
func (d *partialArray) set(key string, val *lazyNode) error {
if key == "-" {
*d = append(*d, val)
return nil
}
idx, err := strconv.Atoi(key)
if err != nil {
return err
}
sz := len(*d)
if idx+1 > sz {
sz = idx + 1
}
ary := make([]*lazyNode, sz)
cur := *d
copy(ary, cur)
if idx >= len(ary) {
return fmt.Errorf("Unable to access invalid index: %d", idx)
}
ary[idx] = val
*d = ary
return nil
}
func (d *partialArray) add(key string, val *lazyNode) error {
if key == "-" {
*d = append(*d, val)
return nil
}
idx, err := strconv.Atoi(key)
if err != nil {
return err
}
ary := make([]*lazyNode, len(*d)+1)
cur := *d
if idx < 0 {
idx *= -1
if idx > len(ary) {
return fmt.Errorf("Unable to access invalid index: %d", idx)
}
idx = len(ary) - idx
}
copy(ary[0:idx], cur[0:idx])
ary[idx] = val
copy(ary[idx+1:], cur[idx:])
*d = ary
return nil
}
func (d *partialArray) get(key string) (*lazyNode, error) {
idx, err := strconv.Atoi(key)
if err != nil {
return nil, err
}
if idx >= len(*d) {
return nil, fmt.Errorf("Unable to access invalid index: %d", idx)
}
return (*d)[idx], nil
}
func (d *partialArray) remove(key string) error {
idx, err := strconv.Atoi(key)
if err != nil {
return err
}
cur := *d
if idx >= len(cur) {
return fmt.Errorf("Unable to remove invalid index: %d", idx)
}
ary := make([]*lazyNode, len(cur)-1)
copy(ary[0:idx], cur[0:idx])
copy(ary[idx:], cur[idx+1:])
*d = ary
return nil
}
func (p Patch) add(doc *container, op operation) error {
path := op.path()
con, key := findObject(doc, path)
if con == nil {
return fmt.Errorf("jsonpatch add operation does not apply: doc is missing path: %s", path)
}
return con.add(key, op.value())
}
func (p Patch) remove(doc *container, op operation) error {
path := op.path()
con, key := findObject(doc, path)
if con == nil {
return fmt.Errorf("jsonpatch remove operation does not apply: doc is missing path: %s", path)
}
return con.remove(key)
}
func (p Patch) replace(doc *container, op operation) error {
path := op.path()
con, key := findObject(doc, path)
if con == nil {
return fmt.Errorf("jsonpatch replace operation does not apply: doc is missing path: %s", path)
}
val, ok := con.get(key)
if val == nil || ok != nil {
return fmt.Errorf("jsonpatch replace operation does not apply: doc is missing key: %s", path)
}
return con.set(key, op.value())
}
func (p Patch) move(doc *container, op operation) error {
from := op.from()
con, key := findObject(doc, from)
if con == nil {
return fmt.Errorf("jsonpatch move operation does not apply: doc is missing from path: %s", from)
}
val, err := con.get(key)
if err != nil {
return err
}
err = con.remove(key)
if err != nil {
return err
}
path := op.path()
con, key = findObject(doc, path)
if con == nil {
return fmt.Errorf("jsonpatch move operation does not apply: doc is missing destination path: %s", path)
}
return con.set(key, val)
}
func (p Patch) test(doc *container, op operation) error {
path := op.path()
con, key := findObject(doc, path)
if con == nil {
return fmt.Errorf("jsonpatch test operation does not apply: is missing path: %s", path)
}
val, err := con.get(key)
if err != nil {
return err
}
if val == nil {
if op.value().raw == nil {
return nil
}
return fmt.Errorf("Testing value %s failed", path)
}
if val.equal(op.value()) {
return nil
}
return fmt.Errorf("Testing value %s failed", path)
}
func (p Patch) copy(doc *container, op operation) error {
from := op.from()
con, key := findObject(doc, from)
if con == nil {
return fmt.Errorf("jsonpatch copy operation does not apply: doc is missing from path: %s", from)
}
val, err := con.get(key)
if err != nil {
return err
}
path := op.path()
con, key = findObject(doc, path)
if con == nil {
return fmt.Errorf("jsonpatch copy operation does not apply: doc is missing destination path: %s", path)
}
return con.set(key, val)
}
// Equal indicates if 2 JSON documents have the same structural equality.
func Equal(a, b []byte) bool {
ra := make(json.RawMessage, len(a))
copy(ra, a)
la := newLazyNode(&ra)
rb := make(json.RawMessage, len(b))
copy(rb, b)
lb := newLazyNode(&rb)
return la.equal(lb)
}
// DecodePatch decodes the passed JSON document as an RFC 6902 patch.
func DecodePatch(buf []byte) (Patch, error) {
var p Patch
err := json.Unmarshal(buf, &p)
if err != nil {
return nil, err
}
return p, nil
}
// Apply mutates a JSON document according to the patch, and returns the new
// document.
func (p Patch) Apply(doc []byte) ([]byte, error) {
return p.ApplyIndent(doc, "")
}
// ApplyIndent mutates a JSON document according to the patch, and returns the new
// document indented.
func (p Patch) ApplyIndent(doc []byte, indent string) ([]byte, error) {
var pd container
if doc[0] == '[' {
pd = &partialArray{}
} else {
pd = &partialDoc{}
}
err := json.Unmarshal(doc, pd)
if err != nil {
return nil, err
}
err = nil
for _, op := range p {
switch op.kind() {
case "add":
err = p.add(&pd, op)
case "remove":
err = p.remove(&pd, op)
case "replace":
err = p.replace(&pd, op)
case "move":
err = p.move(&pd, op)
case "test":
err = p.test(&pd, op)
case "copy":
err = p.copy(&pd, op)
default:
err = fmt.Errorf("Unexpected kind: %s", op.kind())
}
if err != nil {
return nil, err
}
}
if indent != "" {
return json.MarshalIndent(pd, "", indent)
}
return json.Marshal(pd)
}
// From http://tools.ietf.org/html/rfc6901#section-4 :
//
// Evaluation of each reference token begins by decoding any escaped
// character sequence. This is performed by first transforming any
// occurrence of the sequence '~1' to '/', and then transforming any
// occurrence of the sequence '~0' to '~'.
var (
rfc6901Decoder = strings.NewReplacer("~1", "/", "~0", "~")
)
func decodePatchKey(k string) string {
return rfc6901Decoder.Replace(k)
}