Rewrite this function to use numba and nopython mode. You can assume that the input is a 2d matrix correctly formatted. Call it rref still. def rref(matrix, graphlen): PH = 16 extra = 0 col = 0 pivots = [] used_for_row = [set([i]) for i in range(matrix.shape[0])] matrix = matrix%PH while col+extra < matrix.shape[1]-1 and col < matrix.shape[0]: if matrix[col,col+extra] == 0: if np.all(matrix[:,col] == 0): extra += 1 continue other = np.argwhere(matrix[:,col+extra] != 0).flatten()[-1] if other < col: extra += 1 continue matrix[col], matrix[other] = list(matrix[other]), list(matrix[col]) used_for_row[col], used_for_row[other] = used_for_row[other], used_for_row[col] pivots.append(col+extra) pivot = matrix[col,col+extra] if col+extra < graphlen: assert np.abs(pivot) == 1 or np.abs(pivot) == PH-1 else: assert np.abs(pivot) == 2 or np.abs(pivot) == PH-2 pivot //= 2 matrix[col] *= pivot matrix[col] %= PH others = np.argwhere(matrix[:,col+extra]).flatten() for i in others: if i == col: continue used_for_row[i] |= used_for_row[col] if col < graphlen: matrix[i] -= matrix[col]*matrix[i,col+extra] else: while matrix[i,col+extra] != 0: matrix[i] = (matrix[i]-matrix[col])%PH matrix[i] %= PH col += 1 matrix = np.array(matrix)%PH return [sorted(x) for x in used_for_row]
Rewrite this function to use numba and nopython mode. You can assume that the input is a 2d matrix correctly formatted. Call it rref still. def rref(matrix, graphlen): PH = 16 extra = 0 col = 0 pivots = [] used_for_row = [set([i]) for i in range(matrix.shape[0])] matrix = matrix%PH while col+extra < matrix.shape[1]-1 and col < matrix.shape[0]: if matrix[col,col+extra] == 0: if np.all(matrix[:,col] == 0): extra += 1 continue other = np.argwhere(matrix[:,col+extra] != 0).flatten()[-1] if other < col: extra += 1 continue matrix[col], matrix[other] = list(matrix[other]), list(matrix[col]) used_for_row[col], used_for_row[other] = used_for_row[other], used_for_row[col] pivots.append(col+extra) pivot = matrix[col,col+extra] if col+extra < graphlen: assert np.abs(pivot) == 1 or np.abs(pivot) == PH-1 else: assert np.abs(pivot) == 2 or np.abs(pivot) == PH-2 pivot //= 2 matrix[col] *= pivot matrix[col] %= PH others = np.argwhere(matrix[:,col+extra]).flatten() for i in others: if i == col: continue used_for_row[i] |= used_for_row[col] if col < graphlen: matrix[i] -= matrix[col]*matrix[i,col+extra] else: while matrix[i,col+extra] != 0: matrix[i] = (matrix[i]-matrix[col])%PH matrix[i] %= PH col += 1 matrix = np.array(matrix)%PH return [sorted(x) for x in used_for_row]
Here's a numba-compatible version of the function using `@njit` (nopython mode). Note that some modifications were necessary since numba doesn't support Python sets or certain NumPy operations:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | from numba import njit import numpy as np @njit def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize variables extra = 0 col = 0 pivots = np.zeros(cols, dtype=np.int64) pivot_count = 0 # Create used_for_row as a boolean matrix instead of sets used_for_row = np.zeros((rows, rows), dtype=np.int64) for i in range(rows): used_for_row[i, i] = 1 # Work with a copy of the matrix matrix = matrix % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col + extra] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows for j in range(cols): temp = matrix[col, j] matrix[col, j] = matrix[other, j] matrix[other, j] = temp # Swap used_for_row for j in range(rows): temp = used_for_row[col, j] used_for_row[col, j] = used_for_row[other, j] used_for_row[other, j] = temp pivots[pivot_count] = col + extra pivot_count += 1 pivot = matrix[col, col + extra] if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot value") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot value") pivot //= 2 # Multiply row by pivot for j in range(cols): matrix[col, j] = (matrix[col, j] * pivot) % PH # Process other rows for i in range(rows): if i == col: continue if matrix[i, col + extra] != 0: # Update used_for_row for k in range(rows): if used_for_row[col, k] == 1: used_for_row[i, k] = 1 if col < graphlen: mult = matrix[i, col + extra] for j in range(cols): matrix[i, j] = (matrix[i, j] - mult * matrix[col, j]) % PH else: while matrix[i, col + extra] != 0: for j in range(cols): matrix[i, j] = (matrix[i, j] - matrix[col, j]) % PH col += 1 matrix = matrix % PH # Convert used_for_row back to list format result = [] for i in range(rows): row_indices = [] for j in range(rows): if used_for_row[i, j] == 1: row_indices.append(j) result.append(row_indices) return result |
1 2 3 4 5 6 7 | import numpy as np from numba import njit # Example usage: matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int64) graphlen = 2 result = rref(matrix, graphlen) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | from numba import njit import numpy as np @njit def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize variables extra = 0 col = 0 pivots = np.zeros(cols, dtype=np.int64) pivot_count = 0 # Create used_for_row as a boolean matrix instead of sets used_for_row = np.zeros((rows, rows), dtype=np.int64) for i in range(rows): used_for_row[i, i] = 1 # Work with a copy of the matrix matrix = matrix % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col + extra] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows for j in range(cols): temp = matrix[col, j] matrix[col, j] = matrix[other, j] matrix[other, j] = temp # Swap used_for_row for j in range(rows): temp = used_for_row[col, j] used_for_row[col, j] = used_for_row[other, j] used_for_row[other, j] = temp pivots[pivot_count] = col + extra pivot_count += 1 pivot = matrix[col, col + extra] if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot value") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot value") pivot //= 2 # Multiply row by pivot for j in range(cols): matrix[col, j] = (matrix[col, j] * pivot) % PH # Process other rows for i in range(rows): if i == col: continue if matrix[i, col + extra] != 0: # Update used_for_row for k in range(rows): if used_for_row[col, k] == 1: used_for_row[i, k] = 1 if col < graphlen: mult = matrix[i, col + extra] for j in range(cols): matrix[i, j] = (matrix[i, j] - mult * matrix[col, j]) % PH else: while matrix[i, col + extra] != 0: for j in range(cols): matrix[i, j] = (matrix[i, j] - matrix[col, j]) % PH col += 1 matrix = matrix % PH # Convert used_for_row back to list format result = [] for i in range(rows): row_indices = [] for j in range(rows): if used_for_row[i, j] == 1: row_indices.append(j) result.append(row_indices) return result |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 | from numba import njit import numpy as np @njit def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize variables extra = 0 col = 0 pivots = np.zeros(cols, dtype=np.int64) pivot_count = 0 # Create used_for_row as a boolean matrix instead of sets used_for_row = np.zeros((rows, rows), dtype=np.int64) for i in range(rows): used_for_row[i, i] = 1 # Work with a copy of the matrix matrix = matrix % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col + extra] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows for j in range(cols): temp = matrix[col, j] matrix[col, j] = matrix[other, j] matrix[other, j] = temp # Swap used_for_row for j in range(rows): temp = used_for_row[col, j] used_for_row[col, j] = used_for_row[other, j] used_for_row[other, j] = temp pivots[pivot_count] = col + extra pivot_count += 1 pivot = matrix[col, col + extra] if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot value") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot value") pivot //= 2 # Multiply row by pivot for j in range(cols): matrix[col, j] = (matrix[col, j] * pivot) % PH # Process other rows for i in range(rows): if i == col: continue if matrix[i, col + extra] != 0: # Update used_for_row for k in range(rows): if used_for_row[col, k] == 1: used_for_row[i, k] = 1 if col < graphlen: mult = matrix[i, col + extra] for j in range(cols): matrix[i, j] = (matrix[i, j] - mult * matrix[col, j]) % PH else: while matrix[i, col + extra] != 0: for j in range(cols): matrix[i, j] = (matrix[i, j] - matrix[col, j]) % PH col += 1 matrix = matrix % PH # Convert used_for_row back to list format result = [] for i in range(rows): row_indices = [] for j in range(rows): if used_for_row[i, j] == 1: row_indices.append(j) result.append(row_indices) return result arr = np.array([[1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -20], [0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -20], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, -30], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, -20], [1, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -14], [0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, -30], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, -1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, -20], [0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -9], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, -20], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, -40], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) answer = rref(arr, 37) expected = [[0, 38], [0, 38], [0, 6, 28, 35, 38], [0, 6, 18, 21, 28, 35, 38], [0, 6, 28, 35, 38], [0, 6, 21, 28, 35, 38], [0, 5, 10, 23, 38], [0, 5, 10, 23, 38], [0, 5, 10, 23, 38], [0, 5, 9, 10, 23, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 11, 23, 27, 38], [0, 5, 9, 10, 11, 23, 27, 38], [0, 5, 9, 10, 11, 23, 25, 27, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 24, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 13, 17, 23, 30, 32, 38], [0, 5, 9, 10, 13, 17, 23, 30, 32, 33, 38], [0, 22, 29, 38], [0, 22, 29, 38], [0, 6, 28, 35, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [14, 18, 21, 28, 35]] assert answer == expected, f'Wrong answer; got {answer} instead of {expected}' print('All tests passed') |
All tests passed
Testing if the previous output contains the string All tests passed
: True
Testing if the previous output contains the string njit
: True
Testing if the previous output contains the string nopython
: False
Testing if the previous output contains the string import numba
: False
Testing if the previous output contains the string from numba
: True
Rewrite this function to use numba and nopython mode. You can assume that the input is a 2d matrix correctly formatted. Call it rref still. def rref(matrix, graphlen): PH = 16 extra = 0 col = 0 pivots = [] used_for_row = [set([i]) for i in range(matrix.shape[0])] matrix = matrix%PH while col+extra < matrix.shape[1]-1 and col < matrix.shape[0]: if matrix[col,col+extra] == 0: if np.all(matrix[:,col] == 0): extra += 1 continue other = np.argwhere(matrix[:,col+extra] != 0).flatten()[-1] if other < col: extra += 1 continue matrix[col], matrix[other] = list(matrix[other]), list(matrix[col]) used_for_row[col], used_for_row[other] = used_for_row[other], used_for_row[col] pivots.append(col+extra) pivot = matrix[col,col+extra] if col+extra < graphlen: assert np.abs(pivot) == 1 or np.abs(pivot) == PH-1 else: assert np.abs(pivot) == 2 or np.abs(pivot) == PH-2 pivot //= 2 matrix[col] *= pivot matrix[col] %= PH others = np.argwhere(matrix[:,col+extra]).flatten() for i in others: if i == col: continue used_for_row[i] |= used_for_row[col] if col < graphlen: matrix[i] -= matrix[col]*matrix[i,col+extra] else: while matrix[i,col+extra] != 0: matrix[i] = (matrix[i]-matrix[col])%PH matrix[i] %= PH col += 1 matrix = np.array(matrix)%PH return [sorted(x) for x in used_for_row]
Rewrite this function to use numba and nopython mode. You can assume that the input is a 2d matrix correctly formatted. Call it rref still. def rref(matrix, graphlen): PH = 16 extra = 0 col = 0 pivots = [] used_for_row = [set([i]) for i in range(matrix.shape[0])] matrix = matrix%PH while col+extra < matrix.shape[1]-1 and col < matrix.shape[0]: if matrix[col,col+extra] == 0: if np.all(matrix[:,col] == 0): extra += 1 continue other = np.argwhere(matrix[:,col+extra] != 0).flatten()[-1] if other < col: extra += 1 continue matrix[col], matrix[other] = list(matrix[other]), list(matrix[col]) used_for_row[col], used_for_row[other] = used_for_row[other], used_for_row[col] pivots.append(col+extra) pivot = matrix[col,col+extra] if col+extra < graphlen: assert np.abs(pivot) == 1 or np.abs(pivot) == PH-1 else: assert np.abs(pivot) == 2 or np.abs(pivot) == PH-2 pivot //= 2 matrix[col] *= pivot matrix[col] %= PH others = np.argwhere(matrix[:,col+extra]).flatten() for i in others: if i == col: continue used_for_row[i] |= used_for_row[col] if col < graphlen: matrix[i] -= matrix[col]*matrix[i,col+extra] else: while matrix[i,col+extra] != 0: matrix[i] = (matrix[i]-matrix[col])%PH matrix[i] %= PH col += 1 matrix = np.array(matrix)%PH return [sorted(x) for x in used_for_row]
Here's the numba-optimized version of the function using nopython mode. Note that we need to make some adjustments since numba doesn't support certain Python operations like sets:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import numba import numpy as np @numba.jit(nopython=True) def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize tracking arrays extra = 0 col = 0 # Create array to track used rows (instead of sets) # Using a 2D boolean array where each row represents what rows were used used_for_row = np.zeros((rows, rows), dtype=np.bool_) for i in range(rows): used_for_row[i, i] = True matrix = matrix % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col + extra] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows temp_row = matrix[col].copy() matrix[col] = matrix[other] matrix[other] = temp_row # Swap used_for_row tracking temp_used = used_for_row[col].copy() used_for_row[col] = used_for_row[other] used_for_row[other] = temp_used pivot = matrix[col, col + extra] # Multiply row by inverse of pivot if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot value") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot value") pivot //= 2 matrix[col] = (matrix[col] * pivot) % PH # Eliminate in other rows for i in range(rows): if i != col and matrix[i, col + extra] != 0: # Update used_for_row for j in range(rows): if used_for_row[col, j]: used_for_row[i, j] = True if col < graphlen: mult = matrix[i, col + extra] matrix[i] = (matrix[i] - matrix[col] * mult) % PH else: while matrix[i, col + extra] != 0: matrix[i] = (matrix[i] - matrix[col]) % PH col += 1 matrix = matrix % PH # Convert used_for_row to list of sorted indices result = [] for i in range(rows): row_indices = [] for j in range(rows): if used_for_row[i, j]: row_indices.append(j) result.append(row_indices) return result |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import numba import numpy as np @numba.jit(nopython=True) def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize tracking arrays extra = 0 col = 0 # Create array to track used rows (instead of sets) # Using a 2D boolean array where each row represents what rows were used used_for_row = np.zeros((rows, rows), dtype=np.bool_) for i in range(rows): used_for_row[i, i] = True matrix = matrix % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col + extra] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows temp_row = matrix[col].copy() matrix[col] = matrix[other] matrix[other] = temp_row # Swap used_for_row tracking temp_used = used_for_row[col].copy() used_for_row[col] = used_for_row[other] used_for_row[other] = temp_used pivot = matrix[col, col + extra] # Multiply row by inverse of pivot if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot value") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot value") pivot //= 2 matrix[col] = (matrix[col] * pivot) % PH # Eliminate in other rows for i in range(rows): if i != col and matrix[i, col + extra] != 0: # Update used_for_row for j in range(rows): if used_for_row[col, j]: used_for_row[i, j] = True if col < graphlen: mult = matrix[i, col + extra] matrix[i] = (matrix[i] - matrix[col] * mult) % PH else: while matrix[i, col + extra] != 0: matrix[i] = (matrix[i] - matrix[col]) % PH col += 1 matrix = matrix % PH # Convert used_for_row to list of sorted indices result = [] for i in range(rows): row_indices = [] for j in range(rows): if used_for_row[i, j]: row_indices.append(j) result.append(row_indices) return result |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | import numba import numpy as np @numba.jit(nopython=True) def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize tracking arrays extra = 0 col = 0 # Create array to track used rows (instead of sets) # Using a 2D boolean array where each row represents what rows were used used_for_row = np.zeros((rows, rows), dtype=np.bool_) for i in range(rows): used_for_row[i, i] = True matrix = matrix % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col + extra] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows temp_row = matrix[col].copy() matrix[col] = matrix[other] matrix[other] = temp_row # Swap used_for_row tracking temp_used = used_for_row[col].copy() used_for_row[col] = used_for_row[other] used_for_row[other] = temp_used pivot = matrix[col, col + extra] # Multiply row by inverse of pivot if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot value") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot value") pivot //= 2 matrix[col] = (matrix[col] * pivot) % PH # Eliminate in other rows for i in range(rows): if i != col and matrix[i, col + extra] != 0: # Update used_for_row for j in range(rows): if used_for_row[col, j]: used_for_row[i, j] = True if col < graphlen: mult = matrix[i, col + extra] matrix[i] = (matrix[i] - matrix[col] * mult) % PH else: while matrix[i, col + extra] != 0: matrix[i] = (matrix[i] - matrix[col]) % PH col += 1 matrix = matrix % PH # Convert used_for_row to list of sorted indices result = [] for i in range(rows): row_indices = [] for j in range(rows): if used_for_row[i, j]: row_indices.append(j) result.append(row_indices) return result arr = np.array([[1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -20], [0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -20], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, -30], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, -20], [1, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -14], [0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, -30], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, -1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, -20], [0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -9], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, -20], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, -40], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) answer = rref(arr, 37) expected = [[0, 38], [0, 38], [0, 6, 28, 35, 38], [0, 6, 18, 21, 28, 35, 38], [0, 6, 28, 35, 38], [0, 6, 21, 28, 35, 38], [0, 5, 10, 23, 38], [0, 5, 10, 23, 38], [0, 5, 10, 23, 38], [0, 5, 9, 10, 23, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 11, 23, 27, 38], [0, 5, 9, 10, 11, 23, 27, 38], [0, 5, 9, 10, 11, 23, 25, 27, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 24, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 13, 17, 23, 30, 32, 38], [0, 5, 9, 10, 13, 17, 23, 30, 32, 33, 38], [0, 22, 29, 38], [0, 22, 29, 38], [0, 6, 28, 35, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [14, 18, 21, 28, 35]] assert answer == expected, f'Wrong answer; got {answer} instead of {expected}' print('All tests passed') |
All tests passed
Testing if the previous output contains the string All tests passed
: True
Testing if the previous output contains the string njit
: False
Testing if the previous output contains the string nopython
: True
Testing if the previous output contains the string import numba
: True
Testing if the previous output contains the string from numba
: False
Rewrite this function to use numba and nopython mode. You can assume that the input is a 2d matrix correctly formatted. Call it rref still. def rref(matrix, graphlen): PH = 16 extra = 0 col = 0 pivots = [] used_for_row = [set([i]) for i in range(matrix.shape[0])] matrix = matrix%PH while col+extra < matrix.shape[1]-1 and col < matrix.shape[0]: if matrix[col,col+extra] == 0: if np.all(matrix[:,col] == 0): extra += 1 continue other = np.argwhere(matrix[:,col+extra] != 0).flatten()[-1] if other < col: extra += 1 continue matrix[col], matrix[other] = list(matrix[other]), list(matrix[col]) used_for_row[col], used_for_row[other] = used_for_row[other], used_for_row[col] pivots.append(col+extra) pivot = matrix[col,col+extra] if col+extra < graphlen: assert np.abs(pivot) == 1 or np.abs(pivot) == PH-1 else: assert np.abs(pivot) == 2 or np.abs(pivot) == PH-2 pivot //= 2 matrix[col] *= pivot matrix[col] %= PH others = np.argwhere(matrix[:,col+extra]).flatten() for i in others: if i == col: continue used_for_row[i] |= used_for_row[col] if col < graphlen: matrix[i] -= matrix[col]*matrix[i,col+extra] else: while matrix[i,col+extra] != 0: matrix[i] = (matrix[i]-matrix[col])%PH matrix[i] %= PH col += 1 matrix = np.array(matrix)%PH return [sorted(x) for x in used_for_row]
Rewrite this function to use numba and nopython mode. You can assume that the input is a 2d matrix correctly formatted. Call it rref still. def rref(matrix, graphlen): PH = 16 extra = 0 col = 0 pivots = [] used_for_row = [set([i]) for i in range(matrix.shape[0])] matrix = matrix%PH while col+extra < matrix.shape[1]-1 and col < matrix.shape[0]: if matrix[col,col+extra] == 0: if np.all(matrix[:,col] == 0): extra += 1 continue other = np.argwhere(matrix[:,col+extra] != 0).flatten()[-1] if other < col: extra += 1 continue matrix[col], matrix[other] = list(matrix[other]), list(matrix[col]) used_for_row[col], used_for_row[other] = used_for_row[other], used_for_row[col] pivots.append(col+extra) pivot = matrix[col,col+extra] if col+extra < graphlen: assert np.abs(pivot) == 1 or np.abs(pivot) == PH-1 else: assert np.abs(pivot) == 2 or np.abs(pivot) == PH-2 pivot //= 2 matrix[col] *= pivot matrix[col] %= PH others = np.argwhere(matrix[:,col+extra]).flatten() for i in others: if i == col: continue used_for_row[i] |= used_for_row[col] if col < graphlen: matrix[i] -= matrix[col]*matrix[i,col+extra] else: while matrix[i,col+extra] != 0: matrix[i] = (matrix[i]-matrix[col])%PH matrix[i] %= PH col += 1 matrix = np.array(matrix)%PH return [sorted(x) for x in used_for_row]
Here's the numba version of the function using nopython mode. Note that since numba doesn't support Python sets, I've modified the function to use arrays instead:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import numba import numpy as np @numba.jit(nopython=True) def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize arrays instead of sets for tracking used rows used_for_row = np.zeros((rows, rows), dtype=np.int64) for i in range(rows): used_for_row[i, i] = 1 extra = 0 col = 0 pivots = [] matrix = matrix % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col + extra] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows temp_row = matrix[col].copy() matrix[col] = matrix[other] matrix[other] = temp_row temp_used = used_for_row[col].copy() used_for_row[col] = used_for_row[other] used_for_row[other] = temp_used pivots.append(col + extra) pivot = matrix[col, col + extra] if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot") pivot //= 2 # Multiply row by pivot inverse for j in range(cols): matrix[col, j] = (matrix[col, j] * pivot) % PH # Find other rows to reduce for i in range(rows): if i != col and matrix[i, col + extra] != 0: # Update used_for_row for k in range(rows): if used_for_row[col, k] == 1: used_for_row[i, k] = 1 if col < graphlen: mult = matrix[i, col + extra] for j in range(cols): matrix[i, j] = (matrix[i, j] - mult * matrix[col, j]) % PH else: while matrix[i, col + extra] != 0: for j in range(cols): matrix[i, j] = (matrix[i, j] - matrix[col, j]) % PH col += 1 matrix = matrix % PH # Convert used_for_row back to list format result = [] for i in range(rows): row_list = [] for j in range(rows): if used_for_row[i, j] == 1: row_list.append(j) result.append(row_list) return result |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import numba import numpy as np @numba.jit(nopython=True) def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize arrays instead of sets for tracking used rows used_for_row = np.zeros((rows, rows), dtype=np.int64) for i in range(rows): used_for_row[i, i] = 1 extra = 0 col = 0 pivots = [] matrix = matrix % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col + extra] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows temp_row = matrix[col].copy() matrix[col] = matrix[other] matrix[other] = temp_row temp_used = used_for_row[col].copy() used_for_row[col] = used_for_row[other] used_for_row[other] = temp_used pivots.append(col + extra) pivot = matrix[col, col + extra] if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot") pivot //= 2 # Multiply row by pivot inverse for j in range(cols): matrix[col, j] = (matrix[col, j] * pivot) % PH # Find other rows to reduce for i in range(rows): if i != col and matrix[i, col + extra] != 0: # Update used_for_row for k in range(rows): if used_for_row[col, k] == 1: used_for_row[i, k] = 1 if col < graphlen: mult = matrix[i, col + extra] for j in range(cols): matrix[i, j] = (matrix[i, j] - mult * matrix[col, j]) % PH else: while matrix[i, col + extra] != 0: for j in range(cols): matrix[i, j] = (matrix[i, j] - matrix[col, j]) % PH col += 1 matrix = matrix % PH # Convert used_for_row back to list format result = [] for i in range(rows): row_list = [] for j in range(rows): if used_for_row[i, j] == 1: row_list.append(j) result.append(row_list) return result |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | import numba import numpy as np @numba.jit(nopython=True) def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize arrays instead of sets for tracking used rows used_for_row = np.zeros((rows, rows), dtype=np.int64) for i in range(rows): used_for_row[i, i] = 1 extra = 0 col = 0 pivots = [] matrix = matrix % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col + extra] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows temp_row = matrix[col].copy() matrix[col] = matrix[other] matrix[other] = temp_row temp_used = used_for_row[col].copy() used_for_row[col] = used_for_row[other] used_for_row[other] = temp_used pivots.append(col + extra) pivot = matrix[col, col + extra] if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot") pivot //= 2 # Multiply row by pivot inverse for j in range(cols): matrix[col, j] = (matrix[col, j] * pivot) % PH # Find other rows to reduce for i in range(rows): if i != col and matrix[i, col + extra] != 0: # Update used_for_row for k in range(rows): if used_for_row[col, k] == 1: used_for_row[i, k] = 1 if col < graphlen: mult = matrix[i, col + extra] for j in range(cols): matrix[i, j] = (matrix[i, j] - mult * matrix[col, j]) % PH else: while matrix[i, col + extra] != 0: for j in range(cols): matrix[i, j] = (matrix[i, j] - matrix[col, j]) % PH col += 1 matrix = matrix % PH # Convert used_for_row back to list format result = [] for i in range(rows): row_list = [] for j in range(rows): if used_for_row[i, j] == 1: row_list.append(j) result.append(row_list) return result arr = np.array([[1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -20], [0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -20], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, -30], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, -20], [1, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -14], [0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, -30], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, -1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, -20], [0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -9], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, -20], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, -40], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) answer = rref(arr, 37) expected = [[0, 38], [0, 38], [0, 6, 28, 35, 38], [0, 6, 18, 21, 28, 35, 38], [0, 6, 28, 35, 38], [0, 6, 21, 28, 35, 38], [0, 5, 10, 23, 38], [0, 5, 10, 23, 38], [0, 5, 10, 23, 38], [0, 5, 9, 10, 23, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 11, 23, 27, 38], [0, 5, 9, 10, 11, 23, 27, 38], [0, 5, 9, 10, 11, 23, 25, 27, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 24, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 13, 17, 23, 30, 32, 38], [0, 5, 9, 10, 13, 17, 23, 30, 32, 33, 38], [0, 22, 29, 38], [0, 22, 29, 38], [0, 6, 28, 35, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [14, 18, 21, 28, 35]] assert answer == expected, f'Wrong answer; got {answer} instead of {expected}' print('All tests passed') |
All tests passed
Testing if the previous output contains the string All tests passed
: True
Testing if the previous output contains the string njit
: False
Testing if the previous output contains the string nopython
: True
Testing if the previous output contains the string import numba
: True
Testing if the previous output contains the string from numba
: False
Rewrite this function to use numba and nopython mode. You can assume that the input is a 2d matrix correctly formatted. Call it rref still. def rref(matrix, graphlen): PH = 16 extra = 0 col = 0 pivots = [] used_for_row = [set([i]) for i in range(matrix.shape[0])] matrix = matrix%PH while col+extra < matrix.shape[1]-1 and col < matrix.shape[0]: if matrix[col,col+extra] == 0: if np.all(matrix[:,col] == 0): extra += 1 continue other = np.argwhere(matrix[:,col+extra] != 0).flatten()[-1] if other < col: extra += 1 continue matrix[col], matrix[other] = list(matrix[other]), list(matrix[col]) used_for_row[col], used_for_row[other] = used_for_row[other], used_for_row[col] pivots.append(col+extra) pivot = matrix[col,col+extra] if col+extra < graphlen: assert np.abs(pivot) == 1 or np.abs(pivot) == PH-1 else: assert np.abs(pivot) == 2 or np.abs(pivot) == PH-2 pivot //= 2 matrix[col] *= pivot matrix[col] %= PH others = np.argwhere(matrix[:,col+extra]).flatten() for i in others: if i == col: continue used_for_row[i] |= used_for_row[col] if col < graphlen: matrix[i] -= matrix[col]*matrix[i,col+extra] else: while matrix[i,col+extra] != 0: matrix[i] = (matrix[i]-matrix[col])%PH matrix[i] %= PH col += 1 matrix = np.array(matrix)%PH return [sorted(x) for x in used_for_row]
Rewrite this function to use numba and nopython mode. You can assume that the input is a 2d matrix correctly formatted. Call it rref still. def rref(matrix, graphlen): PH = 16 extra = 0 col = 0 pivots = [] used_for_row = [set([i]) for i in range(matrix.shape[0])] matrix = matrix%PH while col+extra < matrix.shape[1]-1 and col < matrix.shape[0]: if matrix[col,col+extra] == 0: if np.all(matrix[:,col] == 0): extra += 1 continue other = np.argwhere(matrix[:,col+extra] != 0).flatten()[-1] if other < col: extra += 1 continue matrix[col], matrix[other] = list(matrix[other]), list(matrix[col]) used_for_row[col], used_for_row[other] = used_for_row[other], used_for_row[col] pivots.append(col+extra) pivot = matrix[col,col+extra] if col+extra < graphlen: assert np.abs(pivot) == 1 or np.abs(pivot) == PH-1 else: assert np.abs(pivot) == 2 or np.abs(pivot) == PH-2 pivot //= 2 matrix[col] *= pivot matrix[col] %= PH others = np.argwhere(matrix[:,col+extra]).flatten() for i in others: if i == col: continue used_for_row[i] |= used_for_row[col] if col < graphlen: matrix[i] -= matrix[col]*matrix[i,col+extra] else: while matrix[i,col+extra] != 0: matrix[i] = (matrix[i]-matrix[col])%PH matrix[i] %= PH col += 1 matrix = np.array(matrix)%PH return [sorted(x) for x in used_for_row]
Here's the numba-optimized version of the function using nopython mode. Note that we need to make some adjustments since numba doesn't support certain Python operations like sets directly:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | import numba import numpy as np from numba import jit @jit(nopython=True) def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize variables extra = 0 col = 0 pivots = [] # Create used_for_row as a boolean matrix instead of sets used_for_row = np.zeros((rows, rows), dtype=np.bool_) for i in range(rows): used_for_row[i, i] = True # Work with a copy of the matrix matrix = matrix.copy() % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col + extra] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows temp_row = matrix[col].copy() matrix[col] = matrix[other] matrix[other] = temp_row temp_used = used_for_row[col].copy() used_for_row[col] = used_for_row[other] used_for_row[other] = temp_used pivots.append(col + extra) pivot = matrix[col, col + extra] # Verify pivot value if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot value") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot value") pivot //= 2 # Multiply row by pivot inverse matrix[col] = (matrix[col] * pivot) % PH # Eliminate in other rows for i in range(rows): if i != col and matrix[i, col + extra] != 0: # Update used_for_row for j in range(rows): if used_for_row[col, j]: used_for_row[i, j] = True if col < graphlen: mult = matrix[i, col + extra] matrix[i] = (matrix[i] - matrix[col] * mult) % PH else: while matrix[i, col + extra] != 0: matrix[i] = (matrix[i] - matrix[col]) % PH col += 1 matrix = matrix % PH # Convert used_for_row back to list format result = [] for i in range(rows): row_indices = [] for j in range(rows): if used_for_row[i, j]: row_indices.append(j) result.append(row_indices) return result |
1 2 3 4 5 6 7 | import numpy as np from numba import jit # Example usage: matrix = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int64) graphlen = 2 result = rref(matrix, graphlen) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | import numba import numpy as np from numba import jit @jit(nopython=True) def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize variables extra = 0 col = 0 pivots = [] # Create used_for_row as a boolean matrix instead of sets used_for_row = np.zeros((rows, rows), dtype=np.bool_) for i in range(rows): used_for_row[i, i] = True # Work with a copy of the matrix matrix = matrix.copy() % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col + extra] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows temp_row = matrix[col].copy() matrix[col] = matrix[other] matrix[other] = temp_row temp_used = used_for_row[col].copy() used_for_row[col] = used_for_row[other] used_for_row[other] = temp_used pivots.append(col + extra) pivot = matrix[col, col + extra] # Verify pivot value if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot value") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot value") pivot //= 2 # Multiply row by pivot inverse matrix[col] = (matrix[col] * pivot) % PH # Eliminate in other rows for i in range(rows): if i != col and matrix[i, col + extra] != 0: # Update used_for_row for j in range(rows): if used_for_row[col, j]: used_for_row[i, j] = True if col < graphlen: mult = matrix[i, col + extra] matrix[i] = (matrix[i] - matrix[col] * mult) % PH else: while matrix[i, col + extra] != 0: matrix[i] = (matrix[i] - matrix[col]) % PH col += 1 matrix = matrix % PH # Convert used_for_row back to list format result = [] for i in range(rows): row_indices = [] for j in range(rows): if used_for_row[i, j]: row_indices.append(j) result.append(row_indices) return result |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | import numba import numpy as np from numba import jit @jit(nopython=True) def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize variables extra = 0 col = 0 pivots = [] # Create used_for_row as a boolean matrix instead of sets used_for_row = np.zeros((rows, rows), dtype=np.bool_) for i in range(rows): used_for_row[i, i] = True # Work with a copy of the matrix matrix = matrix.copy() % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col + extra] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows temp_row = matrix[col].copy() matrix[col] = matrix[other] matrix[other] = temp_row temp_used = used_for_row[col].copy() used_for_row[col] = used_for_row[other] used_for_row[other] = temp_used pivots.append(col + extra) pivot = matrix[col, col + extra] # Verify pivot value if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot value") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot value") pivot //= 2 # Multiply row by pivot inverse matrix[col] = (matrix[col] * pivot) % PH # Eliminate in other rows for i in range(rows): if i != col and matrix[i, col + extra] != 0: # Update used_for_row for j in range(rows): if used_for_row[col, j]: used_for_row[i, j] = True if col < graphlen: mult = matrix[i, col + extra] matrix[i] = (matrix[i] - matrix[col] * mult) % PH else: while matrix[i, col + extra] != 0: matrix[i] = (matrix[i] - matrix[col]) % PH col += 1 matrix = matrix % PH # Convert used_for_row back to list format result = [] for i in range(rows): row_indices = [] for j in range(rows): if used_for_row[i, j]: row_indices.append(j) result.append(row_indices) return result arr = np.array([[1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -20], [0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -20], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, -30], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, -20], [1, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -14], [0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, -30], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, -1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, -20], [0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -9], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, -20], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, -40], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) answer = rref(arr, 37) expected = [[0, 38], [0, 38], [0, 6, 28, 35, 38], [0, 6, 18, 21, 28, 35, 38], [0, 6, 28, 35, 38], [0, 6, 21, 28, 35, 38], [0, 5, 10, 23, 38], [0, 5, 10, 23, 38], [0, 5, 10, 23, 38], [0, 5, 9, 10, 23, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 11, 23, 27, 38], [0, 5, 9, 10, 11, 23, 27, 38], [0, 5, 9, 10, 11, 23, 25, 27, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 24, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 13, 17, 23, 30, 32, 38], [0, 5, 9, 10, 13, 17, 23, 30, 32, 33, 38], [0, 22, 29, 38], [0, 22, 29, 38], [0, 6, 28, 35, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [14, 18, 21, 28, 35]] assert answer == expected, f'Wrong answer; got {answer} instead of {expected}' print('All tests passed') |
All tests passed
Testing if the previous output contains the string All tests passed
: True
Testing if the previous output contains the string njit
: False
Testing if the previous output contains the string nopython
: True
Testing if the previous output contains the string import numba
: True
Testing if the previous output contains the string from numba
: True
Rewrite this function to use numba and nopython mode. You can assume that the input is a 2d matrix correctly formatted. Call it rref still. def rref(matrix, graphlen): PH = 16 extra = 0 col = 0 pivots = [] used_for_row = [set([i]) for i in range(matrix.shape[0])] matrix = matrix%PH while col+extra < matrix.shape[1]-1 and col < matrix.shape[0]: if matrix[col,col+extra] == 0: if np.all(matrix[:,col] == 0): extra += 1 continue other = np.argwhere(matrix[:,col+extra] != 0).flatten()[-1] if other < col: extra += 1 continue matrix[col], matrix[other] = list(matrix[other]), list(matrix[col]) used_for_row[col], used_for_row[other] = used_for_row[other], used_for_row[col] pivots.append(col+extra) pivot = matrix[col,col+extra] if col+extra < graphlen: assert np.abs(pivot) == 1 or np.abs(pivot) == PH-1 else: assert np.abs(pivot) == 2 or np.abs(pivot) == PH-2 pivot //= 2 matrix[col] *= pivot matrix[col] %= PH others = np.argwhere(matrix[:,col+extra]).flatten() for i in others: if i == col: continue used_for_row[i] |= used_for_row[col] if col < graphlen: matrix[i] -= matrix[col]*matrix[i,col+extra] else: while matrix[i,col+extra] != 0: matrix[i] = (matrix[i]-matrix[col])%PH matrix[i] %= PH col += 1 matrix = np.array(matrix)%PH return [sorted(x) for x in used_for_row]
Rewrite this function to use numba and nopython mode. You can assume that the input is a 2d matrix correctly formatted. Call it rref still. def rref(matrix, graphlen): PH = 16 extra = 0 col = 0 pivots = [] used_for_row = [set([i]) for i in range(matrix.shape[0])] matrix = matrix%PH while col+extra < matrix.shape[1]-1 and col < matrix.shape[0]: if matrix[col,col+extra] == 0: if np.all(matrix[:,col] == 0): extra += 1 continue other = np.argwhere(matrix[:,col+extra] != 0).flatten()[-1] if other < col: extra += 1 continue matrix[col], matrix[other] = list(matrix[other]), list(matrix[col]) used_for_row[col], used_for_row[other] = used_for_row[other], used_for_row[col] pivots.append(col+extra) pivot = matrix[col,col+extra] if col+extra < graphlen: assert np.abs(pivot) == 1 or np.abs(pivot) == PH-1 else: assert np.abs(pivot) == 2 or np.abs(pivot) == PH-2 pivot //= 2 matrix[col] *= pivot matrix[col] %= PH others = np.argwhere(matrix[:,col+extra]).flatten() for i in others: if i == col: continue used_for_row[i] |= used_for_row[col] if col < graphlen: matrix[i] -= matrix[col]*matrix[i,col+extra] else: while matrix[i,col+extra] != 0: matrix[i] = (matrix[i]-matrix[col])%PH matrix[i] %= PH col += 1 matrix = np.array(matrix)%PH return [sorted(x) for x in used_for_row]
Here's the numba-optimized version of the function using nopython mode. Note that sets aren't supported in numba, so I've modified the function to use arrays instead:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | import numba import numpy as np @numba.jit(nopython=True) def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize variables extra = 0 col = 0 pivots = np.zeros(cols, dtype=np.int64) pivot_count = 0 # Create used_for_row as a 2D array instead of list of sets used_for_row = np.zeros((rows, rows), dtype=np.int64) for i in range(rows): used_for_row[i, 0] = i used_for_row[i, 1:] = -1 # Use -1 to mark unused entries # Convert matrix to modulo PH matrix = matrix % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows temp_row = matrix[col].copy() matrix[col] = matrix[other] matrix[other] = temp_row temp_used = used_for_row[col].copy() used_for_row[col] = used_for_row[other] used_for_row[other] = temp_used pivots[pivot_count] = col + extra pivot_count += 1 pivot = matrix[col, col + extra] if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot value") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot value") pivot //= 2 # Multiply row by pivot inverse for j in range(cols): matrix[col, j] = (matrix[col, j] * pivot) % PH # Eliminate in other rows for i in range(rows): if i != col and matrix[i, col + extra] != 0: # Merge used_for_row entries temp_used = np.zeros(rows, dtype=np.int64) - 1 idx = 0 # Add entries from both rows for k in range(rows): if used_for_row[i, k] != -1: temp_used[idx] = used_for_row[i, k] idx += 1 for k in range(rows): if used_for_row[col, k] != -1: val = used_for_row[col, k] if val not in temp_used: temp_used[idx] = val idx += 1 used_for_row[i] = temp_used if col < graphlen: mult = matrix[i, col + extra] for j in range(cols): matrix[i, j] = (matrix[i, j] - mult * matrix[col, j]) % PH else: while matrix[i, col + extra] != 0: for j in range(cols): matrix[i, j] = (matrix[i, j] - matrix[col, j]) % PH col += 1 # Convert used_for_row to sorted lists result = [] for i in range(rows): row = [] for j in range(rows): if used_for_row[i, j] != -1: row.append(used_for_row[i, j]) row.sort() result.append(row) return result |
1 2 3 4 | import numpy as np matrix = np.array([[...]]) # Your input matrix graphlen = ... # Your graphlen value result = rref(matrix, graphlen) |
I extracted the following code from that output:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | import numba import numpy as np @numba.jit(nopython=True) def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize variables extra = 0 col = 0 pivots = np.zeros(cols, dtype=np.int64) pivot_count = 0 # Create used_for_row as a 2D array instead of list of sets used_for_row = np.zeros((rows, rows), dtype=np.int64) for i in range(rows): used_for_row[i, 0] = i used_for_row[i, 1:] = -1 # Use -1 to mark unused entries # Convert matrix to modulo PH matrix = matrix % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows temp_row = matrix[col].copy() matrix[col] = matrix[other] matrix[other] = temp_row temp_used = used_for_row[col].copy() used_for_row[col] = used_for_row[other] used_for_row[other] = temp_used pivots[pivot_count] = col + extra pivot_count += 1 pivot = matrix[col, col + extra] if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot value") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot value") pivot //= 2 # Multiply row by pivot inverse for j in range(cols): matrix[col, j] = (matrix[col, j] * pivot) % PH # Eliminate in other rows for i in range(rows): if i != col and matrix[i, col + extra] != 0: # Merge used_for_row entries temp_used = np.zeros(rows, dtype=np.int64) - 1 idx = 0 # Add entries from both rows for k in range(rows): if used_for_row[i, k] != -1: temp_used[idx] = used_for_row[i, k] idx += 1 for k in range(rows): if used_for_row[col, k] != -1: val = used_for_row[col, k] if val not in temp_used: temp_used[idx] = val idx += 1 used_for_row[i] = temp_used if col < graphlen: mult = matrix[i, col + extra] for j in range(cols): matrix[i, j] = (matrix[i, j] - mult * matrix[col, j]) % PH else: while matrix[i, col + extra] != 0: for j in range(cols): matrix[i, j] = (matrix[i, j] - matrix[col, j]) % PH col += 1 # Convert used_for_row to sorted lists result = [] for i in range(rows): row = [] for j in range(rows): if used_for_row[i, j] != -1: row.append(used_for_row[i, j]) row.sort() result.append(row) return result |
Running the following program:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import numba import numpy as np @numba.jit(nopython=True) def rref(matrix, graphlen): PH = 16 rows, cols = matrix.shape # Initialize variables extra = 0 col = 0 pivots = np.zeros(cols, dtype=np.int64) pivot_count = 0 # Create used_for_row as a 2D array instead of list of sets used_for_row = np.zeros((rows, rows), dtype=np.int64) for i in range(rows): used_for_row[i, 0] = i used_for_row[i, 1:] = -1 # Use -1 to mark unused entries # Convert matrix to modulo PH matrix = matrix % PH while col + extra < cols - 1 and col < rows: if matrix[col, col + extra] == 0: # Check if column is all zeros all_zeros = True for i in range(rows): if matrix[i, col] != 0: all_zeros = False break if all_zeros: extra += 1 continue # Find last non-zero entry other = -1 for i in range(rows-1, -1, -1): if matrix[i, col + extra] != 0: other = i break if other < col: extra += 1 continue # Swap rows temp_row = matrix[col].copy() matrix[col] = matrix[other] matrix[other] = temp_row temp_used = used_for_row[col].copy() used_for_row[col] = used_for_row[other] used_for_row[other] = temp_used pivots[pivot_count] = col + extra pivot_count += 1 pivot = matrix[col, col + extra] if col + extra < graphlen: if not (abs(pivot) == 1 or abs(pivot) == PH-1): raise ValueError("Invalid pivot value") else: if not (abs(pivot) == 2 or abs(pivot) == PH-2): raise ValueError("Invalid pivot value") pivot //= 2 # Multiply row by pivot inverse for j in range(cols): matrix[col, j] = (matrix[col, j] * pivot) % PH # Eliminate in other rows for i in range(rows): if i != col and matrix[i, col + extra] != 0: # Merge used_for_row entries temp_used = np.zeros(rows, dtype=np.int64) - 1 idx = 0 # Add entries from both rows for k in range(rows): if used_for_row[i, k] != -1: temp_used[idx] = used_for_row[i, k] idx += 1 for k in range(rows): if used_for_row[col, k] != -1: val = used_for_row[col, k] if val not in temp_used: temp_used[idx] = val idx += 1 used_for_row[i] = temp_used if col < graphlen: mult = matrix[i, col + extra] for j in range(cols): matrix[i, j] = (matrix[i, j] - mult * matrix[col, j]) % PH else: while matrix[i, col + extra] != 0: for j in range(cols): matrix[i, j] = (matrix[i, j] - matrix[col, j]) % PH col += 1 # Convert used_for_row to sorted lists result = [] for i in range(rows): row = [] for j in range(rows): if used_for_row[i, j] != -1: row.append(used_for_row[i, j]) row.sort() result.append(row) return result arr = np.array([[1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -20], [0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0, -20], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, 0, -30], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, 0, -20], [1, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -14], [0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, 0, -30], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, -1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, 0, -20], [0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0, -10], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -13], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -11], [0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -9], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, 0, -20], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2, -40], [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) answer = rref(arr, 37) expected = [[0, 38], [0, 38], [0, 6, 28, 35, 38], [0, 6, 18, 21, 28, 35, 38], [0, 6, 28, 35, 38], [0, 6, 21, 28, 35, 38], [0, 5, 10, 23, 38], [0, 5, 10, 23, 38], [0, 5, 10, 23, 38], [0, 5, 9, 10, 23, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 11, 23, 27, 38], [0, 5, 9, 10, 11, 23, 27, 38], [0, 5, 9, 10, 11, 23, 25, 27, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 24, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 5, 7, 8, 9, 10, 11, 12, 15, 16, 19, 23, 25, 26, 27, 34, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 13, 17, 23, 30, 38], [0, 5, 9, 10, 13, 17, 23, 30, 32, 38], [0, 5, 9, 10, 13, 17, 23, 30, 32, 33, 38], [0, 22, 29, 38], [0, 22, 29, 38], [0, 6, 28, 35, 38], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 15, 16, 19, 20, 21, 23, 25, 26, 27, 28, 31, 34, 35, 36, 37, 38], [14, 18, 21, 28, 35]] assert answer == expected, f'Wrong answer; got {answer} instead of {expected}' print('All tests passed') |
All tests passed
Testing if the previous output contains the string All tests passed
: True
Testing if the previous output contains the string njit
: False
Testing if the previous output contains the string nopython
: True
Testing if the previous output contains the string import numba
: True
Testing if the previous output contains the string from numba
: False