summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/Makefile2
-rw-r--r--lib/libfdt/Makefile3
-rw-r--r--lib/libfdt/fdt_addresses.c55
-rw-r--r--lib/libfdt/fdt_rw.c6
-rw-r--r--lib/libfdt/fdt_sw.c32
-rw-r--r--lib/libfdt/libfdt_internal.h6
-rw-r--r--lib/linux_compat.c47
-rw-r--r--lib/list_sort.c298
-rw-r--r--lib/lmb.c5
-rw-r--r--lib/rbtree.c684
-rw-r--r--lib/rsa/rsa-sign.c61
-rw-r--r--lib/rsa/rsa-verify.c97
12 files changed, 1023 insertions, 273 deletions
diff --git a/lib/Makefile b/lib/Makefile
index 68210a59b73..320197a5209 100644
--- a/lib/Makefile
+++ b/lib/Makefile
@@ -43,6 +43,7 @@ obj-y += strmhz.o
obj-$(CONFIG_TPM) += tpm.o
obj-$(CONFIG_RBTREE) += rbtree.o
obj-$(CONFIG_BITREVERSE) += bitrev.o
+obj-y += list_sort.o
endif
ifdef CONFIG_SPL_BUILD
@@ -58,6 +59,7 @@ obj-y += crc32.o
obj-y += ctype.o
obj-y += div64.o
obj-y += hang.o
+obj-y += linux_compat.o
obj-y += linux_string.o
obj-$(CONFIG_REGEX) += slre.o
obj-y += string.o
diff --git a/lib/libfdt/Makefile b/lib/libfdt/Makefile
index a02c9b02add..6fe79e0b06e 100644
--- a/lib/libfdt/Makefile
+++ b/lib/libfdt/Makefile
@@ -5,7 +5,8 @@
# SPDX-License-Identifier: GPL-2.0+
#
-COBJS-libfdt += fdt.o fdt_ro.o fdt_rw.o fdt_strerror.o fdt_sw.o fdt_wip.o fdt_empty_tree.o
+COBJS-libfdt += fdt.o fdt_ro.o fdt_rw.o fdt_strerror.o fdt_sw.o fdt_wip.o \
+ fdt_empty_tree.o fdt_addresses.o
obj-$(CONFIG_OF_LIBFDT) += $(COBJS-libfdt)
obj-$(CONFIG_FIT) += $(COBJS-libfdt)
diff --git a/lib/libfdt/fdt_addresses.c b/lib/libfdt/fdt_addresses.c
new file mode 100644
index 00000000000..76054d98e5f
--- /dev/null
+++ b/lib/libfdt/fdt_addresses.c
@@ -0,0 +1,55 @@
+/*
+ * libfdt - Flat Device Tree manipulation
+ * Copyright (C) 2014 David Gibson <david@gibson.dropbear.id.au>
+ * SPDX-License-Identifier: GPL-2.0+ BSD-2-Clause
+ */
+#include "libfdt_env.h"
+
+#ifndef USE_HOSTCC
+#include <fdt.h>
+#include <libfdt.h>
+#else
+#include "fdt_host.h"
+#endif
+
+#include "libfdt_internal.h"
+
+int fdt_address_cells(const void *fdt, int nodeoffset)
+{
+ const fdt32_t *ac;
+ int val;
+ int len;
+
+ ac = fdt_getprop(fdt, nodeoffset, "#address-cells", &len);
+ if (!ac)
+ return 2;
+
+ if (len != sizeof(*ac))
+ return -FDT_ERR_BADNCELLS;
+
+ val = fdt32_to_cpu(*ac);
+ if ((val <= 0) || (val > FDT_MAX_NCELLS))
+ return -FDT_ERR_BADNCELLS;
+
+ return val;
+}
+
+int fdt_size_cells(const void *fdt, int nodeoffset)
+{
+ const fdt32_t *sc;
+ int val;
+ int len;
+
+ sc = fdt_getprop(fdt, nodeoffset, "#size-cells", &len);
+ if (!sc)
+ return 2;
+
+ if (len != sizeof(*sc))
+ return -FDT_ERR_BADNCELLS;
+
+ val = fdt32_to_cpu(*sc);
+ if ((val < 0) || (val > FDT_MAX_NCELLS))
+ return -FDT_ERR_BADNCELLS;
+
+ return val;
+}
diff --git a/lib/libfdt/fdt_rw.c b/lib/libfdt/fdt_rw.c
index 6fa4f13073d..bec8b8ad891 100644
--- a/lib/libfdt/fdt_rw.c
+++ b/lib/libfdt/fdt_rw.c
@@ -43,9 +43,9 @@ static int _fdt_rw_check_header(void *fdt)
#define FDT_RW_CHECK_HEADER(fdt) \
{ \
- int err; \
- if ((err = _fdt_rw_check_header(fdt)) != 0) \
- return err; \
+ int __err; \
+ if ((__err = _fdt_rw_check_header(fdt)) != 0) \
+ return __err; \
}
static inline int _fdt_data_size(void *fdt)
diff --git a/lib/libfdt/fdt_sw.c b/lib/libfdt/fdt_sw.c
index 580b57024ff..320a914991a 100644
--- a/lib/libfdt/fdt_sw.c
+++ b/lib/libfdt/fdt_sw.c
@@ -62,6 +62,38 @@ int fdt_create(void *buf, int bufsize)
return 0;
}
+int fdt_resize(void *fdt, void *buf, int bufsize)
+{
+ size_t headsize, tailsize;
+ char *oldtail, *newtail;
+
+ FDT_SW_CHECK_HEADER(fdt);
+
+ headsize = fdt_off_dt_struct(fdt);
+ tailsize = fdt_size_dt_strings(fdt);
+
+ if ((headsize + tailsize) > bufsize)
+ return -FDT_ERR_NOSPACE;
+
+ oldtail = (char *)fdt + fdt_totalsize(fdt) - tailsize;
+ newtail = (char *)buf + bufsize - tailsize;
+
+ /* Two cases to avoid clobbering data if the old and new
+ * buffers partially overlap */
+ if (buf <= fdt) {
+ memmove(buf, fdt, headsize);
+ memmove(newtail, oldtail, tailsize);
+ } else {
+ memmove(newtail, oldtail, tailsize);
+ memmove(buf, fdt, headsize);
+ }
+
+ fdt_set_off_dt_strings(buf, bufsize);
+ fdt_set_totalsize(buf, bufsize);
+
+ return 0;
+}
+
int fdt_add_reservemap_entry(void *fdt, uint64_t addr, uint64_t size)
{
struct fdt_reserve_entry *re;
diff --git a/lib/libfdt/libfdt_internal.h b/lib/libfdt/libfdt_internal.h
index 13cbc9af2ab..9a79fe85dda 100644
--- a/lib/libfdt/libfdt_internal.h
+++ b/lib/libfdt/libfdt_internal.h
@@ -12,9 +12,9 @@
#define FDT_CHECK_HEADER(fdt) \
{ \
- int err; \
- if ((err = fdt_check_header(fdt)) != 0) \
- return err; \
+ int __err; \
+ if ((__err = fdt_check_header(fdt)) != 0) \
+ return __err; \
}
int _fdt_check_node_offset(const void *fdt, int offset);
diff --git a/lib/linux_compat.c b/lib/linux_compat.c
new file mode 100644
index 00000000000..a3d4675f7ed
--- /dev/null
+++ b/lib/linux_compat.c
@@ -0,0 +1,47 @@
+
+#include <common.h>
+#include <linux/compat.h>
+
+struct p_current cur = {
+ .pid = 1,
+};
+__maybe_unused struct p_current *current = &cur;
+
+unsigned long copy_from_user(void *dest, const void *src,
+ unsigned long count)
+{
+ memcpy((void *)dest, (void *)src, count);
+ return 0;
+}
+
+void *kmalloc(size_t size, int flags)
+{
+ return memalign(ARCH_DMA_MINALIGN, size);
+}
+
+void *kzalloc(size_t size, int flags)
+{
+ void *ptr = kmalloc(size, flags);
+ memset(ptr, 0, size);
+ return ptr;
+}
+
+void *vzalloc(unsigned long size)
+{
+ return kzalloc(size, 0);
+}
+
+struct kmem_cache *get_mem(int element_sz)
+{
+ struct kmem_cache *ret;
+
+ ret = memalign(ARCH_DMA_MINALIGN, sizeof(struct kmem_cache));
+ ret->sz = element_sz;
+
+ return ret;
+}
+
+void *kmem_cache_alloc(struct kmem_cache *obj, int flag)
+{
+ return memalign(ARCH_DMA_MINALIGN, obj->sz);
+}
diff --git a/lib/list_sort.c b/lib/list_sort.c
new file mode 100644
index 00000000000..81de0a17de8
--- /dev/null
+++ b/lib/list_sort.c
@@ -0,0 +1,298 @@
+#define __UBOOT__
+#ifndef __UBOOT__
+#include <linux/kernel.h>
+#include <linux/module.h>
+#include <linux/slab.h>
+#else
+#include <linux/compat.h>
+#include <common.h>
+#include <malloc.h>
+#endif
+#include <linux/list.h>
+#include <linux/list_sort.h>
+
+#define MAX_LIST_LENGTH_BITS 20
+
+/*
+ * Returns a list organized in an intermediate format suited
+ * to chaining of merge() calls: null-terminated, no reserved or
+ * sentinel head node, "prev" links not maintained.
+ */
+static struct list_head *merge(void *priv,
+ int (*cmp)(void *priv, struct list_head *a,
+ struct list_head *b),
+ struct list_head *a, struct list_head *b)
+{
+ struct list_head head, *tail = &head;
+
+ while (a && b) {
+ /* if equal, take 'a' -- important for sort stability */
+ if ((*cmp)(priv, a, b) <= 0) {
+ tail->next = a;
+ a = a->next;
+ } else {
+ tail->next = b;
+ b = b->next;
+ }
+ tail = tail->next;
+ }
+ tail->next = a?:b;
+ return head.next;
+}
+
+/*
+ * Combine final list merge with restoration of standard doubly-linked
+ * list structure. This approach duplicates code from merge(), but
+ * runs faster than the tidier alternatives of either a separate final
+ * prev-link restoration pass, or maintaining the prev links
+ * throughout.
+ */
+static void merge_and_restore_back_links(void *priv,
+ int (*cmp)(void *priv, struct list_head *a,
+ struct list_head *b),
+ struct list_head *head,
+ struct list_head *a, struct list_head *b)
+{
+ struct list_head *tail = head;
+
+ while (a && b) {
+ /* if equal, take 'a' -- important for sort stability */
+ if ((*cmp)(priv, a, b) <= 0) {
+ tail->next = a;
+ a->prev = tail;
+ a = a->next;
+ } else {
+ tail->next = b;
+ b->prev = tail;
+ b = b->next;
+ }
+ tail = tail->next;
+ }
+ tail->next = a ? : b;
+
+ do {
+ /*
+ * In worst cases this loop may run many iterations.
+ * Continue callbacks to the client even though no
+ * element comparison is needed, so the client's cmp()
+ * routine can invoke cond_resched() periodically.
+ */
+ (*cmp)(priv, tail->next, tail->next);
+
+ tail->next->prev = tail;
+ tail = tail->next;
+ } while (tail->next);
+
+ tail->next = head;
+ head->prev = tail;
+}
+
+/**
+ * list_sort - sort a list
+ * @priv: private data, opaque to list_sort(), passed to @cmp
+ * @head: the list to sort
+ * @cmp: the elements comparison function
+ *
+ * This function implements "merge sort", which has O(nlog(n))
+ * complexity.
+ *
+ * The comparison function @cmp must return a negative value if @a
+ * should sort before @b, and a positive value if @a should sort after
+ * @b. If @a and @b are equivalent, and their original relative
+ * ordering is to be preserved, @cmp must return 0.
+ */
+void list_sort(void *priv, struct list_head *head,
+ int (*cmp)(void *priv, struct list_head *a,
+ struct list_head *b))
+{
+ struct list_head *part[MAX_LIST_LENGTH_BITS+1]; /* sorted partial lists
+ -- last slot is a sentinel */
+ int lev; /* index into part[] */
+ int max_lev = 0;
+ struct list_head *list;
+
+ if (list_empty(head))
+ return;
+
+ memset(part, 0, sizeof(part));
+
+ head->prev->next = NULL;
+ list = head->next;
+
+ while (list) {
+ struct list_head *cur = list;
+ list = list->next;
+ cur->next = NULL;
+
+ for (lev = 0; part[lev]; lev++) {
+ cur = merge(priv, cmp, part[lev], cur);
+ part[lev] = NULL;
+ }
+ if (lev > max_lev) {
+ if (unlikely(lev >= ARRAY_SIZE(part)-1)) {
+ printk_once(KERN_DEBUG "list passed to"
+ " list_sort() too long for"
+ " efficiency\n");
+ lev--;
+ }
+ max_lev = lev;
+ }
+ part[lev] = cur;
+ }
+
+ for (lev = 0; lev < max_lev; lev++)
+ if (part[lev])
+ list = merge(priv, cmp, part[lev], list);
+
+ merge_and_restore_back_links(priv, cmp, head, part[max_lev], list);
+}
+EXPORT_SYMBOL(list_sort);
+
+#ifdef CONFIG_TEST_LIST_SORT
+
+#include <linux/random.h>
+
+/*
+ * The pattern of set bits in the list length determines which cases
+ * are hit in list_sort().
+ */
+#define TEST_LIST_LEN (512+128+2) /* not including head */
+
+#define TEST_POISON1 0xDEADBEEF
+#define TEST_POISON2 0xA324354C
+
+struct debug_el {
+ unsigned int poison1;
+ struct list_head list;
+ unsigned int poison2;
+ int value;
+ unsigned serial;
+};
+
+/* Array, containing pointers to all elements in the test list */
+static struct debug_el **elts __initdata;
+
+static int __init check(struct debug_el *ela, struct debug_el *elb)
+{
+ if (ela->serial >= TEST_LIST_LEN) {
+ printk(KERN_ERR "list_sort_test: error: incorrect serial %d\n",
+ ela->serial);
+ return -EINVAL;
+ }
+ if (elb->serial >= TEST_LIST_LEN) {
+ printk(KERN_ERR "list_sort_test: error: incorrect serial %d\n",
+ elb->serial);
+ return -EINVAL;
+ }
+ if (elts[ela->serial] != ela || elts[elb->serial] != elb) {
+ printk(KERN_ERR "list_sort_test: error: phantom element\n");
+ return -EINVAL;
+ }
+ if (ela->poison1 != TEST_POISON1 || ela->poison2 != TEST_POISON2) {
+ printk(KERN_ERR "list_sort_test: error: bad poison: %#x/%#x\n",
+ ela->poison1, ela->poison2);
+ return -EINVAL;
+ }
+ if (elb->poison1 != TEST_POISON1 || elb->poison2 != TEST_POISON2) {
+ printk(KERN_ERR "list_sort_test: error: bad poison: %#x/%#x\n",
+ elb->poison1, elb->poison2);
+ return -EINVAL;
+ }
+ return 0;
+}
+
+static int __init cmp(void *priv, struct list_head *a, struct list_head *b)
+{
+ struct debug_el *ela, *elb;
+
+ ela = container_of(a, struct debug_el, list);
+ elb = container_of(b, struct debug_el, list);
+
+ check(ela, elb);
+ return ela->value - elb->value;
+}
+
+static int __init list_sort_test(void)
+{
+ int i, count = 1, err = -EINVAL;
+ struct debug_el *el;
+ struct list_head *cur, *tmp;
+ LIST_HEAD(head);
+
+ printk(KERN_DEBUG "list_sort_test: start testing list_sort()\n");
+
+ elts = kmalloc(sizeof(void *) * TEST_LIST_LEN, GFP_KERNEL);
+ if (!elts) {
+ printk(KERN_ERR "list_sort_test: error: cannot allocate "
+ "memory\n");
+ goto exit;
+ }
+
+ for (i = 0; i < TEST_LIST_LEN; i++) {
+ el = kmalloc(sizeof(*el), GFP_KERNEL);
+ if (!el) {
+ printk(KERN_ERR "list_sort_test: error: cannot "
+ "allocate memory\n");
+ goto exit;
+ }
+ /* force some equivalencies */
+ el->value = prandom_u32() % (TEST_LIST_LEN / 3);
+ el->serial = i;
+ el->poison1 = TEST_POISON1;
+ el->poison2 = TEST_POISON2;
+ elts[i] = el;
+ list_add_tail(&el->list, &head);
+ }
+
+ list_sort(NULL, &head, cmp);
+
+ for (cur = head.next; cur->next != &head; cur = cur->next) {
+ struct debug_el *el1;
+ int cmp_result;
+
+ if (cur->next->prev != cur) {
+ printk(KERN_ERR "list_sort_test: error: list is "
+ "corrupted\n");
+ goto exit;
+ }
+
+ cmp_result = cmp(NULL, cur, cur->next);
+ if (cmp_result > 0) {
+ printk(KERN_ERR "list_sort_test: error: list is not "
+ "sorted\n");
+ goto exit;
+ }
+
+ el = container_of(cur, struct debug_el, list);
+ el1 = container_of(cur->next, struct debug_el, list);
+ if (cmp_result == 0 && el->serial >= el1->serial) {
+ printk(KERN_ERR "list_sort_test: error: order of "
+ "equivalent elements not preserved\n");
+ goto exit;
+ }
+
+ if (check(el, el1)) {
+ printk(KERN_ERR "list_sort_test: error: element check "
+ "failed\n");
+ goto exit;
+ }
+ count++;
+ }
+
+ if (count != TEST_LIST_LEN) {
+ printk(KERN_ERR "list_sort_test: error: bad list length %d",
+ count);
+ goto exit;
+ }
+
+ err = 0;
+exit:
+ kfree(elts);
+ list_for_each_safe(cur, tmp, &head) {
+ list_del(cur);
+ kfree(container_of(cur, struct debug_el, list));
+ }
+ return err;
+}
+module_init(list_sort_test);
+#endif /* CONFIG_TEST_LIST_SORT */
diff --git a/lib/lmb.c b/lib/lmb.c
index 49a3c9e01e5..41a2be46356 100644
--- a/lib/lmb.c
+++ b/lib/lmb.c
@@ -295,7 +295,10 @@ phys_addr_t __lmb_alloc_base(struct lmb *lmb, phys_size_t size, ulong align, phy
if (max_addr == LMB_ALLOC_ANYWHERE)
base = lmb_align_down(lmbbase + lmbsize - size, align);
else if (lmbbase < max_addr) {
- base = min(lmbbase + lmbsize, max_addr);
+ base = lmbbase + lmbsize;
+ if (base < lmbbase)
+ base = -1;
+ base = min(base, max_addr);
base = lmb_align_down(base - size, align);
} else
continue;
diff --git a/lib/rbtree.c b/lib/rbtree.c
index b05f1ab7f59..9e52f70d173 100644
--- a/lib/rbtree.c
+++ b/lib/rbtree.c
@@ -2,283 +2,412 @@
Red Black Trees
(C) 1999 Andrea Arcangeli <andrea@suse.de>
(C) 2002 David Woodhouse <dwmw2@infradead.org>
+ (C) 2012 Michel Lespinasse <walken@google.com>
* SPDX-License-Identifier: GPL-2.0+
linux/lib/rbtree.c
*/
+#define __UBOOT__
+#include <linux/rbtree_augmented.h>
+#ifndef __UBOOT__
+#include <linux/export.h>
+#else
#include <ubi_uboot.h>
-#include <linux/rbtree.h>
+#endif
+/*
+ * red-black trees properties: http://en.wikipedia.org/wiki/Rbtree
+ *
+ * 1) A node is either red or black
+ * 2) The root is black
+ * 3) All leaves (NULL) are black
+ * 4) Both children of every red node are black
+ * 5) Every simple path from root to leaves contains the same number
+ * of black nodes.
+ *
+ * 4 and 5 give the O(log n) guarantee, since 4 implies you cannot have two
+ * consecutive red nodes in a path and every red node is therefore followed by
+ * a black. So if B is the number of black nodes on every simple path (as per
+ * 5), then the longest possible path due to 4 is 2B.
+ *
+ * We shall indicate color with case, where black nodes are uppercase and red
+ * nodes will be lowercase. Unknown color nodes shall be drawn as red within
+ * parentheses and have some accompanying text comment.
+ */
-static void __rb_rotate_left(struct rb_node *node, struct rb_root *root)
+static inline void rb_set_black(struct rb_node *rb)
{
- struct rb_node *right = node->rb_right;
- struct rb_node *parent = rb_parent(node);
-
- if ((node->rb_right = right->rb_left))
- rb_set_parent(right->rb_left, node);
- right->rb_left = node;
-
- rb_set_parent(right, parent);
-
- if (parent)
- {
- if (node == parent->rb_left)
- parent->rb_left = right;
- else
- parent->rb_right = right;
- }
- else
- root->rb_node = right;
- rb_set_parent(node, right);
+ rb->__rb_parent_color |= RB_BLACK;
}
-static void __rb_rotate_right(struct rb_node *node, struct rb_root *root)
+static inline struct rb_node *rb_red_parent(struct rb_node *red)
{
- struct rb_node *left = node->rb_left;
- struct rb_node *parent = rb_parent(node);
-
- if ((node->rb_left = left->rb_right))
- rb_set_parent(left->rb_right, node);
- left->rb_right = node;
-
- rb_set_parent(left, parent);
+ return (struct rb_node *)red->__rb_parent_color;
+}
- if (parent)
- {
- if (node == parent->rb_right)
- parent->rb_right = left;
- else
- parent->rb_left = left;
- }
- else
- root->rb_node = left;
- rb_set_parent(node, left);
+/*
+ * Helper function for rotations:
+ * - old's parent and color get assigned to new
+ * - old gets assigned new as a parent and 'color' as a color.
+ */
+static inline void
+__rb_rotate_set_parents(struct rb_node *old, struct rb_node *new,
+ struct rb_root *root, int color)
+{
+ struct rb_node *parent = rb_parent(old);
+ new->__rb_parent_color = old->__rb_parent_color;
+ rb_set_parent_color(old, new, color);
+ __rb_change_child(old, new, parent, root);
}
-void rb_insert_color(struct rb_node *node, struct rb_root *root)
+static __always_inline void
+__rb_insert(struct rb_node *node, struct rb_root *root,
+ void (*augment_rotate)(struct rb_node *old, struct rb_node *new))
{
- struct rb_node *parent, *gparent;
-
- while ((parent = rb_parent(node)) && rb_is_red(parent))
- {
- gparent = rb_parent(parent);
-
- if (parent == gparent->rb_left)
- {
- {
- register struct rb_node *uncle = gparent->rb_right;
- if (uncle && rb_is_red(uncle))
- {
- rb_set_black(uncle);
- rb_set_black(parent);
- rb_set_red(gparent);
- node = gparent;
- continue;
- }
+ struct rb_node *parent = rb_red_parent(node), *gparent, *tmp;
+
+ while (true) {
+ /*
+ * Loop invariant: node is red
+ *
+ * If there is a black parent, we are done.
+ * Otherwise, take some corrective action as we don't
+ * want a red root or two consecutive red nodes.
+ */
+ if (!parent) {
+ rb_set_parent_color(node, NULL, RB_BLACK);
+ break;
+ } else if (rb_is_black(parent))
+ break;
+
+ gparent = rb_red_parent(parent);
+
+ tmp = gparent->rb_right;
+ if (parent != tmp) { /* parent == gparent->rb_left */
+ if (tmp && rb_is_red(tmp)) {
+ /*
+ * Case 1 - color flips
+ *
+ * G g
+ * / \ / \
+ * p u --> P U
+ * / /
+ * n N
+ *
+ * However, since g's parent might be red, and
+ * 4) does not allow this, we need to recurse
+ * at g.
+ */
+ rb_set_parent_color(tmp, gparent, RB_BLACK);
+ rb_set_parent_color(parent, gparent, RB_BLACK);
+ node = gparent;
+ parent = rb_parent(node);
+ rb_set_parent_color(node, parent, RB_RED);
+ continue;
}
- if (parent->rb_right == node)
- {
- register struct rb_node *tmp;
- __rb_rotate_left(parent, root);
- tmp = parent;
+ tmp = parent->rb_right;
+ if (node == tmp) {
+ /*
+ * Case 2 - left rotate at parent
+ *
+ * G G
+ * / \ / \
+ * p U --> n U
+ * \ /
+ * n p
+ *
+ * This still leaves us in violation of 4), the
+ * continuation into Case 3 will fix that.
+ */
+ parent->rb_right = tmp = node->rb_left;
+ node->rb_left = parent;
+ if (tmp)
+ rb_set_parent_color(tmp, parent,
+ RB_BLACK);
+ rb_set_parent_color(parent, node, RB_RED);
+ augment_rotate(parent, node);
parent = node;
- node = tmp;
+ tmp = node->rb_right;
}
- rb_set_black(parent);
- rb_set_red(gparent);
- __rb_rotate_right(gparent, root);
+ /*
+ * Case 3 - right rotate at gparent
+ *
+ * G P
+ * / \ / \
+ * p U --> n g
+ * / \
+ * n U
+ */
+ gparent->rb_left = tmp; /* == parent->rb_right */
+ parent->rb_right = gparent;
+ if (tmp)
+ rb_set_parent_color(tmp, gparent, RB_BLACK);
+ __rb_rotate_set_parents(gparent, parent, root, RB_RED);
+ augment_rotate(gparent, parent);
+ break;
} else {
- {
- register struct rb_node *uncle = gparent->rb_left;
- if (uncle && rb_is_red(uncle))
- {
- rb_set_black(uncle);
- rb_set_black(parent);
- rb_set_red(gparent);
- node = gparent;
- continue;
- }
+ tmp = gparent->rb_left;
+ if (tmp && rb_is_red(tmp)) {
+ /* Case 1 - color flips */
+ rb_set_parent_color(tmp, gparent, RB_BLACK);
+ rb_set_parent_color(parent, gparent, RB_BLACK);
+ node = gparent;
+ parent = rb_parent(node);
+ rb_set_parent_color(node, parent, RB_RED);
+ continue;
}
- if (parent->rb_left == node)
- {
- register struct rb_node *tmp;
- __rb_rotate_right(parent, root);
- tmp = parent;
+ tmp = parent->rb_left;
+ if (node == tmp) {
+ /* Case 2 - right rotate at parent */
+ parent->rb_left = tmp = node->rb_right;
+ node->rb_right = parent;
+ if (tmp)
+ rb_set_parent_color(tmp, parent,
+ RB_BLACK);
+ rb_set_parent_color(parent, node, RB_RED);
+ augment_rotate(parent, node);
parent = node;
- node = tmp;
+ tmp = node->rb_left;
}
- rb_set_black(parent);
- rb_set_red(gparent);
- __rb_rotate_left(gparent, root);
+ /* Case 3 - left rotate at gparent */
+ gparent->rb_right = tmp; /* == parent->rb_left */
+ parent->rb_left = gparent;
+ if (tmp)
+ rb_set_parent_color(tmp, gparent, RB_BLACK);
+ __rb_rotate_set_parents(gparent, parent, root, RB_RED);
+ augment_rotate(gparent, parent);
+ break;
}
}
-
- rb_set_black(root->rb_node);
}
-static void __rb_erase_color(struct rb_node *node, struct rb_node *parent,
- struct rb_root *root)
+/*
+ * Inline version for rb_erase() use - we want to be able to inline
+ * and eliminate the dummy_rotate callback there
+ */
+static __always_inline void
+____rb_erase_color(struct rb_node *parent, struct rb_root *root,
+ void (*augment_rotate)(struct rb_node *old, struct rb_node *new))
{
- struct rb_node *other;
-
- while ((!node || rb_is_black(node)) && node != root->rb_node)
- {
- if (parent->rb_left == node)
- {
- other = parent->rb_right;
- if (rb_is_red(other))
- {
- rb_set_black(other);
- rb_set_red(parent);
- __rb_rotate_left(parent, root);
- other = parent->rb_right;
- }
- if ((!other->rb_left || rb_is_black(other->rb_left)) &&
- (!other->rb_right || rb_is_black(other->rb_right)))
- {
- rb_set_red(other);
- node = parent;
- parent = rb_parent(node);
+ struct rb_node *node = NULL, *sibling, *tmp1, *tmp2;
+
+ while (true) {
+ /*
+ * Loop invariants:
+ * - node is black (or NULL on first iteration)
+ * - node is not the root (parent is not NULL)
+ * - All leaf paths going through parent and node have a
+ * black node count that is 1 lower than other leaf paths.
+ */
+ sibling = parent->rb_right;
+ if (node != sibling) { /* node == parent->rb_left */
+ if (rb_is_red(sibling)) {
+ /*
+ * Case 1 - left rotate at parent
+ *
+ * P S
+ * / \ / \
+ * N s --> p Sr
+ * / \ / \
+ * Sl Sr N Sl
+ */
+ parent->rb_right = tmp1 = sibling->rb_left;
+ sibling->rb_left = parent;
+ rb_set_parent_color(tmp1, parent, RB_BLACK);
+ __rb_rotate_set_parents(parent, sibling, root,
+ RB_RED);
+ augment_rotate(parent, sibling);
+ sibling = tmp1;
}
- else
- {
- if (!other->rb_right || rb_is_black(other->rb_right))
- {
- struct rb_node *o_left;
- if ((o_left = other->rb_left))
- rb_set_black(o_left);
- rb_set_red(other);
- __rb_rotate_right(other, root);
- other = parent->rb_right;
+ tmp1 = sibling->rb_right;
+ if (!tmp1 || rb_is_black(tmp1)) {
+ tmp2 = sibling->rb_left;
+ if (!tmp2 || rb_is_black(tmp2)) {
+ /*
+ * Case 2 - sibling color flip
+ * (p could be either color here)
+ *
+ * (p) (p)
+ * / \ / \
+ * N S --> N s
+ * / \ / \
+ * Sl Sr Sl Sr
+ *
+ * This leaves us violating 5) which
+ * can be fixed by flipping p to black
+ * if it was red, or by recursing at p.
+ * p is red when coming from Case 1.
+ */
+ rb_set_parent_color(sibling, parent,
+ RB_RED);
+ if (rb_is_red(parent))
+ rb_set_black(parent);
+ else {
+ node = parent;
+ parent = rb_parent(node);
+ if (parent)
+ continue;
+ }
+ break;
}
- rb_set_color(other, rb_color(parent));
- rb_set_black(parent);
- if (other->rb_right)
- rb_set_black(other->rb_right);
- __rb_rotate_left(parent, root);
- node = root->rb_node;
- break;
+ /*
+ * Case 3 - right rotate at sibling
+ * (p could be either color here)
+ *
+ * (p) (p)
+ * / \ / \
+ * N S --> N Sl
+ * / \ \
+ * sl Sr s
+ * \
+ * Sr
+ */
+ sibling->rb_left = tmp1 = tmp2->rb_right;
+ tmp2->rb_right = sibling;
+ parent->rb_right = tmp2;
+ if (tmp1)
+ rb_set_parent_color(tmp1, sibling,
+ RB_BLACK);
+ augment_rotate(sibling, tmp2);
+ tmp1 = sibling;
+ sibling = tmp2;
}
- }
- else
- {
- other = parent->rb_left;
- if (rb_is_red(other))
- {
- rb_set_black(other);
- rb_set_red(parent);
- __rb_rotate_right(parent, root);
- other = parent->rb_left;
- }
- if ((!other->rb_left || rb_is_black(other->rb_left)) &&
- (!other->rb_right || rb_is_black(other->rb_right)))
- {
- rb_set_red(other);
- node = parent;
- parent = rb_parent(node);
+ /*
+ * Case 4 - left rotate at parent + color flips
+ * (p and sl could be either color here.
+ * After rotation, p becomes black, s acquires
+ * p's color, and sl keeps its color)
+ *
+ * (p) (s)
+ * / \ / \
+ * N S --> P Sr
+ * / \ / \
+ * (sl) sr N (sl)
+ */
+ parent->rb_right = tmp2 = sibling->rb_left;
+ sibling->rb_left = parent;
+ rb_set_parent_color(tmp1, sibling, RB_BLACK);
+ if (tmp2)
+ rb_set_parent(tmp2, parent);
+ __rb_rotate_set_parents(parent, sibling, root,
+ RB_BLACK);
+ augment_rotate(parent, sibling);
+ break;
+ } else {
+ sibling = parent->rb_left;
+ if (rb_is_red(sibling)) {
+ /* Case 1 - right rotate at parent */
+ parent->rb_left = tmp1 = sibling->rb_right;
+ sibling->rb_right = parent;
+ rb_set_parent_color(tmp1, parent, RB_BLACK);
+ __rb_rotate_set_parents(parent, sibling, root,
+ RB_RED);
+ augment_rotate(parent, sibling);
+ sibling = tmp1;
}
- else
- {
- if (!other->rb_left || rb_is_black(other->rb_left))
- {
- register struct rb_node *o_right;
- if ((o_right = other->rb_right))
- rb_set_black(o_right);
- rb_set_red(other);
- __rb_rotate_left(other, root);
- other = parent->rb_left;
+ tmp1 = sibling->rb_left;
+ if (!tmp1 || rb_is_black(tmp1)) {
+ tmp2 = sibling->rb_right;
+ if (!tmp2 || rb_is_black(tmp2)) {
+ /* Case 2 - sibling color flip */
+ rb_set_parent_color(sibling, parent,
+ RB_RED);
+ if (rb_is_red(parent))
+ rb_set_black(parent);
+ else {
+ node = parent;
+ parent = rb_parent(node);
+ if (parent)
+ continue;
+ }
+ break;
}
- rb_set_color(other, rb_color(parent));
- rb_set_black(parent);
- if (other->rb_left)
- rb_set_black(other->rb_left);
- __rb_rotate_right(parent, root);
- node = root->rb_node;
- break;
+ /* Case 3 - right rotate at sibling */
+ sibling->rb_right = tmp1 = tmp2->rb_left;
+ tmp2->rb_left = sibling;
+ parent->rb_left = tmp2;
+ if (tmp1)
+ rb_set_parent_color(tmp1, sibling,
+ RB_BLACK);
+ augment_rotate(sibling, tmp2);
+ tmp1 = sibling;
+ sibling = tmp2;
}
+ /* Case 4 - left rotate at parent + color flips */
+ parent->rb_left = tmp2 = sibling->rb_right;
+ sibling->rb_right = parent;
+ rb_set_parent_color(tmp1, sibling, RB_BLACK);
+ if (tmp2)
+ rb_set_parent(tmp2, parent);
+ __rb_rotate_set_parents(parent, sibling, root,
+ RB_BLACK);
+ augment_rotate(parent, sibling);
+ break;
}
}
- if (node)
- rb_set_black(node);
}
+/* Non-inline version for rb_erase_augmented() use */
+void __rb_erase_color(struct rb_node *parent, struct rb_root *root,
+ void (*augment_rotate)(struct rb_node *old, struct rb_node *new))
+{
+ ____rb_erase_color(parent, root, augment_rotate);
+}
+EXPORT_SYMBOL(__rb_erase_color);
+
+/*
+ * Non-augmented rbtree manipulation functions.
+ *
+ * We use dummy augmented callbacks here, and have the compiler optimize them
+ * out of the rb_insert_color() and rb_erase() function definitions.
+ */
+
+static inline void dummy_propagate(struct rb_node *node, struct rb_node *stop) {}
+static inline void dummy_copy(struct rb_node *old, struct rb_node *new) {}
+static inline void dummy_rotate(struct rb_node *old, struct rb_node *new) {}
+
+static const struct rb_augment_callbacks dummy_callbacks = {
+ dummy_propagate, dummy_copy, dummy_rotate
+};
+
+void rb_insert_color(struct rb_node *node, struct rb_root *root)
+{
+ __rb_insert(node, root, dummy_rotate);
+}
+EXPORT_SYMBOL(rb_insert_color);
+
void rb_erase(struct rb_node *node, struct rb_root *root)
{
- struct rb_node *child, *parent;
- int color;
-
- if (!node->rb_left)
- child = node->rb_right;
- else if (!node->rb_right)
- child = node->rb_left;
- else
- {
- struct rb_node *old = node, *left;
-
- node = node->rb_right;
- while ((left = node->rb_left) != NULL)
- node = left;
- child = node->rb_right;
- parent = rb_parent(node);
- color = rb_color(node);
-
- if (child)
- rb_set_parent(child, parent);
- if (parent == old) {
- parent->rb_right = child;
- parent = node;
- } else
- parent->rb_left = child;
-
- node->rb_parent_color = old->rb_parent_color;
- node->rb_right = old->rb_right;
- node->rb_left = old->rb_left;
-
- if (rb_parent(old))
- {
- if (rb_parent(old)->rb_left == old)
- rb_parent(old)->rb_left = node;
- else
- rb_parent(old)->rb_right = node;
- } else
- root->rb_node = node;
-
- rb_set_parent(old->rb_left, node);
- if (old->rb_right)
- rb_set_parent(old->rb_right, node);
- goto color;
- }
+ struct rb_node *rebalance;
+ rebalance = __rb_erase_augmented(node, root, &dummy_callbacks);
+ if (rebalance)
+ ____rb_erase_color(rebalance, root, dummy_rotate);
+}
+EXPORT_SYMBOL(rb_erase);
- parent = rb_parent(node);
- color = rb_color(node);
-
- if (child)
- rb_set_parent(child, parent);
- if (parent)
- {
- if (parent->rb_left == node)
- parent->rb_left = child;
- else
- parent->rb_right = child;
- }
- else
- root->rb_node = child;
+/*
+ * Augmented rbtree manipulation functions.
+ *
+ * This instantiates the same __always_inline functions as in the non-augmented
+ * case, but this time with user-defined callbacks.
+ */
- color:
- if (color == RB_BLACK)
- __rb_erase_color(child, parent, root);
+void __rb_insert_augmented(struct rb_node *node, struct rb_root *root,
+ void (*augment_rotate)(struct rb_node *old, struct rb_node *new))
+{
+ __rb_insert(node, root, augment_rotate);
}
+EXPORT_SYMBOL(__rb_insert_augmented);
/*
* This function returns the first node (in sort order) of the tree.
*/
-struct rb_node *rb_first(struct rb_root *root)
+struct rb_node *rb_first(const struct rb_root *root)
{
struct rb_node *n;
@@ -289,8 +418,9 @@ struct rb_node *rb_first(struct rb_root *root)
n = n->rb_left;
return n;
}
+EXPORT_SYMBOL(rb_first);
-struct rb_node *rb_last(struct rb_root *root)
+struct rb_node *rb_last(const struct rb_root *root)
{
struct rb_node *n;
@@ -301,58 +431,68 @@ struct rb_node *rb_last(struct rb_root *root)
n = n->rb_right;
return n;
}
+EXPORT_SYMBOL(rb_last);
-struct rb_node *rb_next(struct rb_node *node)
+struct rb_node *rb_next(const struct rb_node *node)
{
struct rb_node *parent;
- if (rb_parent(node) == node)
+ if (RB_EMPTY_NODE(node))
return NULL;
- /* If we have a right-hand child, go down and then left as far
- as we can. */
+ /*
+ * If we have a right-hand child, go down and then left as far
+ * as we can.
+ */
if (node->rb_right) {
- node = node->rb_right;
+ node = node->rb_right;
while (node->rb_left)
node=node->rb_left;
- return node;
+ return (struct rb_node *)node;
}
- /* No right-hand children. Everything down and left is
- smaller than us, so any 'next' node must be in the general
- direction of our parent. Go up the tree; any time the
- ancestor is a right-hand child of its parent, keep going
- up. First time it's a left-hand child of its parent, said
- parent is our 'next' node. */
+ /*
+ * No right-hand children. Everything down and left is smaller than us,
+ * so any 'next' node must be in the general direction of our parent.
+ * Go up the tree; any time the ancestor is a right-hand child of its
+ * parent, keep going up. First time it's a left-hand child of its
+ * parent, said parent is our 'next' node.
+ */
while ((parent = rb_parent(node)) && node == parent->rb_right)
node = parent;
return parent;
}
+EXPORT_SYMBOL(rb_next);
-struct rb_node *rb_prev(struct rb_node *node)
+struct rb_node *rb_prev(const struct rb_node *node)
{
struct rb_node *parent;
- if (rb_parent(node) == node)
+ if (RB_EMPTY_NODE(node))
return NULL;
- /* If we have a left-hand child, go down and then right as far
- as we can. */
+ /*
+ * If we have a left-hand child, go down and then right as far
+ * as we can.
+ */
if (node->rb_left) {
- node = node->rb_left;
+ node = node->rb_left;
while (node->rb_right)
node=node->rb_right;
- return node;
+ return (struct rb_node *)node;
}
- /* No left-hand children. Go up till we find an ancestor which
- is a right-hand child of its parent */
+ /*
+ * No left-hand children. Go up till we find an ancestor which
+ * is a right-hand child of its parent.
+ */
while ((parent = rb_parent(node)) && node == parent->rb_left)
node = parent;
return parent;
}
+EXPORT_SYMBOL(rb_prev);
void rb_replace_node(struct rb_node *victim, struct rb_node *new,
struct rb_root *root)
@@ -360,14 +500,7 @@ void rb_replace_node(struct rb_node *victim, struct rb_node *new,
struct rb_node *parent = rb_parent(victim);
/* Set the surrounding nodes to point to the replacement */
- if (parent) {
- if (victim == parent->rb_left)
- parent->rb_left = new;
- else
- parent->rb_right = new;
- } else {
- root->rb_node = new;
- }
+ __rb_change_child(victim, new, parent, root);
if (victim->rb_left)
rb_set_parent(victim->rb_left, new);
if (victim->rb_right)
@@ -376,3 +509,44 @@ void rb_replace_node(struct rb_node *victim, struct rb_node *new,
/* Copy the pointers/colour from the victim to the replacement */
*new = *victim;
}
+EXPORT_SYMBOL(rb_replace_node);
+
+static struct rb_node *rb_left_deepest_node(const struct rb_node *node)
+{
+ for (;;) {
+ if (node->rb_left)
+ node = node->rb_left;
+ else if (node->rb_right)
+ node = node->rb_right;
+ else
+ return (struct rb_node *)node;
+ }
+}
+
+struct rb_node *rb_next_postorder(const struct rb_node *node)
+{
+ const struct rb_node *parent;
+ if (!node)
+ return NULL;
+ parent = rb_parent(node);
+
+ /* If we're sitting on node, we've already seen our children */
+ if (parent && node == parent->rb_left && parent->rb_right) {
+ /* If we are the parent's left node, go to the parent's right
+ * node then all the way down to the left */
+ return rb_left_deepest_node(parent->rb_right);
+ } else
+ /* Otherwise we are the parent's right node, and the parent
+ * should be next */
+ return (struct rb_node *)parent;
+}
+EXPORT_SYMBOL(rb_next_postorder);
+
+struct rb_node *rb_first_postorder(const struct rb_root *root)
+{
+ if (!root->rb_node)
+ return NULL;
+
+ return rb_left_deepest_node(root->rb_node);
+}
+EXPORT_SYMBOL(rb_first_postorder);
diff --git a/lib/rsa/rsa-sign.c b/lib/rsa/rsa-sign.c
index 83f5e878389..5d9716f0134 100644
--- a/lib/rsa/rsa-sign.c
+++ b/lib/rsa/rsa-sign.c
@@ -76,6 +76,7 @@ static int rsa_get_pub_key(const char *keydir, const char *name, RSA **rsap)
rsa = EVP_PKEY_get1_RSA(key);
if (!rsa) {
rsa_err("Couldn't convert to a RSA style key");
+ ret = -EINVAL;
goto err_rsa;
}
fclose(f);
@@ -261,10 +262,57 @@ err_priv:
}
/*
+ * rsa_get_exponent(): - Get the public exponent from an RSA key
+ */
+static int rsa_get_exponent(RSA *key, uint64_t *e)
+{
+ int ret;
+ BIGNUM *bn_te;
+ uint64_t te;
+
+ ret = -EINVAL;
+ bn_te = NULL;
+
+ if (!e)
+ goto cleanup;
+
+ if (BN_num_bits(key->e) > 64)
+ goto cleanup;
+
+ *e = BN_get_word(key->e);
+
+ if (BN_num_bits(key->e) < 33) {
+ ret = 0;
+ goto cleanup;
+ }
+
+ bn_te = BN_dup(key->e);
+ if (!bn_te)
+ goto cleanup;
+
+ if (!BN_rshift(bn_te, bn_te, 32))
+ goto cleanup;
+
+ if (!BN_mask_bits(bn_te, 32))
+ goto cleanup;
+
+ te = BN_get_word(bn_te);
+ te <<= 32;
+ *e |= te;
+ ret = 0;
+
+cleanup:
+ if (bn_te)
+ BN_free(bn_te);
+
+ return ret;
+}
+
+/*
* rsa_get_params(): - Get the important parameters of an RSA public key
*/
-int rsa_get_params(RSA *key, uint32_t *n0_invp, BIGNUM **modulusp,
- BIGNUM **r_squaredp)
+int rsa_get_params(RSA *key, uint64_t *exponent, uint32_t *n0_invp,
+ BIGNUM **modulusp, BIGNUM **r_squaredp)
{
BIGNUM *big1, *big2, *big32, *big2_32;
BIGNUM *n, *r, *r_squared, *tmp;
@@ -286,6 +334,9 @@ int rsa_get_params(RSA *key, uint32_t *n0_invp, BIGNUM **modulusp,
return -ENOMEM;
}
+ if (0 != rsa_get_exponent(key, exponent))
+ ret = -1;
+
if (!BN_copy(n, key->n) || !BN_set_word(big1, 1L) ||
!BN_set_word(big2, 2L) || !BN_set_word(big32, 32L))
ret = -1;
@@ -386,6 +437,7 @@ static int fdt_add_bignum(void *blob, int noffset, const char *prop_name,
int rsa_add_verify_data(struct image_sign_info *info, void *keydest)
{
BIGNUM *modulus, *r_squared;
+ uint64_t exponent;
uint32_t n0_inv;
int parent, node;
char name[100];
@@ -397,7 +449,7 @@ int rsa_add_verify_data(struct image_sign_info *info, void *keydest)
ret = rsa_get_pub_key(info->keydir, info->keyname, &rsa);
if (ret)
return ret;
- ret = rsa_get_params(rsa, &n0_inv, &modulus, &r_squared);
+ ret = rsa_get_params(rsa, &exponent, &n0_inv, &modulus, &r_squared);
if (ret)
return ret;
bits = BN_num_bits(modulus);
@@ -442,6 +494,9 @@ int rsa_add_verify_data(struct image_sign_info *info, void *keydest)
if (!ret)
ret = fdt_setprop_u32(keydest, node, "rsa,n0-inverse", n0_inv);
if (!ret) {
+ ret = fdt_setprop_u64(keydest, node, "rsa,exponent", exponent);
+ }
+ if (!ret) {
ret = fdt_add_bignum(keydest, node, "rsa,modulus", modulus,
bits);
}
diff --git a/lib/rsa/rsa-verify.c b/lib/rsa/rsa-verify.c
index bcb906368d0..4ef19b66f4b 100644
--- a/lib/rsa/rsa-verify.c
+++ b/lib/rsa/rsa-verify.c
@@ -26,6 +26,9 @@
#define get_unaligned_be32(a) fdt32_to_cpu(*(uint32_t *)a)
#define put_unaligned_be32(a, b) (*(uint32_t *)(b) = cpu_to_fdt32(a))
+/* Default public exponent for backward compatibility */
+#define RSA_DEFAULT_PUBEXP 65537
+
/**
* subtract_modulus() - subtract modulus from the given value
*
@@ -54,9 +57,9 @@ static void subtract_modulus(const struct rsa_public_key *key, uint32_t num[])
static int greater_equal_modulus(const struct rsa_public_key *key,
uint32_t num[])
{
- uint32_t i;
+ int i;
- for (i = key->len - 1; i >= 0; i--) {
+ for (i = (int)key->len - 1; i >= 0; i--) {
if (num[i] < key->modulus[i])
return 0;
if (num[i] > key->modulus[i])
@@ -123,6 +126,48 @@ static void montgomery_mul(const struct rsa_public_key *key,
}
/**
+ * num_pub_exponent_bits() - Number of bits in the public exponent
+ *
+ * @key: RSA key
+ * @num_bits: Storage for the number of public exponent bits
+ */
+static int num_public_exponent_bits(const struct rsa_public_key *key,
+ int *num_bits)
+{
+ uint64_t exponent;
+ int exponent_bits;
+ const uint max_bits = (sizeof(exponent) * 8);
+
+ exponent = key->exponent;
+ exponent_bits = 0;
+
+ if (!exponent) {
+ *num_bits = exponent_bits;
+ return 0;
+ }
+
+ for (exponent_bits = 1; exponent_bits < max_bits + 1; ++exponent_bits)
+ if (!(exponent >>= 1)) {
+ *num_bits = exponent_bits;
+ return 0;
+ }
+
+ return -EINVAL;
+}
+
+/**
+ * is_public_exponent_bit_set() - Check if a bit in the public exponent is set
+ *
+ * @key: RSA key
+ * @pos: The bit position to check
+ */
+static int is_public_exponent_bit_set(const struct rsa_public_key *key,
+ int pos)
+{
+ return key->exponent & (1ULL << pos);
+}
+
+/**
* pow_mod() - in-place public exponentiation
*
* @key: RSA key
@@ -132,6 +177,7 @@ static int pow_mod(const struct rsa_public_key *key, uint32_t *inout)
{
uint32_t *result, *ptr;
uint i;
+ int j, k;
/* Sanity check for stack size - key->len is in 32-bit words */
if (key->len > RSA_MAX_KEY_BITS / 32) {
@@ -141,18 +187,48 @@ static int pow_mod(const struct rsa_public_key *key, uint32_t *inout)
}
uint32_t val[key->len], acc[key->len], tmp[key->len];
+ uint32_t a_scaled[key->len];
result = tmp; /* Re-use location. */
/* Convert from big endian byte array to little endian word array. */
for (i = 0, ptr = inout + key->len - 1; i < key->len; i++, ptr--)
val[i] = get_unaligned_be32(ptr);
- montgomery_mul(key, acc, val, key->rr); /* axx = a * RR / R mod M */
- for (i = 0; i < 16; i += 2) {
- montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod M */
- montgomery_mul(key, acc, tmp, tmp); /* acc = tmp^2 / R mod M */
+ if (0 != num_public_exponent_bits(key, &k))
+ return -EINVAL;
+
+ if (k < 2) {
+ debug("Public exponent is too short (%d bits, minimum 2)\n",
+ k);
+ return -EINVAL;
}
- montgomery_mul(key, result, acc, val); /* result = XX * a / R mod M */
+
+ if (!is_public_exponent_bit_set(key, 0)) {
+ debug("LSB of RSA public exponent must be set.\n");
+ return -EINVAL;
+ }
+
+ /* the bit at e[k-1] is 1 by definition, so start with: C := M */
+ montgomery_mul(key, acc, val, key->rr); /* acc = a * RR / R mod n */
+ /* retain scaled version for intermediate use */
+ memcpy(a_scaled, acc, key->len * sizeof(a_scaled[0]));
+
+ for (j = k - 2; j > 0; --j) {
+ montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod n */
+
+ if (is_public_exponent_bit_set(key, j)) {
+ /* acc = tmp * val / R mod n */
+ montgomery_mul(key, acc, tmp, a_scaled);
+ } else {
+ /* e[j] == 0, copy tmp back to acc for next operation */
+ memcpy(acc, tmp, key->len * sizeof(acc[0]));
+ }
+ }
+
+ /* the bit at e[0] is always 1 */
+ montgomery_mul(key, tmp, acc, acc); /* tmp = acc^2 / R mod n */
+ montgomery_mul(key, acc, tmp, val); /* acc = tmp * a / R mod M */
+ memcpy(result, acc, key->len * sizeof(result[0]));
/* Make sure result < mod; result is at most 1x mod too large. */
if (greater_equal_modulus(key, result))
@@ -229,6 +305,8 @@ static int rsa_verify_with_keynode(struct image_sign_info *info,
const void *blob = info->fdt_blob;
struct rsa_public_key key;
const void *modulus, *rr;
+ const uint64_t *public_exponent;
+ int length;
int ret;
if (node < 0) {
@@ -241,6 +319,11 @@ static int rsa_verify_with_keynode(struct image_sign_info *info,
}
key.len = fdtdec_get_int(blob, node, "rsa,num-bits", 0);
key.n0inv = fdtdec_get_int(blob, node, "rsa,n0-inverse", 0);
+ public_exponent = fdt_getprop(blob, node, "rsa,exponent", &length);
+ if (!public_exponent || length < sizeof(*public_exponent))
+ key.exponent = RSA_DEFAULT_PUBEXP;
+ else
+ key.exponent = fdt64_to_cpu(*public_exponent);
modulus = fdt_getprop(blob, node, "rsa,modulus", NULL);
rr = fdt_getprop(blob, node, "rsa,r-squared", NULL);
if (!key.len || !modulus || !rr) {