Implemented matrix multiplication, added tests for matrix mul and iter
This commit is contained in:
parent
7cad1b8dfc
commit
a998ec4339
2 changed files with 124 additions and 12 deletions
|
@ -4,6 +4,24 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi
|
|||
|
||||
use num::{traits::NumAssign, Num, Signed};
|
||||
|
||||
impl<T: Num + Clone> Matrix<T> {
|
||||
pub fn multiply(&self, rhs: &Self) -> Self {
|
||||
if !self.can_mul(&rhs) {
|
||||
panic!("Unable to multiply matrices with sizes {}x{} and {}x{}",
|
||||
self.width, self.height(), rhs.width, rhs.height());
|
||||
}
|
||||
let mut new = Matrix::new_zeroes(rhs.width, self.height());
|
||||
for (i, j) in new.indices() {
|
||||
let new_elem = self.row(i).into_iter()
|
||||
.zip(rhs.column(j).into_iter())
|
||||
.map(|(e1, e2)| e1 * e2)
|
||||
.reduce(|acc, e| acc + e);
|
||||
new[i][j] = new_elem.unwrap();
|
||||
}
|
||||
return new;
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Num + Clone> Add for Matrix<T> {
|
||||
type Output = Self;
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
|
@ -109,15 +127,7 @@ impl<T: NumAssign + Clone> DivAssign<T> for Matrix<T> {
|
|||
impl<T: Num + Clone> Mul for Matrix<T> {
|
||||
type Output = Self;
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
if !self.can_mul(&rhs) {
|
||||
panic!("Unable to multiply matrices with sizes {}x{} and {}x{}",
|
||||
self.width, self.height(), rhs.width, rhs.height());
|
||||
}
|
||||
let mut new = Matrix::new_zeroes(rhs.width, self.height());
|
||||
for (i, j) in new.indices() {
|
||||
todo!()
|
||||
}
|
||||
return new;
|
||||
self.multiply(&rhs)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -190,8 +200,50 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
#[should_panic(expected = "Unable to multiply")]
|
||||
fn mul_wrong_size() {
|
||||
let _ = matrix![2; 1, 2, 3, 4] * matrix![2; 1, 2, 3, 4, 5, 6];
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mul() {
|
||||
|
||||
let first = matrix![3;
|
||||
1, 2, 3,
|
||||
4, 5, 6,
|
||||
7, 8, 9
|
||||
];
|
||||
let second = matrix![3;
|
||||
2, 3, 4,
|
||||
5, 6, 7,
|
||||
8, 9, 10
|
||||
];
|
||||
assert_eq!(first.multiply(&second), matrix![3;
|
||||
36, 42, 48,
|
||||
81, 96, 111,
|
||||
126, 150, 174
|
||||
]);
|
||||
assert_eq!(second * first, matrix![3;
|
||||
42, 51, 60,
|
||||
78, 96, 114,
|
||||
114, 141, 168
|
||||
]);
|
||||
let first = matrix![3;
|
||||
1, 2, 3,
|
||||
4, 5, 6
|
||||
];
|
||||
let second = matrix![2;
|
||||
2, 3,
|
||||
4, 5,
|
||||
6, 7
|
||||
];
|
||||
assert_eq!(first.multiply(&second), matrix![2;
|
||||
28, 34,
|
||||
64, 79
|
||||
]);
|
||||
assert_eq!(second * first, matrix![3;
|
||||
14, 19, 24,
|
||||
24, 33, 42,
|
||||
34, 47, 60
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -120,7 +120,6 @@ impl<'a, T: Num> Iterator for ColumnsIter<'a, T> {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
impl Iterator for IterIndices {
|
||||
type Item = (usize, usize);
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
|
@ -139,7 +138,7 @@ mod tests {
|
|||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn iter() {
|
||||
fn iter_elem() {
|
||||
let data = vec![1,2,3,4,5,6,7,8,9];
|
||||
let mut matrix = Matrix::new(data.clone(), 3);
|
||||
for (i, e) in matrix.iter().enumerate() {
|
||||
|
@ -153,5 +152,66 @@ mod tests {
|
|||
assert_eq!(data[i] + 2, e);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iter_indexed() {
|
||||
let data = vec![1,2,3,4,5,6,7,8,9];
|
||||
let mut matrix = Matrix::new(data.clone(), 3);
|
||||
let (width, height) = (matrix.width(), matrix.height());
|
||||
let mut matrix_iter = matrix.iter_indexed();
|
||||
for i in 0..height {
|
||||
for j in 0..width {
|
||||
let (mi, mj, e) = matrix_iter.next().unwrap();
|
||||
assert_eq!((mi, mj), (i, j));
|
||||
assert_eq!(*e, data[i * width + j]);
|
||||
}
|
||||
}
|
||||
let mut matrix_iter = matrix.iter_indexed_mut();
|
||||
for i in 0..height {
|
||||
for j in 0..width {
|
||||
let (mi, mj, e) = matrix_iter.next().unwrap();
|
||||
assert_eq!((mi, mj), (i, j));
|
||||
assert_eq!(*e, data[i * width + j]);
|
||||
*e += 2;
|
||||
assert_eq!(*e, data[i * width + j] + 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iter_rows() {
|
||||
let mut matrix = matrix![3;
|
||||
1, 2, 3,
|
||||
4, 5, 6,
|
||||
7, 8, 9
|
||||
];
|
||||
let mut matrix_iter = matrix.iter_rows();
|
||||
assert_eq!(matrix_iter.next().unwrap(), &[1, 2, 3]);
|
||||
assert_eq!(matrix_iter.next().unwrap(), &[4, 5, 6]);
|
||||
assert_eq!(matrix_iter.next().unwrap(), &[7, 8, 9]);
|
||||
assert_eq!(matrix_iter.next(), None);
|
||||
for row in matrix.iter_rows_mut() {
|
||||
row[0] += 9;
|
||||
}
|
||||
let mut matrix_iter = matrix.iter_rows_mut();
|
||||
assert_eq!(matrix_iter.next().unwrap(), &[10, 2, 3]);
|
||||
assert_eq!(matrix_iter.next().unwrap(), &[13, 5, 6]);
|
||||
assert_eq!(matrix_iter.next().unwrap(), &[16, 8, 9]);
|
||||
assert_eq!(matrix_iter.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iter_columns() {
|
||||
let matrix = matrix![3;
|
||||
1, 2, 3,
|
||||
4, 5, 6,
|
||||
7, 8, 9
|
||||
];
|
||||
for (j, column) in matrix.iter_columns().enumerate() {
|
||||
for (i, e) in column.enumerate() {
|
||||
assert_eq!(*e, matrix[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue