cls_bpf: introduce integrated actions

Often cls_bpf classifier is used with single action drop attached.
Optimize this use case and let cls_bpf return both classid and action.
For backwards compatibility reasons enable this feature under
TCA_BPF_FLAG_ACT_DIRECT flag.

Then more interesting programs like the following are easier to write:
int cls_bpf_prog(struct __sk_buff *skb)
{
  /* classify arp, ip, ipv6 into different traffic classes
   * and drop all other packets
   */
  switch (skb->protocol) {
  case htons(ETH_P_ARP):
    skb->tc_classid = 1;
    break;
  case htons(ETH_P_IP):
    skb->tc_classid = 2;
    break;
  case htons(ETH_P_IPV6):
    skb->tc_classid = 3;
    break;
  default:
    return TC_ACT_SHOT;
  }

  return TC_ACT_OK;
}

Joint work with Daniel Borkmann.

Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Signed-off-by: Alexei Starovoitov <ast@plumgrid.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/sched/cls_bpf.c b/net/sched/cls_bpf.c
index e5168f8..77b0ef1 100644
--- a/net/sched/cls_bpf.c
+++ b/net/sched/cls_bpf.c
@@ -38,6 +38,7 @@
 	struct bpf_prog *filter;
 	struct list_head link;
 	struct tcf_result res;
+	bool exts_integrated;
 	struct tcf_exts exts;
 	u32 handle;
 	union {
@@ -52,6 +53,7 @@
 
 static const struct nla_policy bpf_policy[TCA_BPF_MAX + 1] = {
 	[TCA_BPF_CLASSID]	= { .type = NLA_U32 },
+	[TCA_BPF_FLAGS]		= { .type = NLA_U32 },
 	[TCA_BPF_FD]		= { .type = NLA_U32 },
 	[TCA_BPF_NAME]		= { .type = NLA_NUL_STRING, .len = CLS_BPF_NAME_LEN },
 	[TCA_BPF_OPS_LEN]	= { .type = NLA_U16 },
@@ -59,6 +61,22 @@
 				    .len = sizeof(struct sock_filter) * BPF_MAXINSNS },
 };
 
+static int cls_bpf_exec_opcode(int code)
+{
+	switch (code) {
+	case TC_ACT_OK:
+	case TC_ACT_RECLASSIFY:
+	case TC_ACT_SHOT:
+	case TC_ACT_PIPE:
+	case TC_ACT_STOLEN:
+	case TC_ACT_QUEUED:
+	case TC_ACT_UNSPEC:
+		return code;
+	default:
+		return TC_ACT_UNSPEC;
+	}
+}
+
 static int cls_bpf_classify(struct sk_buff *skb, const struct tcf_proto *tp,
 			    struct tcf_result *res)
 {
@@ -79,6 +97,8 @@
 	list_for_each_entry_rcu(prog, &head->plist, link) {
 		int filter_res;
 
+		qdisc_skb_cb(skb)->tc_classid = prog->res.classid;
+
 		if (at_ingress) {
 			/* It is safe to push/pull even if skb_shared() */
 			__skb_push(skb, skb->mac_len);
@@ -88,6 +108,16 @@
 			filter_res = BPF_PROG_RUN(prog->filter, skb);
 		}
 
+		if (prog->exts_integrated) {
+			res->class = prog->res.class;
+			res->classid = qdisc_skb_cb(skb)->tc_classid;
+
+			ret = cls_bpf_exec_opcode(filter_res);
+			if (ret == TC_ACT_UNSPEC)
+				continue;
+			break;
+		}
+
 		if (filter_res == 0)
 			continue;
 
@@ -195,8 +225,7 @@
 	return ret;
 }
 
-static int cls_bpf_prog_from_ops(struct nlattr **tb,
-				 struct cls_bpf_prog *prog, u32 classid)
+static int cls_bpf_prog_from_ops(struct nlattr **tb, struct cls_bpf_prog *prog)
 {
 	struct sock_filter *bpf_ops;
 	struct sock_fprog_kern fprog_tmp;
@@ -230,15 +259,13 @@
 	prog->bpf_ops = bpf_ops;
 	prog->bpf_num_ops = bpf_num_ops;
 	prog->bpf_name = NULL;
-
 	prog->filter = fp;
-	prog->res.classid = classid;
 
 	return 0;
 }
 
-static int cls_bpf_prog_from_efd(struct nlattr **tb,
-				 struct cls_bpf_prog *prog, u32 classid)
+static int cls_bpf_prog_from_efd(struct nlattr **tb, struct cls_bpf_prog *prog,
+				 const struct tcf_proto *tp)
 {
 	struct bpf_prog *fp;
 	char *name = NULL;
@@ -268,9 +295,7 @@
 	prog->bpf_ops = NULL;
 	prog->bpf_fd = bpf_fd;
 	prog->bpf_name = name;
-
 	prog->filter = fp;
-	prog->res.classid = classid;
 
 	return 0;
 }
@@ -280,8 +305,8 @@
 				   unsigned long base, struct nlattr **tb,
 				   struct nlattr *est, bool ovr)
 {
+	bool is_bpf, is_ebpf, have_exts = false;
 	struct tcf_exts exts;
-	bool is_bpf, is_ebpf;
 	u32 classid;
 	int ret;
 
@@ -298,9 +323,22 @@
 		return ret;
 
 	classid = nla_get_u32(tb[TCA_BPF_CLASSID]);
+	if (tb[TCA_BPF_FLAGS]) {
+		u32 bpf_flags = nla_get_u32(tb[TCA_BPF_FLAGS]);
 
-	ret = is_bpf ? cls_bpf_prog_from_ops(tb, prog, classid) :
-		       cls_bpf_prog_from_efd(tb, prog, classid);
+		if (bpf_flags & ~TCA_BPF_FLAG_ACT_DIRECT) {
+			tcf_exts_destroy(&exts);
+			return -EINVAL;
+		}
+
+		have_exts = bpf_flags & TCA_BPF_FLAG_ACT_DIRECT;
+	}
+
+	prog->res.classid = classid;
+	prog->exts_integrated = have_exts;
+
+	ret = is_bpf ? cls_bpf_prog_from_ops(tb, prog) :
+		       cls_bpf_prog_from_efd(tb, prog, tp);
 	if (ret < 0) {
 		tcf_exts_destroy(&exts);
 		return ret;