1 // SPDX-License-Identifier: GPL-2.0 2 #include "comm.h" 3 #include "util.h" 4 #include <errno.h> 5 #include <stdlib.h> 6 #include <stdio.h> 7 #include <string.h> 8 #include <linux/refcount.h> 9 #include <linux/rbtree.h> 10 #include "rwsem.h" 11 12 struct comm_str { 13 char *str; 14 struct rb_node rb_node; 15 refcount_t refcnt; 16 }; 17 18 /* Should perhaps be moved to struct machine */ 19 static struct rb_root comm_str_root; 20 static struct rw_semaphore comm_str_lock = {.lock = PTHREAD_RWLOCK_INITIALIZER,}; 21 22 static struct comm_str *comm_str__get(struct comm_str *cs) 23 { 24 if (cs && refcount_inc_not_zero(&cs->refcnt)) 25 return cs; 26 27 return NULL; 28 } 29 30 static void comm_str__put(struct comm_str *cs) 31 { 32 if (cs && refcount_dec_and_test(&cs->refcnt)) { 33 down_write(&comm_str_lock); 34 rb_erase(&cs->rb_node, &comm_str_root); 35 up_write(&comm_str_lock); 36 zfree(&cs->str); 37 free(cs); 38 } 39 } 40 41 static struct comm_str *comm_str__alloc(const char *str) 42 { 43 struct comm_str *cs; 44 45 cs = zalloc(sizeof(*cs)); 46 if (!cs) 47 return NULL; 48 49 cs->str = strdup(str); 50 if (!cs->str) { 51 free(cs); 52 return NULL; 53 } 54 55 refcount_set(&cs->refcnt, 1); 56 57 return cs; 58 } 59 60 static 61 struct comm_str *__comm_str__findnew(const char *str, struct rb_root *root) 62 { 63 struct rb_node **p = &root->rb_node; 64 struct rb_node *parent = NULL; 65 struct comm_str *iter, *new; 66 int cmp; 67 68 while (*p != NULL) { 69 parent = *p; 70 iter = rb_entry(parent, struct comm_str, rb_node); 71 72 /* 73 * If we race with comm_str__put, iter->refcnt is 0 74 * and it will be removed within comm_str__put call 75 * shortly, ignore it in this search. 76 */ 77 cmp = strcmp(str, iter->str); 78 if (!cmp && comm_str__get(iter)) 79 return iter; 80 81 if (cmp < 0) 82 p = &(*p)->rb_left; 83 else 84 p = &(*p)->rb_right; 85 } 86 87 new = comm_str__alloc(str); 88 if (!new) 89 return NULL; 90 91 rb_link_node(&new->rb_node, parent, p); 92 rb_insert_color(&new->rb_node, root); 93 94 return new; 95 } 96 97 static struct comm_str *comm_str__findnew(const char *str, struct rb_root *root) 98 { 99 struct comm_str *cs; 100 101 down_write(&comm_str_lock); 102 cs = __comm_str__findnew(str, root); 103 up_write(&comm_str_lock); 104 105 return cs; 106 } 107 108 struct comm *comm__new(const char *str, u64 timestamp, bool exec) 109 { 110 struct comm *comm = zalloc(sizeof(*comm)); 111 112 if (!comm) 113 return NULL; 114 115 comm->start = timestamp; 116 comm->exec = exec; 117 118 comm->comm_str = comm_str__findnew(str, &comm_str_root); 119 if (!comm->comm_str) { 120 free(comm); 121 return NULL; 122 } 123 124 return comm; 125 } 126 127 int comm__override(struct comm *comm, const char *str, u64 timestamp, bool exec) 128 { 129 struct comm_str *new, *old = comm->comm_str; 130 131 new = comm_str__findnew(str, &comm_str_root); 132 if (!new) 133 return -ENOMEM; 134 135 comm_str__put(old); 136 comm->comm_str = new; 137 comm->start = timestamp; 138 if (exec) 139 comm->exec = true; 140 141 return 0; 142 } 143 144 void comm__free(struct comm *comm) 145 { 146 comm_str__put(comm->comm_str); 147 free(comm); 148 } 149 150 const char *comm__str(const struct comm *comm) 151 { 152 return comm->comm_str->str; 153 } 154