blob: 93a018741b2263bed92ed6ec739d47b8dc825c65 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package imfg
import (
"bytes"
"encoding/binary"
"fmt"
"sort"
"strings"
"github.com/apache/mynewt-artifact/errors"
"github.com/apache/mynewt-artifact/flash"
"github.com/apache/mynewt-artifact/mfg"
"mynewt.apache.org/imgmod/iutil"
)
type NameBlobMap map[string][]byte
func (to NameBlobMap) Union(from NameBlobMap) {
for k, v := range from {
to[k] = v
}
}
func errInvalidArea(areaName string, format string,
args ...interface{}) error {
suffix := fmt.Sprintf(format, args...)
return errors.Errorf("invalid flash area \"%s\": %s", areaName, suffix)
}
func verifyArea(area flash.FlashArea, minOffset int) error {
if area.Offset < minOffset {
return errInvalidArea(area.Name, "invalid offset %d; expected >= %d",
area.Offset, minOffset)
}
if area.Size < 0 {
return errInvalidArea(area.Name, "invalid size %d", area.Size)
}
return nil
}
// `areas` must be sorted by device ID, then by offset.
func VerifyAreas(areas []flash.FlashArea) error {
prevDevice := -1
off := 0
for _, area := range areas {
if area.Device != prevDevice {
off = 0
}
if err := verifyArea(area, off); err != nil {
return err
}
off += area.Size
}
return nil
}
func Split(mfgBin []byte, deviceNum int,
areas []flash.FlashArea, eraseVal byte) (NameBlobMap, error) {
mm := NameBlobMap{}
for _, area := range areas {
if _, ok := mm[area.Name]; ok {
return nil, errors.Errorf(
"two or more flash areas with same name: \"%s\"", area.Name)
}
if area.Device == deviceNum {
var areaBin []byte
if area.Offset < len(mfgBin) {
end := area.Offset + area.Size
overflow := end - len(mfgBin)
if overflow > 0 {
end -= overflow
}
areaBin = mfgBin[area.Offset:end]
}
mm[area.Name] = StripPadding(areaBin, eraseVal)
}
}
return mm, nil
}
// `areas` must be sorted by device ID, then by offset.
func Join(mm NameBlobMap, eraseVal byte,
areas []flash.FlashArea) ([]byte, error) {
// Keep track of which areas we haven't seen yet.
unseen := map[string]struct{}{}
for name, _ := range mm {
unseen[name] = struct{}{}
}
joined := []byte{}
for _, area := range areas {
bin := mm[area.Name]
// Only include this area if it belongs to the mfg image we are
// joining.
if bin != nil {
delete(unseen, area.Name)
// Pad remainder of previous area in this section.
padSize := area.Offset - len(joined)
if padSize > 0 {
joined = mfg.AddPadding(joined, eraseVal, padSize)
}
// Append data to joined binary.
binstr := ""
if len(bin) >= 4 {
binstr = fmt.Sprintf("%x", bin[:4])
}
iutil.Printf("inserting %s (%s) at offset %d (0x%x)\n",
area.Name, binstr, len(joined), len(joined))
joined = append(joined, bin...)
}
}
// Ensure we processed every area in the map.
if len(unseen) > 0 {
names := []string{}
for name, _ := range unseen {
names = append(names, name)
}
sort.Strings(names)
return nil, errors.Errorf(
"unprocessed flash areas: %s", strings.Join(names, ", "))
}
// Strip padding from the end of the joined binary.
joined = StripPadding(joined, eraseVal)
return joined, nil
}
func replaceKey(mfgBin []byte, okey []byte, nkey []byte) (int, error) {
if len(okey) > len(mfgBin) {
return 0, errors.Errorf(
"key longer than flash section (%d > %d)", len(okey), len(mfgBin))
}
idx := bytes.Index(mfgBin, okey)
if idx == -1 {
return 0, errors.Errorf("old key not present in flash section")
}
lastIdx := bytes.LastIndex(mfgBin, okey)
if idx != lastIdx {
return 0, errors.Errorf(
"multiple instances of old key in flash section")
}
iutil.PrintfVerbose("Replacing key at offset %d\n", idx)
copy(mfgBin[idx:idx+len(okey)], nkey)
return idx, nil
}
func ReplaceIsk(mfgBin []byte, okey []byte, nkey []byte) error {
if len(nkey) != len(okey) {
return errors.Errorf(
"key lengths differ (%d != %d)", len(nkey), len(okey))
}
if _, err := replaceKey(mfgBin, okey, nkey); err != nil {
return err
}
return nil
}
func ReplaceKek(mfgBin []byte, okey []byte, nkey []byte) error {
if len(nkey) > len(okey) {
return errors.Errorf(
"new key longer than old key (%d > %d)", len(nkey), len(okey))
}
keyIdx, err := replaceKey(mfgBin, okey, nkey)
if err != nil {
return err
}
// The key length is an unsigned int immediately prior to the key.
var kl uint32
klIdx := keyIdx - 4
buf := bytes.NewBuffer(mfgBin[klIdx : klIdx+4])
if err := binary.Read(buf, binary.LittleEndian, &kl); err != nil {
return errors.Wrapf(err, "failed to read key length")
}
if int(kl) != len(okey) {
return errors.Errorf(
"embedded key length (off=%d) has unexpected value; "+
"want=%d have=%d",
klIdx, len(okey), kl)
}
buf = &bytes.Buffer{}
kl = uint32(len(nkey))
if err := binary.Write(buf, binary.LittleEndian, kl); err != nil {
return errors.Wrapf(err, "failed to write key length")
}
copy(mfgBin[klIdx:klIdx+4], buf.Bytes())
return nil
}
func StripPadding(b []byte, eraseVal byte) []byte {
var pad int
for pad = 0; pad < len(b); pad++ {
off := len(b) - pad - 1
if b[off] != eraseVal {
break
}
}
return b[:len(b)-pad]
}