From a998ec43391d42180880b6d35f79ca455e5ab6ce Mon Sep 17 00:00:00 2001 From: erius Date: Thu, 23 May 2024 10:43:04 +0300 Subject: [PATCH] Implemented matrix multiplication, added tests for matrix mul and iter --- src/math/matrix/arithemtic.rs | 72 ++++++++++++++++++++++++++++++----- src/math/matrix/iter.rs | 64 ++++++++++++++++++++++++++++++- 2 files changed, 124 insertions(+), 12 deletions(-) diff --git a/src/math/matrix/arithemtic.rs b/src/math/matrix/arithemtic.rs index cbecde3..d393a5d 100644 --- a/src/math/matrix/arithemtic.rs +++ b/src/math/matrix/arithemtic.rs @@ -4,6 +4,24 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi use num::{traits::NumAssign, Num, Signed}; +impl Matrix { + pub fn multiply(&self, rhs: &Self) -> Self { + if !self.can_mul(&rhs) { + panic!("Unable to multiply matrices with sizes {}x{} and {}x{}", + self.width, self.height(), rhs.width, rhs.height()); + } + let mut new = Matrix::new_zeroes(rhs.width, self.height()); + for (i, j) in new.indices() { + let new_elem = self.row(i).into_iter() + .zip(rhs.column(j).into_iter()) + .map(|(e1, e2)| e1 * e2) + .reduce(|acc, e| acc + e); + new[i][j] = new_elem.unwrap(); + } + return new; + } +} + impl Add for Matrix { type Output = Self; fn add(self, rhs: Self) -> Self::Output { @@ -109,15 +127,7 @@ impl DivAssign for Matrix { impl Mul for Matrix { type Output = Self; fn mul(self, rhs: Self) -> Self::Output { - if !self.can_mul(&rhs) { - panic!("Unable to multiply matrices with sizes {}x{} and {}x{}", - self.width, self.height(), rhs.width, rhs.height()); - } - let mut new = Matrix::new_zeroes(rhs.width, self.height()); - for (i, j) in new.indices() { - todo!() - } - return new; + self.multiply(&rhs) } } @@ -190,8 +200,50 @@ mod tests { #[test] #[should_panic(expected = "Unable to multiply")] - fn mul() { + fn mul_wrong_size() { + let _ = matrix![2; 1, 2, 3, 4] * matrix![2; 1, 2, 3, 4, 5, 6]; + } + #[test] + fn mul() { + let first = matrix![3; + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + ]; + let second = matrix![3; + 2, 3, 4, + 5, 6, 7, + 8, 9, 10 + ]; + assert_eq!(first.multiply(&second), matrix![3; + 36, 42, 48, + 81, 96, 111, + 126, 150, 174 + ]); + assert_eq!(second * first, matrix![3; + 42, 51, 60, + 78, 96, 114, + 114, 141, 168 + ]); + let first = matrix![3; + 1, 2, 3, + 4, 5, 6 + ]; + let second = matrix![2; + 2, 3, + 4, 5, + 6, 7 + ]; + assert_eq!(first.multiply(&second), matrix![2; + 28, 34, + 64, 79 + ]); + assert_eq!(second * first, matrix![3; + 14, 19, 24, + 24, 33, 42, + 34, 47, 60 + ]) } } diff --git a/src/math/matrix/iter.rs b/src/math/matrix/iter.rs index f8d5323..5e10d0b 100644 --- a/src/math/matrix/iter.rs +++ b/src/math/matrix/iter.rs @@ -120,7 +120,6 @@ impl<'a, T: Num> Iterator for ColumnsIter<'a, T> { } } - impl Iterator for IterIndices { type Item = (usize, usize); fn next(&mut self) -> Option { @@ -139,7 +138,7 @@ mod tests { use super::*; #[test] - fn iter() { + fn iter_elem() { let data = vec![1,2,3,4,5,6,7,8,9]; let mut matrix = Matrix::new(data.clone(), 3); for (i, e) in matrix.iter().enumerate() { @@ -153,5 +152,66 @@ mod tests { assert_eq!(data[i] + 2, e); } } + + #[test] + fn iter_indexed() { + let data = vec![1,2,3,4,5,6,7,8,9]; + let mut matrix = Matrix::new(data.clone(), 3); + let (width, height) = (matrix.width(), matrix.height()); + let mut matrix_iter = matrix.iter_indexed(); + for i in 0..height { + for j in 0..width { + let (mi, mj, e) = matrix_iter.next().unwrap(); + assert_eq!((mi, mj), (i, j)); + assert_eq!(*e, data[i * width + j]); + } + } + let mut matrix_iter = matrix.iter_indexed_mut(); + for i in 0..height { + for j in 0..width { + let (mi, mj, e) = matrix_iter.next().unwrap(); + assert_eq!((mi, mj), (i, j)); + assert_eq!(*e, data[i * width + j]); + *e += 2; + assert_eq!(*e, data[i * width + j] + 2); + } + } + } + + #[test] + fn iter_rows() { + let mut matrix = matrix![3; + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + ]; + let mut matrix_iter = matrix.iter_rows(); + assert_eq!(matrix_iter.next().unwrap(), &[1, 2, 3]); + assert_eq!(matrix_iter.next().unwrap(), &[4, 5, 6]); + assert_eq!(matrix_iter.next().unwrap(), &[7, 8, 9]); + assert_eq!(matrix_iter.next(), None); + for row in matrix.iter_rows_mut() { + row[0] += 9; + } + let mut matrix_iter = matrix.iter_rows_mut(); + assert_eq!(matrix_iter.next().unwrap(), &[10, 2, 3]); + assert_eq!(matrix_iter.next().unwrap(), &[13, 5, 6]); + assert_eq!(matrix_iter.next().unwrap(), &[16, 8, 9]); + assert_eq!(matrix_iter.next(), None); + } + + #[test] + fn iter_columns() { + let matrix = matrix![3; + 1, 2, 3, + 4, 5, 6, + 7, 8, 9 + ]; + for (j, column) in matrix.iter_columns().enumerate() { + for (i, e) in column.enumerate() { + assert_eq!(*e, matrix[i][j]); + } + } + } }