summaryrefslogtreecommitdiff
path: root/kattis-kth-alginda-quicksort/radix.c
blob: babc0ec126f06668bb487c107e9d3719f74e8b4b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
/* https://kth.kattis.com/problems/kth.alginda.quicksort */

#include <unistd.h>
#include <stdbool.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <stdint.h>
/* #include <immintrin.h> */

// longest int: -2147483648
#define BUFFER_MAX (12 * 600000)
char buffer[BUFFER_MAX];
int xs[600000];
int tmp[600000];

int *radix_sort(int n) {
    int *out = xs;
    int *in = tmp;
    // we loop an even amount of times so the completely sorted array
    // will be in xs, not tmp
    for (int m = 0; m < 32; m += 8) {
        int *tmp = out;
        out = in;
        in = tmp;

        int counts[256] = { 0 };
        for (int i = 0; i < n; i++)
            counts[(uint8_t) (in[i] >> m)]++;

        int sum = 0;
        for (int i = 0; i < 256; i++) {
            sum += counts[i];
            counts[i] = sum;
        }
        for (int i = n - 1; i >= 0; i--) {
            uint8_t fx = in[i] >> m;
            counts[fx]--;
            out[counts[fx]] = in[i];
        }
    }
}

int main() {
    /* char *buffer = malloc(BUFFER_MAX); */
    /* ssize_t len = 0; */
    /* while (true) { */
    /*     if (len >= BUFFER_MAX) return 1; */
    /*     ssize_t r = read(0, &buffer[len], BUFFER_MAX - len); */
    /*     len += r; */
    /*     if (r == 0) */
    /*         break; */
    /* } */
    ssize_t len = read(0, buffer, BUFFER_MAX);

    int curr = 0;
    int n = 0;
    while (buffer[curr] != ' ' && buffer[curr] != '\n') {
        n *= 10;
        n += buffer[curr] - '0';
        curr++;
    }
    while (buffer[curr] == ' ' || buffer[curr] == '\n') curr++;

    for (int i = 0; i < n; i++) {
        int x = 0;
        bool neg = false;
        if (buffer[curr] == '-') {
            neg = true;
            curr++;
        }
        while (curr < len && buffer[curr] != ' ' && buffer[curr] != '\n') {
            x *= 10;
            x += buffer[curr] - '0';
            curr++;
        }
        while (curr < len && (buffer[curr] == ' ' || buffer[curr] == '\n')) curr++;

        if (neg) x = -x;
        xs[i] = x ^ (1 << 31);
    }

    radix_sort(n);

    len = 0;
    for (int i = 0; i < n; i++, len += 12) {
        long x = xs[i] ^ (1 << 31);
        bool neg = false;
        if (x < 0) {
            neg = true;
            x = -x;
        }
        int j = 11;
        buffer[len + j--] = '\n';
        do {
            buffer[len + j--] = (x % 10) + '0';
            x /= 10;
        } while (x > 0);
        if (neg) buffer[len + j--] = '-';
        while (j >= 0) buffer[len + j--] = ' ';
    }

    /*
    len = 0;

    __m256i highestbit = _mm256_set1_epi32(1 << 32);
    __m256i zero = _mm256_set1_epi32(0);
    __m256i lens = _mm256_set_epi32(0 * 12,
                                    1 * 12,
                                    2 * 12,
                                    3 * 12,
                                    4 * 12,
                                    5 * 12,
                                    6 * 12,
                                    7 * 12);
    __m256i lenoffset = _mm256_set1_epi(8 * 12);

    for (int i = 0; i < n; i += 8) {
        // long x = xs[i] ^ (1 << 31);
        __m256i xx = _mm256_load_si256(&xs[i]);
        xx = _mm256_xor_si256(xx, highestbit);
        // bool neg = false;
        // if (x < 0) {
        //     neg = true;
        //     x = -x;
        // }
        __m256i negs = _mm256_cmpgt(zero, xx);
        xx = _mm256_abs_epi64(xx);
        // todo...
        int j = 11;
        buffer[len + j--] = '\n';
        do {
            buffer[len + j--] = (x % 10) + '0';
            x /= 10;
        } while (x > 0);
        if (neg) buffer[len + j--] = '-';
        while (j >= 0) buffer[len + j--] = ' ';

        // len += 8 * 12
        lens = _mm256_add_epi32(lens, lenoffset);
    }
    */
    /* ssize_t c = 0; */
    /* while (c < len) c += write(1, &buffer[c], len - c); */
    write(1, buffer, len);
}