diff options
Diffstat (limited to 'rpython/rlib/rstring.py')
-rw-r--r-- | rpython/rlib/rstring.py | 152 |
1 files changed, 134 insertions, 18 deletions
diff --git a/rpython/rlib/rstring.py b/rpython/rlib/rstring.py index 29e1495381..b7bf0b2a16 100644 --- a/rpython/rlib/rstring.py +++ b/rpython/rlib/rstring.py @@ -101,9 +101,13 @@ def _split_by(value, by, maxsplit): start = 0 if bylen == 1: - # fast path: uses str.rfind(character) and str.count(character) + # fast path: uses str.find(character) and str.count(character) by = by[0] # annotator hack: string -> char cnt = count(value, by, 0, len(value)) + if cnt == 0: + if isinstance(value, str): + return [value] + return [value[0:len(value)]] if 0 <= maxsplit < cnt: cnt = maxsplit res = newlist_hint(cnt + 1) @@ -208,12 +212,12 @@ def _rsplit_by(value, by, maxsplit): @specialize.argtype(0, 1) @jit.elidable -def replace(input, sub, by, maxsplit=-1): - return replace_count(input, sub, by, maxsplit)[0] +def replace(input, sub, by, maxcount=-1): + return replace_count(input, sub, by, maxcount)[0] @specialize.ll_and_arg(4) @jit.elidable -def replace_count(input, sub, by, maxsplit=-1, isutf8=False): +def replace_count(input, sub, by, maxcount=-1, isutf8=False): if isinstance(input, str): Builder = StringBuilder elif isinstance(input, unicode): @@ -221,14 +225,14 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False): else: assert isinstance(input, list) Builder = ByteListBuilder - if maxsplit == 0: + if maxcount == 0: return input, 0 if not sub and not isutf8: upper = len(input) - if maxsplit > 0 and maxsplit < upper + 2: - upper = maxsplit - 1 + if maxcount > 0 and maxcount < upper + 2: + upper = maxcount - 1 assert upper >= 0 try: @@ -246,17 +250,27 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False): builder.append(by) builder.append_slice(input, upper, len(input)) replacements = upper + 1 + + elif isinstance(input, str) and len(sub) == 1: + if len(by) == 1: + return replace_count_str_chr_chr(input, sub[0], by[0], maxcount) + return replace_count_str_chr_str(input, sub[0], by, maxcount) + else: # First compute the exact result size if sub: cnt = count(input, sub, 0, len(input)) + if isinstance(input, str) and cnt == 0: + return input, 0 + if isinstance(input, str): + return replace_count_str_str_str(input, sub, by, cnt, maxcount) else: assert isutf8 from rpython.rlib import rutf8 cnt = rutf8.codepoints_in_utf8(input) + 1 - if cnt > maxsplit and maxsplit > 0: - cnt = maxsplit + if cnt > maxcount and maxcount > 0: + cnt = maxcount diff_len = len(by) - len(sub) try: result_size = ovfcheck(diff_len * cnt) @@ -274,26 +288,122 @@ def replace_count(input, sub, by, maxsplit=-1, isutf8=False): from rpython.rlib import rutf8 while True: builder.append(by) - maxsplit -= 1 - if start == len(input) or maxsplit == 0: + maxcount -= 1 + if start == len(input) or maxcount == 0: break next = rutf8.next_codepoint_pos(input, start) builder.append_slice(input, start, next) start = next else: - while maxsplit != 0: + while maxcount != 0: next = find(input, sub, start, len(input)) if next < 0: break builder.append_slice(input, start, next) builder.append(by) start = next + sublen - maxsplit -= 1 # NB. if it's already < 0, it stays < 0 + maxcount -= 1 # NB. if it's already < 0, it stays < 0 builder.append_slice(input, start, len(input)) return builder.build(), replacements +def replace_count_str_chr_chr(input, c1, c2, maxcount): + from rpython.rtyper.annlowlevel import llstr, hlstr + s = llstr(input) + length = len(s.chars) + start = find(input, c1, 0, len(input)) + if start < 0: + return input, 0 + newstr = s.malloc(length) + src = s.chars + dst = newstr.chars + s.copy_contents(s, newstr, 0, 0, len(input)) + dst[start] = c2 + count = 1 + start += 1 + maxcount -= 1 + while maxcount != 0: + next = find(input, c1, start, len(input)) + if next < 0: + break + dst[next] = c2 + start = next + 1 + maxcount -= 1 + count += 1 + + return hlstr(newstr), count + +def replace_count_str_chr_str(input, sub, by, maxcount): + from rpython.rtyper.annlowlevel import llstr, hlstr + cnt = count(input, sub, 0, len(input)) + if cnt == 0: + return input, 0 + if maxcount > 0 and cnt > maxcount: + cnt = maxcount + diff_len = len(by) - 1 + try: + result_size = ovfcheck(diff_len * cnt) + result_size = ovfcheck(result_size + len(input)) + except OverflowError: + raise + + s = llstr(input) + by_ll = llstr(by) + + newstr = s.malloc(result_size) + dst = 0 + start = 0 + while maxcount != 0: + next = find(input, sub, start, len(input)) + if next < 0: + break + s.copy_contents(s, newstr, start, dst, next - start) + dst += next - start + s.copy_contents(by_ll, newstr, 0, dst, len(by)) + dst += len(by) + + start = next + 1 + maxcount -= 1 # NB. if it's already < 0, it stays < 0 + + s.copy_contents(s, newstr, start, dst, len(input) - start) + assert dst - start + len(input) == result_size + return hlstr(newstr), cnt + +def replace_count_str_str_str(input, sub, by, cnt, maxcount): + from rpython.rtyper.annlowlevel import llstr, hlstr + if cnt > maxcount and maxcount > 0: + cnt = maxcount + diff_len = len(by) - len(sub) + try: + result_size = ovfcheck(diff_len * cnt) + result_size = ovfcheck(result_size + len(input)) + except OverflowError: + raise + + s = llstr(input) + by_ll = llstr(by) + newstr = s.malloc(result_size) + sublen = len(sub) + bylen = len(by) + inputlen = len(input) + dst = 0 + start = 0 + while maxcount != 0: + next = find(input, sub, start, inputlen) + if next < 0: + break + s.copy_contents(s, newstr, start, dst, next - start) + dst += next - start + s.copy_contents(by_ll, newstr, 0, dst, bylen) + dst += bylen + start = next + sublen + maxcount -= 1 # NB. if it's already < 0, it stays < 0 + s.copy_contents(s, newstr, start, dst, len(input) - start) + assert dst - start + len(input) == result_size + return hlstr(newstr), cnt + + def _normalize_start_end(length, start, end): if start < 0: start += length @@ -355,20 +465,26 @@ def count(value, other, start, end): return _search(value, other, start, end, SEARCH_COUNT) # -------------- substring searching helper ---------------- -# XXX a lot of code duplication with lltypesystem.rstr :-( SEARCH_COUNT = 0 SEARCH_FIND = 1 SEARCH_RFIND = 2 +@specialize.ll() def bloom_add(mask, c): return mask | (1 << (ord(c) & (BLOOM_WIDTH - 1))) +@specialize.ll() def bloom(mask, c): return mask & (1 << (ord(c) & (BLOOM_WIDTH - 1))) @specialize.argtype(0, 1) def _search(value, other, start, end, mode): + assert value is not None + if isinstance(value, unicode): + NUL = u'\0' + else: + NUL = '\0' if start < 0: start = 0 if end > len(value): @@ -398,7 +514,7 @@ def _search(value, other, start, end, mode): return -1 mlast = m - 1 - skip = mlast - 1 + skip = mlast mask = 0 if mode != SEARCH_RFIND: @@ -411,7 +527,7 @@ def _search(value, other, start, end, mode): i = start - 1 while i + 1 <= start + w: i += 1 - if value[i + m - 1] == other[m - 1]: + if value[i + mlast] == other[mlast]: for j in range(mlast): if value[i + j] != other[j]: break @@ -425,7 +541,7 @@ def _search(value, other, start, end, mode): if i + m < len(value): c = value[i + m] else: - c = '\0' + c = NUL if not bloom(mask, c): i += m else: @@ -434,7 +550,7 @@ def _search(value, other, start, end, mode): if i + m < len(value): c = value[i + m] else: - c = '\0' + c = NUL if not bloom(mask, c): i += m else: |