emmc: enable write unifykey [1/1]

PD#SWPL-152984

Problem:
s7 not support write unifykey

Solution:
enable write unifykey

Verify:
s7

Change-Id: Idd6b9a5b428c6fa0906f68f8891c1a40e414d3e6
Signed-off-by: Ruixuan.li <ruixuan.li@amlogic.com>
diff --git a/drivers/amlogic/mmc/aml_emmc_partition.c b/drivers/amlogic/mmc/aml_emmc_partition.c
index 552a92c..918e6b0 100644
--- a/drivers/amlogic/mmc/aml_emmc_partition.c
+++ b/drivers/amlogic/mmc/aml_emmc_partition.c
@@ -20,6 +20,7 @@
 #include <amlogic/partition_table.h>
 #include <linux/compat.h>
 #include <u-boot/crc.h>
+#include <amlogic/aml_mmc.h>
 
 DECLARE_GLOBAL_DATA_PTR;
 /* using mbr*/
@@ -94,6 +95,14 @@
 	int count;	/* partition count in use */
 };
 
+struct aml_key_info {
+	u64 checksum;
+	u32 stamp;
+	u32 magic;
+};
+
+struct aml_key_info key_infos[2] = { {0, 0, 0}, {0, 0, 0} };
+
 unsigned device_boot_flag = 0xff;
 extern bool is_partition_checked;
 
@@ -1555,6 +1564,7 @@
 int mmc_partition_init(void)
 {
 	struct mmc *mmc = find_mmc_device(1);
+	struct _iptbl iptbl_inh;
 	int ret;
 
 	if (!mmc) {
@@ -1562,6 +1572,12 @@
 		return -1;
 	}
 
+	iptbl_inh.count = get_emmc_partition_arraysize();
+	if (iptbl_inh.count) {
+		iptbl_inh.partitions = emmc_partition_table;
+		_calculate_offset(mmc, &iptbl_inh, 0);
+	}
+
 	if (!p_iptbl_ept) {
 		ret = _zalloc_iptbl(&p_iptbl_ept);
 		if (ret)
@@ -1900,3 +1916,324 @@
 _out:
 	return ret;
 }
+
+/* unifykey backup distribution */
+
+/*--------------------------------------------------------
+ * offset |  0x12020  |  0x12220  |  0x12420  |  0x12421  |
+ *--------------------------------------------------------
+ *  size  | 200 block | 200 block |  1 block  |  1 block  |
+ *--------------------------------------------------------
+ *content |    key1   |    key2   | checksum1 | checksum2 |
+ *--------------------------------------------------------
+ */
+static u64 _calc_key_checksum(void *addr, int size)
+{
+	int i = 0;
+	u32 *buffer;
+	u64 checksum = 0;
+
+	if ((u64)addr % 4 != 0)
+		BUG();
+
+	buffer = (u32 *)addr;
+	size = size >> 2;
+	while (i < size)
+		checksum += buffer[i++];
+
+	return checksum;
+}
+
+static int _key_read(struct mmc *mmc, u64 blk, u64 cnt, void * addr)
+{
+	int dev = EMMC_DTB_DEV;
+	u64 n;
+	n = blk_dread(mmc_get_blk_desc(mmc), blk, cnt, addr);
+	if (n != cnt) {
+		printf("%s: dev # %d, block # %#llx, count # %#llx ERROR!\n",
+				__func__, dev, blk, cnt);
+	}
+
+	return (n != cnt);
+}
+
+static int _verify_key_checksum(struct mmc *mmc, void *addr, int cpy)
+{
+	u64 checksum;
+	int ret = 0;
+	u64 blk, key_glb_offset;
+	struct partitions * part = NULL;
+	struct virtual_partition *vpart = NULL;
+	char checksum_info[512] = {0};
+
+	vpart = aml_get_virtual_partition_by_name(MMC_KEY_NAME);
+	part = aml_get_partition_by_name(MMC_RESERVED_NAME);
+	key_glb_offset = part->offset + vpart->offset;
+
+	blk = (key_glb_offset + 2 * (vpart->size)) / MMC_BLOCK_SIZE + cpy;
+	ret = _key_read(mmc, blk, 1, (void *)checksum_info);
+	if (ret)
+		return -1;
+
+	memcpy(&key_infos[cpy], checksum_info, sizeof(struct aml_key_info));
+
+	checksum = _calc_key_checksum(addr, vpart->size);
+	printf("calc %llx, store %llx\n", checksum, key_infos[cpy].checksum);
+
+	return !(checksum == key_infos[cpy].checksum);
+}
+
+static int update_key_info(struct mmc *mmc, unsigned char *addr)
+{
+	int ret = 0;
+	u64 blk, cnt, key_glb_offset;
+	int cpy = 1;
+	struct partitions * part = NULL;
+	struct virtual_partition *vpart = NULL;
+	int valid_flag = 0;
+
+	vpart = aml_get_virtual_partition_by_name(MMC_KEY_NAME);
+	part = aml_get_partition_by_name(MMC_RESERVED_NAME);
+	key_glb_offset = part->offset + vpart->offset;
+
+	while (cpy >= 0) {
+		blk = (key_glb_offset + cpy * (vpart->size)) / MMC_BLOCK_SIZE;
+		cnt = vpart->size / mmc->read_bl_len;
+		ret = _key_read(mmc, blk, cnt, addr);
+		if (ret) {
+			printf("%s: block # %#llx, cnt # %#llx ERROR!\n",
+				__func__, blk, cnt);
+			return -1;
+		}
+
+		ret = _verify_key_checksum(mmc, addr, cpy);
+		if (!ret && key_infos[cpy].magic != 0)
+			valid_flag += cpy + 1;
+		else
+			printf("cpy %d is not valid\n", cpy);
+		cpy--;
+	}
+
+	if (key_infos[0].stamp > key_infos[1].stamp)
+		mmc->key_stamp = key_infos[0].stamp;
+	else
+		mmc->key_stamp = key_infos[1].stamp;
+
+	return valid_flag;
+}
+
+static int _key_write(struct mmc *mmc, u64 blk, u64 cnt, void *addr)
+{
+	int dev = STORAGE_EMMC;
+	u32 n;
+	n = blk_dwrite(mmc_get_blk_desc(mmc), blk, cnt, addr);
+	if (n != cnt) {
+		printf("%s: dev # %d, block # %#llx, count # %#llx ERROR!\n",
+				__func__, dev, blk, cnt);
+	}
+
+	return (n != cnt);
+}
+
+static int write_invalid_key(struct mmc *mmc, void *addr, int valid_flag)
+{
+	u64 blk, cnt, key_glb_offset;
+	int ret;
+	struct partitions * part = NULL;
+	struct virtual_partition *vpart = NULL;
+	char checksum_info[512] = {0};
+
+	if (valid_flag > 2 || valid_flag < 1)
+		return 1;
+
+	vpart = aml_get_virtual_partition_by_name(MMC_KEY_NAME);
+	part = aml_get_partition_by_name(MMC_RESERVED_NAME);
+	key_glb_offset = part->offset + vpart->offset;
+
+	blk = (key_glb_offset + (valid_flag - 1) * (vpart->size)) / MMC_BLOCK_SIZE;
+	cnt = vpart->size / mmc->read_bl_len;
+
+	if (_key_read(mmc, blk, cnt, addr)) {
+	printf("%s: block # %#llx,cnt # %#llx ERROR!\n",
+			__func__, blk, cnt);
+		ret = -2;
+	}
+	/* fixme, update the invalid one - key1 */
+	blk = (key_glb_offset + (valid_flag % 2) * vpart->size) / MMC_BLOCK_SIZE;
+	if (_key_write(mmc, blk, cnt, addr)) {
+		printf("%s: block # %#llx,cnt # %#llx ERROR!\n",
+			__func__, blk, cnt);
+		ret = -4;
+	}
+
+	memcpy(checksum_info, &key_infos[valid_flag - 1], sizeof(struct aml_key_info));
+	blk = (key_glb_offset + 2 * (vpart->size)) / MMC_BLOCK_SIZE + valid_flag % 2;
+	if (_key_write(mmc, blk, 1, checksum_info)) {
+		printf("%s: block # %#llx,cnt # %#llx ERROR!\n",
+			__func__, blk, cnt);
+		ret = -4;
+	}
+
+	return ret;
+}
+
+static int update_invalid_key(struct mmc *mmc, void *addr, int valid_flag)
+{
+	int ret = 0, dev = STORAGE_EMMC;
+	u64 blk, cnt, key_glb_offset;
+	struct partitions * part = NULL;
+	struct virtual_partition *vpart = NULL;
+	char checksum_info[512] = {0};
+
+	vpart = aml_get_virtual_partition_by_name(MMC_KEY_NAME);
+	part = aml_get_partition_by_name(MMC_RESERVED_NAME);
+	key_glb_offset = part->offset + vpart->offset;
+	cnt = vpart->size / mmc->read_bl_len;
+
+	if (valid_flag == 2) {
+		printf("update key1");
+		ret = write_invalid_key(mmc, addr, valid_flag);
+		if (ret)
+			ret = -2;
+	} else {
+		printf("update key2");
+		blk = (key_glb_offset + vpart->size) / MMC_BLOCK_SIZE;
+		if (_key_write(mmc, blk, cnt, addr)) {
+			printf("%s: dev # %d, block # %#llx,cnt # %#llx ERROR!\n",
+				__func__, dev, blk, cnt);
+			ret = -2;
+		}
+		memcpy(checksum_info, &key_infos[valid_flag - 1],
+				sizeof(struct aml_key_info));
+		blk = (key_glb_offset + 2 * (vpart->size)) / MMC_BLOCK_SIZE + valid_flag % 2;
+		if (_key_write(mmc, blk, 1, checksum_info)) {
+			printf("%s: block # %#llx,cnt # %#llx ERROR!\n",
+				__func__, blk, cnt);
+			ret = -4;
+		}
+	}
+	return ret;
+}
+
+int update_old_key(struct mmc *mmc, void *addr)
+{
+	int ret = 0;
+	int valid_flag;
+
+	if (stamp_after(key_infos[1].stamp, key_infos[0].stamp)) {
+		memcpy(&key_infos[1], &key_infos[0], sizeof(struct aml_key_info));
+		valid_flag = 2;
+	} else if (stamp_after(key_infos[0].stamp, key_infos[1].stamp)) {
+		memcpy(&key_infos[0], &key_infos[1], sizeof(struct aml_key_info));
+		valid_flag = 1;
+	} else {
+		printf("do nothing\n");
+		return ret;
+	}
+
+	ret = write_invalid_key(mmc, addr, valid_flag);
+	/*update key*/
+	if (ret)
+		ret = -3;
+	mmc->key_stamp = key_infos[0].stamp;
+	return ret;
+}
+
+static struct mmc *_rsv_init(void)
+{
+	struct mmc *mmc = find_mmc_device(STORAGE_EMMC);
+	if (!mmc) {
+		printf("not find mmc\n");
+		return NULL;
+	}
+
+	if (mmc_init(mmc)) {
+		printf("mmc init failed\n");
+		return NULL;
+	}
+
+	return mmc;
+}
+
+int mmc_key_write_backup(const char *name,
+			      unsigned char *addr, unsigned int size)
+{
+	int ret = 0;
+	u64 blk, cnt, key_glb_offset;
+	int cpy;
+	struct mmc * mmc;
+	struct partitions * part = NULL;
+	struct virtual_partition *vpart = NULL;
+	char checksum_info[512] = {0};
+
+	vpart = aml_get_virtual_partition_by_name(MMC_KEY_NAME);
+	part = aml_get_partition_by_name(MMC_RESERVED_NAME);
+	key_glb_offset = part->offset + vpart->offset;
+
+	mmc = _rsv_init();
+	if (mmc == NULL)
+		return -10;
+
+	key_infos[0].stamp =  mmc->key_stamp + 1;
+	key_infos[0].magic = 9;
+	key_infos[0].checksum = _calc_key_checksum(addr, vpart->size);
+	printf("new stamp %d, checksum 0x%llx, magic %d\n",
+		key_infos[0].stamp, key_infos[0].checksum, key_infos[0].magic);
+
+	memcpy(checksum_info, &key_infos[0], sizeof(struct aml_key_info));
+
+	for (cpy = 0; cpy < KEY_COPIES; cpy++) {
+		blk = (key_glb_offset + cpy * (vpart->size)) / MMC_BLOCK_SIZE;
+		cnt = vpart->size / mmc->read_bl_len;
+		ret |= _key_write(mmc, blk, cnt, addr);
+
+		blk = (key_glb_offset + 2 * (vpart->size)) / MMC_BLOCK_SIZE + cpy;
+		ret |= _key_write(mmc, blk, 1, checksum_info);
+	}
+
+	if (ret) {
+		printf("%s() %d: emmc init %d\n", __func__, __LINE__, ret);
+		ret = -2;
+	}
+
+	return ret;
+}
+
+int mmc_key_read_backup(const char *name,
+			      unsigned char *addr, unsigned int size)
+{
+	int valid = 0;
+	struct mmc *mmc;
+
+	mmc = _rsv_init();
+	if (mmc == NULL)
+		return -10;
+
+	/* check valid key flag , addr save the first key content */
+	valid = update_key_info(mmc, addr);
+	switch (valid) {
+		/* none is valid, using the 1st one for compatibility*/
+		case 0:
+			goto _out;
+		break;
+		/* only first is valid, using the first update the second */
+		case 1:
+			update_invalid_key(mmc, addr, 1);
+		break;
+		/* only second is valid, using the second */
+		case 2:
+			update_invalid_key(mmc, addr, 2);
+		break;
+		case 3:
+		/*update the old key */
+			update_old_key(mmc, addr);
+		break;
+		default:
+			printf("impossible valid values.\n");
+			BUG();
+		break;
+	}
+_out:
+	return 0;
+}
+