diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/Makefile | 2 | ||||
-rw-r--r-- | lib/libfdt/Makefile | 3 | ||||
-rw-r--r-- | lib/libfdt/fdt_addresses.c | 55 | ||||
-rw-r--r-- | lib/libfdt/fdt_rw.c | 6 | ||||
-rw-r--r-- | lib/libfdt/fdt_sw.c | 32 | ||||
-rw-r--r-- | lib/libfdt/libfdt_internal.h | 6 | ||||
-rw-r--r-- | lib/linux_compat.c | 47 | ||||
-rw-r--r-- | lib/list_sort.c | 298 | ||||
-rw-r--r-- | lib/lmb.c | 5 | ||||
-rw-r--r-- | lib/rbtree.c | 684 | ||||
-rw-r--r-- | lib/rsa/rsa-sign.c | 61 | ||||
-rw-r--r-- | lib/rsa/rsa-verify.c | 97 |
12 files changed, 1023 insertions, 273 deletions
diff --git a/lib/Makefile b/lib/Makefile index 68210a59b73..320197a5209 100644 --- a/lib/Makefile +++ b/lib/Makefile @@ -43,6 +43,7 @@ obj-y += strmhz.o obj-$(CONFIG_TPM) += tpm.o obj-$(CONFIG_RBTREE) += rbtree.o obj-$(CONFIG_BITREVERSE) += bitrev.o +obj-y += list_sort.o endif ifdef CONFIG_SPL_BUILD @@ -58,6 +59,7 @@ obj-y += crc32.o obj-y += ctype.o obj-y += div64.o obj-y += hang.o +obj-y += linux_compat.o obj-y += linux_string.o obj-$(CONFIG_REGEX) += slre.o obj-y += string.o diff --git a/lib/libfdt/Makefile b/lib/libfdt/Makefile index a02c9b02add..6fe79e0b06e 100644 --- a/lib/libfdt/Makefile +++ b/lib/libfdt/Makefile @@ -5,7 +5,8 @@ # SPDX-License-Identifier: GPL-2.0+ # -COBJS-libfdt += fdt.o fdt_ro.o fdt_rw.o fdt_strerror.o fdt_sw.o fdt_wip.o fdt_empty_tree.o +COBJS-libfdt += fdt.o fdt_ro.o fdt_rw.o fdt_strerror.o fdt_sw.o fdt_wip.o \ + fdt_empty_tree.o fdt_addresses.o obj-$(CONFIG_OF_LIBFDT) += $(COBJS-libfdt) obj-$(CONFIG_FIT) += $(COBJS-libfdt) diff --git a/lib/libfdt/fdt_addresses.c b/lib/libfdt/fdt_addresses.c new file mode 100644 index 00000000000..76054d98e5f --- /dev/null +++ b/lib/libfdt/fdt_addresses.c @@ -0,0 +1,55 @@ +/* + * libfdt - Flat Device Tree manipulation + * Copyright (C) 2014 David Gibson <david@gibson.dropbear.id.au> + * SPDX-License-Identifier: GPL-2.0+ BSD-2-Clause + */ +#include "libfdt_env.h" + +#ifndef USE_HOSTCC +#include <fdt.h> +#include <libfdt.h> +#else +#include "fdt_host.h" +#endif + +#include "libfdt_internal.h" + +int fdt_address_cells(const void *fdt, int nodeoffset) +{ + const fdt32_t *ac; + int val; + int len; + + ac = fdt_getprop(fdt, nodeoffset, "#address-cells", &len); + if (!ac) + return 2; + + if (len != sizeof(*ac)) + return -FDT_ERR_BADNCELLS; + + val = fdt32_to_cpu(*ac); + if ((val <= 0) || (val > FDT_MAX_NCELLS)) + return -FDT_ERR_BADNCELLS; + + return val; +} + +int fdt_size_cells(const void *fdt, int nodeoffset) +{ + const fdt32_t *sc; + int val; + int len; + + sc = fdt_getprop(fdt, nodeoffset, "#size-cells", &len); + if (!sc) + return 2; + + if (len != sizeof(*sc)) + return -FDT_ERR_BADNCELLS; + + val = fdt32_to_cpu(*sc); + if ((val < 0) || (val > FDT_MAX_NCELLS)) + return -FDT_ERR_BADNCELLS; + + return val; +} diff --git a/lib/libfdt/fdt_rw.c b/lib/libfdt/fdt_rw.c index 6fa4f13073d..bec8b8ad891 100644 --- a/lib/libfdt/fdt_rw.c +++ b/lib/libfdt/fdt_rw.c @@ -43,9 +43,9 @@ static int _fdt_rw_check_header(void *fdt) #define FDT_RW_CHECK_HEADER(fdt) \ { \ - int err; \ - if ((err = _fdt_rw_check_header(fdt)) != 0) \ - return err; \ + int __err; \ + if ((__err = _fdt_rw_check_header(fdt)) != 0) \ + return __err; \ } static inline int _fdt_data_size(void *fdt) diff --git a/lib/libfdt/fdt_sw.c b/lib/libfdt/fdt_sw.c index 580b57024ff..320a914991a 100644 --- a/lib/libfdt/fdt_sw.c +++ b/lib/libfdt/fdt_sw.c @@ -62,6 +62,38 @@ int fdt_create(void *buf, int bufsize) return 0; } +int fdt_resize(void *fdt, void *buf, int bufsize) +{ + size_t headsize, tailsize; + char *oldtail, *newtail; + + FDT_SW_CHECK_HEADER(fdt); + + headsize = fdt_off_dt_struct(fdt); + tailsize = fdt_size_dt_strings(fdt); + + if ((headsize + tailsize) > bufsize) + return -FDT_ERR_NOSPACE; + + oldtail = (char *)fdt + fdt_totalsize(fdt) - tailsize; + newtail = (char *)buf + bufsize - tailsize; + + /* Two cases to avoid clobbering data if the old and new + * buffers partially overlap */ + if (buf <= fdt) { + memmove(buf, fdt, headsize); + memmove(newtail, oldtail, tailsize); + } else { + memmove(newtail, oldtail, tailsize); + memmove(buf, fdt, headsize); + } + + fdt_set_off_dt_strings(buf, bufsize); + fdt_set_totalsize(buf, bufsize); + + return 0; +} + int fdt_add_reservemap_entry(void *fdt, uint64_t addr, uint64_t size) { struct fdt_reserve_entry *re; diff --git a/lib/libfdt/libfdt_internal.h b/lib/libfdt/libfdt_internal.h index 13cbc9af2ab..9a79fe85dda 100644 --- a/lib/libfdt/libfdt_internal.h +++ b/lib/libfdt/libfdt_internal.h @@ -12,9 +12,9 @@ #define FDT_CHECK_HEADER(fdt) \ { \ - int err; \ - if ((err = fdt_check_header(fdt)) != 0) \ - return err; \ + int __err; \ + if ((__err = fdt_check_header(fdt)) != 0) \ + return __err; \ } int _fdt_check_node_offset(const void *fdt, int offset); diff --git a/lib/linux_compat.c b/lib/linux_compat.c new file mode 100644 index 00000000000..a3d4675f7ed --- /dev/null +++ b/lib/linux_compat.c @@ -0,0 +1,47 @@ + +#include <common.h> +#include <linux/compat.h> + +struct p_current cur = { + .pid = 1, +}; +__maybe_unused struct p_current *current = &cur; + +unsigned long copy_from_user(void *dest, const void *src, + unsigned long count) +{ + memcpy((void *)dest, (void *)src, count); + return 0; +} + +void *kmalloc(size_t size, int flags) +{ + return memalign(ARCH_DMA_MINALIGN, size); +} + +void *kzalloc(size_t size, int flags) +{ + void *ptr = kmalloc(size, flags); + memset(ptr, 0, size); + return ptr; +} + +void *vzalloc(unsigned long size) +{ + return kzalloc(size, 0); +} + +struct kmem_cache *get_mem(int element_sz) +{ + struct kmem_cache *ret; + + ret = memalign(ARCH_DMA_MINALIGN, sizeof(struct kmem_cache)); + ret->sz = element_sz; + + return ret; +} + +void *kmem_cache_alloc(struct kmem_cache *obj, int flag) +{ + return memalign(ARCH_DMA_MINALIGN, obj->sz); +} diff --git a/lib/list_sort.c b/lib/list_sort.c new file mode 100644 index 00000000000..81de0a17de8 --- /dev/null +++ b/lib/list_sort.c @@ -0,0 +1,298 @@ +#define __UBOOT__ +#ifndef __UBOOT__ +#include <linux/kernel.h> +#include <linux/module.h> +#include <linux/slab.h> +#else +#include <linux/compat.h> +#include <common.h> +#include <malloc.h> +#endif +#include <linux/list.h> +#include <linux/list_sort.h> + +#define MAX_LIST_LENGTH_BITS 20 + +/* + * Returns a list organized in an intermediate format suited + * to chaining of merge() calls: null-terminated, no reserved or + * sentinel head node, "prev" links not maintained. + */ +static struct list_head *merge(void *priv, + int (*cmp)(void *priv, struct list_head *a, + struct list_head *b), + struct list_head *a, struct list_head *b) +{ + struct list_head head, *tail = &head; + + while (a && b) { + /* if equal, take 'a' -- important for sort stability */ + if ((*cmp)(priv, a, b) <= 0) { + tail->next = a; + a = a->next; + } else { + tail->next = b; + b = b->next; + } + tail = tail->next; + } + tail->next = a?:b; + return head.next; +} + +/* + * Combine final list merge with restoration of standard doubly-linked + * list structure. This approach duplicates code from merge(), but + * runs faster than the tidier alternatives of either a separate final + * prev-link restoration pass, or maintaining the prev links + * throughout. + */ +static void merge_and_restore_back_links(void *priv, + int (*cmp)(void *priv, struct list_head *a, + struct list_head *b), + struct list_head *head, + struct list_head *a, struct list_head *b) +{ + struct list_head *tail = head; + + while (a && b) { + /* if equal, take 'a' -- important for sort stability */ + if ((*cmp)(priv, a, b) <= 0) { + tail->next = a; + a->prev = tail; + a = a->next; + } else { + tail->next = b; + b->prev = tail; + b = b->next; + } + tail = tail->next; + } + tail->next = a ? : b; + + do { + /* + * In worst cases this loop may run many iterations. + * Continue callbacks to the client even though no + * element comparison is needed, so the client's cmp() + * routine can invoke cond_resched() periodically. + */ + (*cmp)(priv, tail->next, tail->next); + + tail->next->prev = tail; + tail = tail->next; + } while (tail->next); + + tail->next = head; + head->prev = tail; +} + +/** + * list_sort - sort a list + * @priv: private data, opaque to list_sort(), passed to @cmp + * @head: the list to sort + * @cmp: the elements comparison function + * + * This function implements "merge sort", which has O(nlog(n)) + * complexity. + * + * The comparison function @cmp must return a negative value if @a + * should sort before @b, and a positive value if @a should sort after + * @b. If @a and @b are equivalent, and their original relative + * ordering is to be preserved, @cmp must return 0. + */ +void list_sort(void *priv, struct list_head *head, + int (*cmp)(void *priv, struct list_head *a, + struct list_head *b)) +{ + struct list_head *part[MAX_LIST_LENGTH_BITS+1]; /* sorted partial lists + -- last slot is a sentinel */ + int lev; /* index into part[] */ + int max_lev = 0; + struct list_head *list; + + if (list_empty(head)) + return; + + memset(part, 0, sizeof(part)); + + head->prev->next = NULL; + list = head->next; + + while (list) { + struct list_head *cur = list; + list = list->next; + cur->next = NULL; + + for (lev = 0; part[lev]; lev++) { + cur = merge(priv, cmp, part[lev], cur); + part[lev] = NULL; + } + if (lev > max_lev) { + if (unlikely(lev >= ARRAY_SIZE(part)-1)) { + printk_once(KERN_DEBUG "list passed to" + " list_sort() too long for" + " efficiency\n"); + lev--; + } + max_lev = lev; + } + part[lev] = cur; + } + + for (lev = 0; lev < max_lev; lev++) + if (part[lev]) + list = merge(priv, cmp, part[lev], list); + + merge_and_restore_back_links(priv, cmp, head, part[max_lev], list); +} +EXPORT_SYMBOL(list_sort); + +#ifdef CONFIG_TEST_LIST_SORT + +#include <linux/random.h> + +/* + * The pattern of set bits in the list length determines which cases + * are hit in list_sort(). + */ +#define TEST_LIST_LEN (512+128+2) /* not including head */ + +#define TEST_POISON1 0xDEADBEEF +#define TEST_POISON2 0xA324354C + +struct debug_el { + unsigned int poison1; + struct list_head list; + unsigned int poison2; + int value; + unsigned serial; +}; + +/* Array, containing pointers to all elements in the test list */ +static struct debug_el **elts __initdata; + +static int __init check(struct debug_el *ela, struct debug_el *elb) +{ + if (ela->serial >= TEST_LIST_LEN) { + printk(KERN_ERR "list_sort_test: error: incorrect serial %d\n", + ela->serial); + return -EINVAL; + } + if (elb->serial >= TEST_LIST_LEN) { + printk(KERN_ERR "list_sort_test: error: incorrect serial %d\n", + elb->serial); + return -EINVAL; + } + if (elts[ela->serial] != ela || elts[elb->serial] != elb) { + printk(KERN_ERR "list_sort_test: error: phantom element\n"); + return -EINVAL; + } + if (ela->poison1 != TEST_POISON1 || ela->poison2 != TEST_POISON2) { + printk(KERN_ERR "list_sort_test: error: bad poison: %#x/%#x\n", + ela->poison1, ela->poison2); + return -EINVAL; + } + if (elb->poison1 != TEST_POISON1 || elb->poison2 != TEST_POISON2) { + printk(KERN_ERR "list_sort_test: error: bad poison: %#x/%#x\n", + elb->poison1, elb->poison2); + return -EINVAL; + } + return 0; +} + +static int __init cmp(void *priv, struct list_head *a, struct list_head *b) +{ + struct debug_el *ela, *elb; + + ela = container_of(a, struct debug_el, list); + elb = container_of(b, struct debug_el, list); + + check(ela, elb); + return ela->value - elb->value; +} + +static int __init list_sort_test(void) +{ + int i, count = 1, err = -EINVAL; + struct debug_el *el; + struct list_head *cur, *tmp; + LIST_HEAD(head); + + printk(KERN_DEBUG "list_sort_test: start testing list_sort()\n"); + + elts = kmalloc(sizeof(void *) * TEST_LIST_LEN, GFP_KERNEL); + if (!elts) { + printk(KERN_ERR "list_sort_test: error: cannot allocate " + "memory\n"); + goto exit; + } + + for (i = 0; i < TEST_LIST_LEN; i++) { + el = kmalloc(sizeof(*el), GFP_KERNEL); + if (!el) { + printk(KERN_ERR "list_sort_test: error: cannot " + "allocate memory\n"); + goto exit; + } + /* force some equivalencies */ + el->value = prandom_u32() % (TEST_LIST_LEN / 3); + el->serial = i; + el->poison1 = TEST_POISON1; + el->poison2 = TEST_POISON2; + elts[i] = el; + list_add_tail(&el->list, &head); + } + + list_sort(NULL, &head, cmp); + + for (cur = head.next; cur->next != &head; cur = cur->next) { + struct debug_el *el1; + int cmp_result; + + if (cur->next->prev != cur) { + printk(KERN_ERR "list_sort_test: error: list is " + "corrupted\n"); + goto exit; + } + + cmp_result = cmp(NULL, cur, cur->next); + if (cmp_result > 0) { + printk(KERN_ERR "list_sort_test: error: list is not " + "sorted\n"); + goto exit; + } + + el = container_of(cur, struct debug_el, list); + el1 = container_of(cur->next, struct debug_el, list); + if (cmp_result == 0 && el->serial >= el1->serial) { + printk(KERN_ERR "list_sort_test: error: order of " + "equivalent elements not preserved\n"); + goto exit; + } + + if (check(el, el1)) { + printk(KERN_ERR "list_sort_test: error: element check " + "failed\n"); + goto exit; + } + count++; + } + + if (count != TEST_LIST_LEN) { + printk(KERN_ERR "list_sort_test: error: bad list length %d", + count); + goto exit; + } + + err = 0; +exit: + kfree(elts); + list_for_each_safe(cur, tmp, &head) { + list_del(cur); + kfree(container_of(cur, struct debug_el, list)); + } + return err; +} +module_init(list_sort_test); +#endif /* CONFIG_TEST_LIST_SORT */ diff --git a/lib/lmb.c b/lib/lmb.c index 49a3c9e01e5..41a2be46356 100644 --- a/lib/lmb.c +++ b/lib/lmb.c @@ -295,7 +295,10 @@ phys_addr_t __lmb_alloc_base(struct lmb *lmb, phys_size_t size, ulong align, phy if (max_addr == LMB_ALLOC_ANYWHERE) base = lmb_align_down(lmbbase + lmbsize - size, align); else if (lmbbase < max_addr) { - base = min(lmbbase + lmbsize, max_addr); + base = lmbbase + lmbsize; + if (base < lmbbase) + base = -1; + base = min(base, max_addr); base = lmb_align_down(base - size, align); } else continue; diff --git a/lib/rbtree.c b/lib/rbtree.c index b05f1ab7f59..9e52f70d173 100644 --- a/lib/rbtree.c +++ b/lib/rbtree.c @@ -2,283 +2,412 @@ Red Black Trees (C) 1999 Andrea Arcangeli <andrea@suse.de> (C) 2002 David Woodhouse <dwmw2@infradead.org> + (C) 2012 Michel Lespinasse <walken@google.com> * SPDX-License-Identifier: GPL-2.0+ linux/lib/rbtree.c */ +#define __UBOOT__ +#include <linux/rbtree_augmented.h> +#ifndef __UBOOT__ +#include <linux/export.h> +#else #include <ubi_uboot.h> -#include <linux/rbtree.h> +#endif +/* + * red-black trees properties: http://en.wikipedia.org/wiki/Rbtree + * + * 1) A node is either red or black + * 2) The root is black + * 3) All leaves (NULL) are black + * 4) Both children of every red node are black + * 5) Every simple path from root to leaves contains the same number + * of black nodes. + * + * 4 and 5 give the O(log n) guarantee, since 4 implies you cannot have two + * consecutive red nodes in a path and every red node is therefore followed by + * a black. So if B is the number of black nodes on every simple path (as per + * 5), then the longest possible path due to 4 is 2B. + * + * We shall indicate color with case, where black nodes are uppercase and red + * nodes will be lowercase. Unknown color nodes shall be drawn as red within + * parentheses and have some accompanying text comment. + */ -static void __rb_rotate_left(struct rb_node *node, struct rb_root *root) +static inline void rb_set_black(struct rb_node *rb) { - struct rb_node *right = node->rb_right; - struct rb_node *parent = rb_parent(node); - - if ((node->rb_right = right->rb_left)) - rb_set_parent(right->rb_left, node); - right->rb_left = node; - - rb_set_parent(right, parent); - - if (parent) - { - if (node == parent->rb_left) - parent->rb_left = right; - else - parent->rb_right = right; - } - else - root->rb_node = right; - rb_set_parent(node, right); + rb->__rb_parent_color |= RB_BLACK; } -static void __rb_rotate_right(struct rb_node *node, struct rb_root *root) +static inline struct rb_node *rb_red_parent(struct rb_node *red) { - struct rb_node *left = node->rb_left; - struct rb_node *parent = rb_parent(node); - - if ((node->rb_left = left->rb_right)) - rb_set_parent(left->rb_right, node); - left->rb_right = node; - - rb_set_parent(left, parent); + return (struct rb_node *)red->__rb_parent_color; +} - if (parent) - { - if (node == parent->rb_right) - parent->rb_right = left; - else - parent->rb_left = left; - } - else - root->rb_node = left; - rb_set_parent(node, left); +/* + * Helper function for rotations: + * - old's parent and color get assigned to new + * - old gets assigned new as a parent and 'color' as a color. + */ +static inline void +__rb_rotate_set_parents(struct rb_node *old, struct rb_node *new, + struct rb_root *root, int color) +{ + struct rb_node *parent = rb_parent(old); + new->__rb_parent_color = old->__rb_parent_color; + rb_set_parent_color(old, new, color); + __rb_change_child(old, new, parent, root); } -void rb_insert_color(struct rb_node *node, struct rb_root *root) +static __always_inline void +__rb_insert(struct rb_node *node, struct rb_root *root, + void (*augment_rotate)(struct rb_node *old, struct rb_node *new)) { - struct rb_node *parent, *gparent; - - while ((parent = rb_parent(node)) && rb_is_red(parent)) - { - gparent = rb_parent(parent); - - if (parent == gparent->rb_left) - { - { - register struct rb_node *uncle = gparent->rb_right; - if (uncle && rb_is_red(uncle)) - { - rb_set_black(uncle); - rb_set_black(parent); - rb_set_red(gparent); - node = gparent; - continue; - } + struct rb_node *parent = rb_red_parent(node), *gparent, *tmp; + + while (true) { + /* + * Loop invariant: node is red + * + * If there is a black parent, we are done. + * Otherwise, take some corrective action as we don't + * want a red root or two consecutive red nodes. + */ + if (!parent) { + rb_set_parent_color(node, NULL, RB_BLACK); + break; + } else if (rb_is_black(parent)) + break; + + gparent = rb_red_parent(parent); + + tmp = gparent->rb_right; + if (parent != tmp) { /* parent == gparent->rb_left */ + if (tmp && rb_is_red(tmp)) { + /* + * Case 1 - color flips + * + * G g + * / \ / \ + * p u --> P U + * / / + * n N + * + * However, since g's parent might be red, and + * 4) does not allow this, we need to recurse + * at g. + */ + rb_set_parent_color(tmp, gparent, RB_BLACK); + rb_set_parent_color(parent, gparent, RB_BLACK); + node = gparent; + parent = rb_parent(node); + rb_set_parent_color(node, parent, RB_RED); + continue; } - if (parent->rb_right == node) - { - register struct rb_node *tmp; - __rb_rotate_left(parent, root); - tmp = parent; + tmp = parent->rb_right; + if (node == tmp) { + /* + * Case 2 - left rotate at parent + * + * G G + * / \ / \ + * p U --> n U + * \ / + * n p + * + * This still leaves us in violation of 4), the + * continuation into Case 3 will fix that. + */ + parent->rb_right = tmp = node->rb_left; + node->rb_left = parent; + if (tmp) + rb_set_parent_color(tmp, parent, + RB_BLACK); + rb_set_parent_color(parent, node, RB_RED); + augment_rotate(parent, node); parent = node; - node = tmp; + tmp = node->rb_right; } - rb_set_black(parent); - rb_set_red(gparent); - __rb_rotate_right(gparent, root); + /* + * Case 3 - right rotate at gparent + * + * G P + * / \ / \ + * p U --> n g + * / \ + * n U + */ + gparent->rb_left = tmp; /* == parent->rb_right */ + parent->rb_right = gparent; + if (tmp) + rb_set_parent_color(tmp, gparent, RB_BLACK); + __rb_rotate_set_parents(gparent, parent, root, RB_RED); + augment_rotate(gparent, parent); + break; } else { - { - register struct rb_node *uncle = gparent->rb_left; - if (uncle && rb_is_red(uncle)) - { - rb_set_black(uncle); - rb_set_black(parent); - rb_set_red(gparent); - node = gparent; - continue; - } + tmp = gparent->rb_left; + if (tmp && rb_is_red(tmp)) { + /* Case 1 - color flips */ + rb_set_parent_color(tmp, gparent, RB_BLACK); + rb_set_parent_color(parent, gparent, RB_BLACK); + node = gparent; + parent = rb_parent(node); + rb_set_parent_color(node, parent, RB_RED); + continue; } - if (parent->rb_left == node) - { - register struct rb_node *tmp; - __rb_rotate_right(parent, root); - tmp = parent; + tmp = parent->rb_left; + if (node == tmp) { + /* Case 2 - right rotate at parent */ + parent->rb_left = tmp = node->rb_right; + node->rb_right = parent; + if (tmp) + rb_set_parent_color(tmp, parent, + RB_BLACK); + rb_set_parent_color(parent, node, RB_RED); + augment_rotate(parent, node); parent = node; - node = tmp; + tmp = node->rb_left; } - rb_set_black(parent); - rb_set_red(gparent); - __rb_rotate_left(gparent, root); + /* Case 3 - left rotate at gparent */ + gparent->rb_right = tmp; /* == parent->rb_left */ + parent->rb_left = gparent; + if (tmp) + rb_set_parent_color(tmp, gparent, RB_BLACK); + __rb_rotate_set_parents(gparent, parent, root, RB_RED); + augment_rotate(gparent, parent); + break; } } - - rb_set_black(root->rb_node); } -static void __rb_erase_color(struct rb_node *node, struct rb_node *parent, - struct rb_root *root) +/* + * Inline version for rb_erase() use - we want to be able to inline + * and eliminate the dummy_rotate callback there + */ +static __always_inline void +____rb_erase_color(struct rb_node *parent, struct rb_root *root, + void (*augment_rotate)(struct rb_node *old, struct rb_node *new)) { - struct rb_node *other; - - while ((!node || rb_is_black(node)) && node != root->rb_node) - { - if (parent->rb_left == node) - { - other = parent->rb_right; - if (rb_is_red(other)) - { - rb_set_black(other); - rb_set_red(parent); - __rb_rotate_left(parent, root); - other = parent->rb_right; - } - if ((!other->rb_left || rb_is_black(other->rb_left)) && - (!other->rb_right || rb_is_black(other->rb_right))) - { - rb_set_red(other); - node = parent; - parent = rb_parent(node); + struct rb_node *node = NULL, *sibling, *tmp1, *tmp2; + + while (true) { + /* + * Loop invariants: + * - node is black (or NULL on first iteration) + * - node is not the root (parent is not NULL) + * - All leaf paths going through parent and node have a + * black node count that is 1 lower than other leaf paths. + */ + sibling = parent->rb_right; + if (node != sibling) { /* node == parent->rb_left */ + if (rb_is_red(sibling)) { + /* + * Case 1 - left rotate at parent + * + * P S + * / \ / \ + * N s --> p Sr + * / \ / \ + * Sl Sr N Sl + */ + parent->rb_right = tmp1 = sibling->rb_left; + sibling->rb_left = parent; + rb_set_parent_color(tmp1, parent, RB_BLACK); + __rb_rotate_set_parents(parent, sibling, root, + RB_RED); + augment_rotate(parent, sibling); + sibling = tmp1; } - else - { - if (!other->rb_right || rb_is_black(other->rb_right)) - { - struct rb_node *o_left; - if ((o_left = other->rb_left)) - rb_set_black(o_left); - rb_set_red(other); - __rb_rotate_right(other, root); - other = parent->rb_right; + tmp1 = sibling->rb_right; + if (!tmp1 || rb_is_black(tmp1)) { + tmp2 = sibling->rb_left; + if (!tmp2 || rb_is_black(tmp2)) { + /* + * Case 2 - sibling color flip + * (p could be either color here) + * + * (p) (p) + * / \ / \ + * N S --> N s + * / \ / \ + * Sl Sr Sl Sr + * + * This leaves us violating 5) which + * can be fixed by flipping p to black + * if it was red, or by recursing at p. + * p is red when coming from Case 1. + */ + rb_set_parent_color(sibling, parent, + RB_RED); + if (rb_is_red(parent)) + rb_set_black(parent); + else { + node = parent; + parent = rb_parent(node); + if (parent) + continue; + } + break; } - rb_set_color(other, rb_color(parent)); - rb_set_black(parent); - if (other->rb_right) - rb_set_black(other->rb_right); - __rb_rotate_left(parent, root); - node = root->rb_node; - break; + /* + * Case 3 - right rotate at sibling + * (p could be either color here) + * + * (p) (p) + * / \ / \ + * N S --> N Sl + * / \ \ + * sl Sr s + * \ + * Sr + */ + sibling->rb_left = tmp1 = tmp2->rb_right; + tmp2->rb_right = sibling; + parent->rb_right = tmp2; + if (tmp1) + rb_set_parent_color(tmp1, sibling, + RB_BLACK); + augment_rotate(sibling, tmp2); + tmp1 = sibling; + sibling = tmp2; } - } - else - { - other = parent->rb_left; - if (rb_is_red(other)) - { - rb_set_black(other); - rb_set_red(parent); - __rb_rotate_right(parent, root); - other = parent->rb_left; - } - if ((!other->rb_left || rb_is_black(other->rb_left)) && - (!other->rb_right || rb_is_black(other->rb_right))) - { - rb_set_red(other); - node = parent; - parent = rb_parent(node); + /* + * Case 4 - left rotate at parent + color flips + * (p and sl could be either color here. + * After rotation, p becomes black, s acquires + * p's color, and sl keeps its color) + * + * (p) (s) + * / \ / \ + * N S --> P Sr + * / \ / \ + * (sl) sr N (sl) + */ + parent->rb_right = tmp2 = sibling->rb_left; + sibling->rb_left = parent; + rb_set_parent_color(tmp1, sibling, RB_BLACK); + if (tmp2) + rb_set_parent(tmp2, parent); + __rb_rotate_set_parents(parent, sibling, root, + RB_BLACK); + augment_rotate(parent, sibling); + break; + } else { + sibling = parent->rb_left; + if (rb_is_red(sibling)) { + /* Case 1 - right rotate at parent */ + parent->rb_left = tmp1 = sibling->rb_right; + sibling->rb_right = parent; + rb_set_parent_color(tmp1, parent, RB_BLACK); + __rb_rotate_set_parents(parent, sibling, root, + RB_RED); + augment_rotate(parent, sibling); + sibling = tmp1; } - else - { - if (!other->rb_left || rb_is_black(other->rb_left)) - { - register struct rb_node *o_right; - if ((o_right = other->rb_right)) - rb_set_black(o_right); - rb_set_red(other); - __rb_rotate_left(other, root); - other = parent->rb_left; + tmp1 = sibling->rb_left; + if (!tmp1 || rb_is_black(tmp1)) { + tmp2 = sibling->rb_right; + if (!tmp2 || rb_is_black(tmp2)) { + /* Case 2 - sibling color flip */ + rb_set_parent_color(sibling, parent, + RB_RED); + if (rb_is_red(parent)) + rb_set_black(parent); + else { + node = parent; + parent = rb_parent(node); + if (parent) + continue; + } + break; } - rb_set_color(other, rb_color(parent)); - rb_set_black(parent); - if (other->rb_left) - rb_set_black(other->rb_left); - __rb_rotate_right(parent, root); - node = root->rb_node; - break; + /* Case 3 - right rotate at sibling */ + sibling->rb_right = tmp1 = tmp2->rb_left; + tmp2->rb_left = sibling; + parent->rb_left = tmp2; + if (tmp1) + rb_set_parent_color(tmp1, sibling, + RB_BLACK); + augment_rotate(sibling, tmp2); + tmp1 = sibling; + sibling = tmp2; } + /* Case 4 - left rotate at parent + color flips */ + parent->rb_left = tmp2 = sibling->rb_right; + sibling->rb_right = parent; + rb_set_parent_color(tmp1, sibling, RB_BLACK); + if (tmp2) + rb_set_parent(tmp2, parent); + __rb_rotate_set_parents(parent, sibling, root, + RB_BLACK); + augment_rotate(parent, sibling); + break; } } - if (node) - rb_set_black(node); } +/* Non-inline version for rb_erase_augmented() use */ +void __rb_erase_color(struct rb_node *parent, struct rb_root *root, + void (*augment_rotate)(struct rb_node *old, struct rb_node *new)) +{ + ____rb_erase_color(parent, root, augment_rotate); +} +EXPORT_SYMBOL(__rb_erase_color); + +/* + * Non-augmented rbtree manipulation functions. + * + * We use dummy augmented callbacks here, and have the compiler optimize them + * out of the rb_insert_color() and rb_erase() function definitions. + */ + +static inline void dummy_propagate(struct rb_node *node, struct rb_node *stop) {} +static inline void dummy_copy(struct rb_node *old, struct rb_node *new) {} +static inline void dummy_rotate(struct rb_node *old, struct rb_node *new) {} + +static const struct rb_augment_callbacks dummy_callbacks = { + dummy_propagate, dummy_copy, dummy_rotate +}; + +void rb_insert_color(struct rb_node *node, struct rb_root *root) +{ + __rb_insert(node, root, dummy_rotate); +} +EXPORT_SYMBOL(rb_insert_color); + void rb_erase(struct rb_node *node, struct rb_root *root) { - struct rb_node *child, *parent; - int color; - - if (!node->rb_left) - child = node->rb_right; - else if (!node->rb_right) - child = node->rb_left; - else - { - struct rb_node *old = node, *left; - - node = node->rb_right; - while ((left = node->rb_left) != NULL) - node = left; - child = node->rb_right; - parent = rb_parent(node); - color = rb_color(node); - - if (child) - rb_set_parent(child, parent); - if (parent == old) { - parent->rb_right = child; - parent = node; - } else - parent->rb_left = child; - - node->rb_parent_color = old->rb_parent_color; - node->rb_right = old->rb_right; - node->rb_left = old->rb_left; - - if (rb_parent(old)) - { - if (rb_parent(old)->rb_left == old) - rb_parent(old)->rb_left = node; - else - rb_parent(old)->rb_right = node; - } else - root->rb_node = node; - - rb_set_parent(old->rb_left, node); - if (old->rb_right) - rb_set_parent(old->rb_right, node); - goto color; - } + struct rb_node *rebalance; + rebalance = __rb_erase_augmented(node, root, &dummy_callbacks); + if (rebalance) + ____rb_erase_color(rebalance, root, dummy_rotate); +} +EXPORT_SYMBOL(rb_erase); - parent = rb_parent(node); - color = rb_color(node); - - if (child) - rb_set_parent(child, parent); - if (parent) - { - if (parent->rb_left == node) - parent->rb_left = child; - else - parent->rb_right = child; - } - else - root->rb_node = child; +/* + * Augmented rbtree manipulation functions. + * + * This instantiates the same __always_inline functions as in the non-augmented + * case, but this time with user-defined callbacks. + */ - color: - if (color == RB_BLACK) - __rb_erase_color(child, parent, root); +void __rb_insert_augmented(struct rb_node *node, struct rb_root *root, + void (*augment_rotate)(struct rb_node *old, struct rb_node *new)) +{ + __rb_insert(node, root, augment_rotate); } +EXPORT_SYMBOL(__rb_insert_augmented); /* * This function returns the first node (in sort order) of the tree. */ -struct rb_node *rb_first(struct rb_root *root) +struct rb_node *rb_first(const struct rb_root *root) { struct rb_node *n; @@ -289,8 +418,9 @@ struct rb_node *rb_first(struct rb_root *root) n = n->rb_left; return n; } +EXPORT_SYMBOL(rb_first); -struct rb_node *rb_last(struct rb_root *root) +struct rb_node *rb_last(const struct rb_root *root) { struct rb_node *n; @@ -301,58 +431,68 @@ struct rb_node *rb_last(struct rb_root *root) n = n->rb_right; return n; } +EXPORT_SYMBOL(rb_last); -struct rb_node *rb_next(struct rb_node *node) +struct rb_node *rb_next(const struct rb_node *node) { struct rb_node *parent; - if (rb_parent(node) == node) + if (RB_EMPTY_NODE(node)) return NULL; - /* If we have a right-hand child, go down and then left as far - as we can. */ + /* + * If we have a right-hand child, go down and then left as far + * as we can. + */ if (node->rb_right) { - node = node->rb_right; + node = node->rb_right; while (node->rb_left) node=node->rb_left; - return node; + return (struct rb_node *)node; } - /* No right-hand children. Everything down and left is - smaller than us, so any 'next' node must be in the general - direction of our parent. Go up the tree; any time the - ancestor is a right-hand child of its parent, keep going - up. First time it's a left-hand child of its parent, said - parent is our 'next' node. */ + /* + * No right-hand children. Everything down and left is smaller than us, + * so any 'next' node must be in the general direction of our parent. + * Go up the tree; any time the ancestor is a right-hand child of its + * parent, keep going up. First time it's a left-hand child of its + * parent, said parent is our 'next' node. + */ while ((parent = rb_parent(node)) && node == parent->rb_right) node = parent; return parent; } +EXPORT_SYMBOL(rb_next); -struct rb_node *rb_prev(struct rb_node *node) +struct rb_node *rb_prev(const struct rb_node *node) { struct rb_node *parent; - if (rb_parent(node) == node) + if (RB_EMPTY_NODE(node)) return NULL; - /* If we have a left-hand child, go down and then right as far - as we can. */ + /* + * If we have a left-hand child, go down and then right as far + * as we can. + */ if (node->rb_left) { - node = node->rb_left; + node = node->rb_left; while (node->rb_right) node=node->rb_right; - return node; + return (struct rb_node *)node; } - /* No left-hand children. Go up till we find an ancestor which - is a right-hand child of its parent */ + /* + * No left-hand children. Go up till we find an ancestor which + * is a right-hand child of its parent. + */ while ((parent = rb_parent(node)) && node == parent->rb_left) node = parent; return parent; } +EXPORT_SYMBOL(rb_prev); void rb_replace_node(struct rb_node *victim, struct rb_node *new, struct rb_root *root) @@ -360,14 +500,7 @@ void rb_replace_node(struct rb_node *victim, struct rb_node *new, struct rb_node *parent = rb_parent(victim); /* Set the surrounding nodes to point to the replacement */ - if (parent) { - if (victim == parent->rb_left) - parent->rb_left = new; - else - parent->rb_right = new; - } else { - root->rb_node = new; - } + __rb_change_child(victim, new, parent, root); if (victim->rb_left) rb_set_parent(victim->rb_left, new); if (victim->rb_right) @@ -376,3 +509,44 @@ void rb_replace_node(struct rb_node *victim, struct rb_node *new, /* Copy the pointers/colour from the victim to the replacement */ *new = *victim; } +EXPORT_SYMBOL(rb_replace_node); + +static struct rb_node *rb_left_deepest_node(const struct rb_node *node) +{ + for (;;) { + if (node->rb_left) + node = node->rb_left; + else if (node->rb_right) + node = node->rb_right; + else + return (struct rb_node *)node; + } +} + +struct rb_node *rb_next_postorder(const struct rb_node *node) +{ + const struct rb_node *parent; + if (!node) + return NULL; + parent = rb_parent(node); + + /* If we're sitting on node, we've already seen our children */ + if (parent && node == parent->rb_left && parent->rb_right) { + /* If we are the parent's left node, go to the parent's right + * node then all the way down to the left */ + return rb_left_deepest_node(parent->rb_right); + } else + /* Otherwise we are the parent's right node, and the parent + * should be next */ + return (struct rb_node *)parent; +} +EXPORT_SYMBOL(rb_next_postorder); + +struct rb_node *rb_first_postorder(const struct rb_root *root) +{ + if (!root->rb_node) + return NULL; + + return rb_left_deepest_node(root->rb_node); +} +EXPORT_SYMBOL(rb_first_postorder); diff --git a/lib/rsa/rsa-sign.c b/lib/rsa/rsa-sign.c index 83f5e878389..5d9716f0134 100644 --- a/lib/rsa/rsa-sign.c +++ b/lib/rsa/rsa-sign.c @@ -76,6 +76,7 @@ static int rsa_get_pub_key(const char *keydir, const char *name, RSA **rsap) rsa = EVP_PKEY_get1_RSA(key); if (!rsa) { rsa_err("Couldn't convert to a RSA style key"); + ret = -EINVAL; goto err_rsa; } fclose(f); @@ -261,10 +262,57 @@ err_priv: } /* + * rsa_get_exponent(): - Get the public exponent from an RSA key + */ +static int rsa_get_exponent(RSA *key, uint64_t *e) +{ + int ret; + BIGNUM *bn_te; + uint64_t te; + + ret = -EINVAL; + bn_te = NULL; + + if (!e) + goto cleanup; + + if (BN_num_bits(key->e) > 64) + goto cleanup; + + *e = BN_get_word(key->e); + + if (BN_num_bits(key->e) < 33) { + ret = 0; + goto cleanup; + } + + bn_te = BN_dup(key->e); + if (!bn_te) + goto cleanup; + + if (!BN_rshift(bn_te, bn_te, 32)) + goto cleanup; + + if (!BN_mask_bits(bn_te, 32)) + goto cleanup; + + te = BN_get_word(bn_te); + te <<= 32; + *e |= te; + ret = 0; + +cleanup: + if (bn_te) + BN_free(bn_te); + + return ret; +} + +/* * rsa_get_params(): - Get the important parameters of an RSA public key */ -int rsa_get_params(RSA *key, uint32_t *n0_invp, BIGNUM **modulusp, - BIGNUM **r_squaredp) +int rsa_get_params(RSA *key, uint64_t *exponent, uint32_t *n0_invp, + BIGNUM **modulusp, BIGNUM **r_squaredp) { BIGNUM *big1, *big2, *big32, *big2_32; BIGNUM *n, *r, *r_squared, *tmp; @@ -286,6 +334,9 @@ int rsa_get_params(RSA *key, uint32_t *n0_invp, BIGNUM **modulusp, return -ENOMEM; } + if (0 != rsa_get_exponent(key, exponent)) + ret = -1; + if (!BN_copy(n, key->n) || !BN_set_word(big1, 1L) || !BN_set_word(big2, 2L) || !BN_set_word(big32, 32L)) ret = -1; @@ -386,6 +437,7 @@ static int fdt_add_bignum(void *blob, int noffset, const char *prop_name, int rsa_add_verify_data(struct image_sign_info *info, void *keydest) { BIGNUM *modulus, *r_squared; + uint64_t exponent; uint32_t n0_inv; int parent, node; char name[100]; @@ -397,7 +449,7 @@ int rsa_add_verify_data(struct image_sign_info *info, void *keydest) ret = rsa_get_pub_key(info->keydir, info->keyname, &rsa); if (ret) return ret; - ret = rsa_get_params(rsa, &n0_inv, &modulus, &r_squared); + ret = rsa_get_params(rsa, &exponent, &n0_inv, &modulus, &r_squared); if (ret) return ret; bits = BN_num_bits(modulus); @@ -442,6 +494,9 @@ int rsa_add_verify_data(struct image_sign_info *info, void *keydest) if (!ret) ret = fdt_setprop_u32(keydest, node, "rsa,n0-inverse", n0_inv); if (!ret) { + ret = fdt_setprop_u64(keydest, node, "rsa,exponent", exponent); + } + if (!ret) { ret = fdt_add_bignum(keydest, node, "rsa,modulus", modulus, bits); } diff --git a/lib/rsa/rsa-verify.c b/lib/rsa/rsa-verify.c index bcb906368d0..4ef19b66f4b 100644 --- a/lib/rsa/rsa-verify.c +++ b/lib/rsa/rsa-verify.c @@ -26,6 +26,9 @@ #define get_unaligned_be32(a) fdt32_to_cpu(*(uint32_t *)a) #define put_unaligned_be32(a, b) (*(uint32_t *)(b) = cpu_to_fdt32(a)) +/* Default public exponent for backward compatibility */ +#define RSA_DEFAULT_PUBEXP 65537 + /** * subtract_modulus() - subtract modulus from the given value * @@ -54,9 +57,9 @@ static void subtract_modulus(const struct rsa_public_key *key, uint32_t num[]) static int greater_equal_modulus(const struct rsa_public_key *key, uint32_t num[]) { - uint32_t i; + int i; - for (i = key->len - 1; i >= 0; i--) { + for (i = (int)key->len - 1; i >= 0; i--) { if (num[i] < key->modulus[i]) return 0; if (num[i] > key->modulus[i]) @@ -123,6 +126,48 @@ static void montgomery_mul(const struct rsa_public_key *key, } /** + * num_pub_exponent_bits() - Number of bits in the public exponent + * + * @key: RSA key + * @num_bits: Storage for the number of public exponent bits + */ +static int num_public_exponent_bits(const struct rsa_public_key *key, + int *num_bits) +{ + uint64_t exponent; + int exponent_bits; + const uint max_bits = (sizeof(exponent) * 8); + + exponent = key->exponent; + exponent_bits = 0; + + if (!exponent) { + *num_bits = exponent_bits; + return 0; + } + + for (exponent_bits = 1; exponent_bits < max_bits + 1; ++exponent_bits) + if (!(exponent >>= 1)) { + *num_bits = exponent_bits; + return 0; + } + + return -EINVAL; +} + +/** + * is_public_exponent_bit_set() - Check if a bit in the public exponent is set + * + * @key: RSA key + * @pos: The bit position to check + */ +static int is_public_exponent_bit_set(const struct rsa_public_key *key, + int pos) +{ + return key->exponent & (1ULL << pos); +} + +/** * pow_mod() - in-place public exponentiation * * @key: RSA key @@ -132,6 +177,7 @@ static int pow_mod(const struct rsa_public_key *key, uint32_t *inout) { uint32_t *result, *ptr; uint i; + int j, k; /* Sanity check for stack size - key->len is in 32-bit words */ if (key->len > RSA_MAX_KEY_BITS / 32) { @@ -141,18 +187,48 @@ static int pow_mod(const struct rsa_public_key *key, uint32_t *inout) } uint32_t val[key->len], acc[key->len], tmp[key->len]; + uint32_t a_scaled[key->len]; result = tmp; /* Re-use location. */ /* Convert from big endian byte array to little endian word array. */ for (i = 0, ptr = inout + key->len - 1; i < key->len; i++, ptr--) val[i] = get_unaligned_be32(ptr); - montgomery_mul(key, acc, val, key->rr); /* axx = a * RR / R mod M */ - for (i = 0; i < 16; i += 2) { - montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod M */ - montgomery_mul(key, acc, tmp, tmp); /* acc = tmp^2 / R mod M */ + if (0 != num_public_exponent_bits(key, &k)) + return -EINVAL; + + if (k < 2) { + debug("Public exponent is too short (%d bits, minimum 2)\n", + k); + return -EINVAL; } - montgomery_mul(key, result, acc, val); /* result = XX * a / R mod M */ + + if (!is_public_exponent_bit_set(key, 0)) { + debug("LSB of RSA public exponent must be set.\n"); + return -EINVAL; + } + + /* the bit at e[k-1] is 1 by definition, so start with: C := M */ + montgomery_mul(key, acc, val, key->rr); /* acc = a * RR / R mod n */ + /* retain scaled version for intermediate use */ + memcpy(a_scaled, acc, key->len * sizeof(a_scaled[0])); + + for (j = k - 2; j > 0; --j) { + montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod n */ + + if (is_public_exponent_bit_set(key, j)) { + /* acc = tmp * val / R mod n */ + montgomery_mul(key, acc, tmp, a_scaled); + } else { + /* e[j] == 0, copy tmp back to acc for next operation */ + memcpy(acc, tmp, key->len * sizeof(acc[0])); + } + } + + /* the bit at e[0] is always 1 */ + montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod n */ + montgomery_mul(key, acc, tmp, val); /* acc = tmp * a / R mod M */ + memcpy(result, acc, key->len * sizeof(result[0])); /* Make sure result < mod; result is at most 1x mod too large. */ if (greater_equal_modulus(key, result)) @@ -229,6 +305,8 @@ static int rsa_verify_with_keynode(struct image_sign_info *info, const void *blob = info->fdt_blob; struct rsa_public_key key; const void *modulus, *rr; + const uint64_t *public_exponent; + int length; int ret; if (node < 0) { @@ -241,6 +319,11 @@ static int rsa_verify_with_keynode(struct image_sign_info *info, } key.len = fdtdec_get_int(blob, node, "rsa,num-bits", 0); key.n0inv = fdtdec_get_int(blob, node, "rsa,n0-inverse", 0); + public_exponent = fdt_getprop(blob, node, "rsa,exponent", &length); + if (!public_exponent || length < sizeof(*public_exponent)) + key.exponent = RSA_DEFAULT_PUBEXP; + else + key.exponent = fdt64_to_cpu(*public_exponent); modulus = fdt_getprop(blob, node, "rsa,modulus", NULL); rr = fdt_getprop(blob, node, "rsa,r-squared", NULL); if (!key.len || !modulus || !rr) { |