Implemented matrix multiplication, added tests for matrix mul and iter

This commit is contained in:
Egor 2024-05-23 10:43:04 +03:00
parent 7cad1b8dfc
commit a998ec4339
2 changed files with 124 additions and 12 deletions

View file

@ -4,6 +4,24 @@ use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssi
use num::{traits::NumAssign, Num, Signed}; use num::{traits::NumAssign, Num, Signed};
impl<T: Num + Clone> Matrix<T> {
pub fn multiply(&self, rhs: &Self) -> Self {
if !self.can_mul(&rhs) {
panic!("Unable to multiply matrices with sizes {}x{} and {}x{}",
self.width, self.height(), rhs.width, rhs.height());
}
let mut new = Matrix::new_zeroes(rhs.width, self.height());
for (i, j) in new.indices() {
let new_elem = self.row(i).into_iter()
.zip(rhs.column(j).into_iter())
.map(|(e1, e2)| e1 * e2)
.reduce(|acc, e| acc + e);
new[i][j] = new_elem.unwrap();
}
return new;
}
}
impl<T: Num + Clone> Add for Matrix<T> { impl<T: Num + Clone> Add for Matrix<T> {
type Output = Self; type Output = Self;
fn add(self, rhs: Self) -> Self::Output { fn add(self, rhs: Self) -> Self::Output {
@ -109,15 +127,7 @@ impl<T: NumAssign + Clone> DivAssign<T> for Matrix<T> {
impl<T: Num + Clone> Mul for Matrix<T> { impl<T: Num + Clone> Mul for Matrix<T> {
type Output = Self; type Output = Self;
fn mul(self, rhs: Self) -> Self::Output { fn mul(self, rhs: Self) -> Self::Output {
if !self.can_mul(&rhs) { self.multiply(&rhs)
panic!("Unable to multiply matrices with sizes {}x{} and {}x{}",
self.width, self.height(), rhs.width, rhs.height());
}
let mut new = Matrix::new_zeroes(rhs.width, self.height());
for (i, j) in new.indices() {
todo!()
}
return new;
} }
} }
@ -190,8 +200,50 @@ mod tests {
#[test] #[test]
#[should_panic(expected = "Unable to multiply")] #[should_panic(expected = "Unable to multiply")]
fn mul_wrong_size() {
let _ = matrix![2; 1, 2, 3, 4] * matrix![2; 1, 2, 3, 4, 5, 6];
}
#[test]
fn mul() { fn mul() {
let first = matrix![3;
1, 2, 3,
4, 5, 6,
7, 8, 9
];
let second = matrix![3;
2, 3, 4,
5, 6, 7,
8, 9, 10
];
assert_eq!(first.multiply(&second), matrix![3;
36, 42, 48,
81, 96, 111,
126, 150, 174
]);
assert_eq!(second * first, matrix![3;
42, 51, 60,
78, 96, 114,
114, 141, 168
]);
let first = matrix![3;
1, 2, 3,
4, 5, 6
];
let second = matrix![2;
2, 3,
4, 5,
6, 7
];
assert_eq!(first.multiply(&second), matrix![2;
28, 34,
64, 79
]);
assert_eq!(second * first, matrix![3;
14, 19, 24,
24, 33, 42,
34, 47, 60
])
} }
} }

View file

@ -120,7 +120,6 @@ impl<'a, T: Num> Iterator for ColumnsIter<'a, T> {
} }
} }
impl Iterator for IterIndices { impl Iterator for IterIndices {
type Item = (usize, usize); type Item = (usize, usize);
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
@ -139,7 +138,7 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn iter() { fn iter_elem() {
let data = vec![1,2,3,4,5,6,7,8,9]; let data = vec![1,2,3,4,5,6,7,8,9];
let mut matrix = Matrix::new(data.clone(), 3); let mut matrix = Matrix::new(data.clone(), 3);
for (i, e) in matrix.iter().enumerate() { for (i, e) in matrix.iter().enumerate() {
@ -153,5 +152,66 @@ mod tests {
assert_eq!(data[i] + 2, e); assert_eq!(data[i] + 2, e);
} }
} }
#[test]
fn iter_indexed() {
let data = vec![1,2,3,4,5,6,7,8,9];
let mut matrix = Matrix::new(data.clone(), 3);
let (width, height) = (matrix.width(), matrix.height());
let mut matrix_iter = matrix.iter_indexed();
for i in 0..height {
for j in 0..width {
let (mi, mj, e) = matrix_iter.next().unwrap();
assert_eq!((mi, mj), (i, j));
assert_eq!(*e, data[i * width + j]);
}
}
let mut matrix_iter = matrix.iter_indexed_mut();
for i in 0..height {
for j in 0..width {
let (mi, mj, e) = matrix_iter.next().unwrap();
assert_eq!((mi, mj), (i, j));
assert_eq!(*e, data[i * width + j]);
*e += 2;
assert_eq!(*e, data[i * width + j] + 2);
}
}
}
#[test]
fn iter_rows() {
let mut matrix = matrix![3;
1, 2, 3,
4, 5, 6,
7, 8, 9
];
let mut matrix_iter = matrix.iter_rows();
assert_eq!(matrix_iter.next().unwrap(), &[1, 2, 3]);
assert_eq!(matrix_iter.next().unwrap(), &[4, 5, 6]);
assert_eq!(matrix_iter.next().unwrap(), &[7, 8, 9]);
assert_eq!(matrix_iter.next(), None);
for row in matrix.iter_rows_mut() {
row[0] += 9;
}
let mut matrix_iter = matrix.iter_rows_mut();
assert_eq!(matrix_iter.next().unwrap(), &[10, 2, 3]);
assert_eq!(matrix_iter.next().unwrap(), &[13, 5, 6]);
assert_eq!(matrix_iter.next().unwrap(), &[16, 8, 9]);
assert_eq!(matrix_iter.next(), None);
}
#[test]
fn iter_columns() {
let matrix = matrix![3;
1, 2, 3,
4, 5, 6,
7, 8, 9
];
for (j, column) in matrix.iter_columns().enumerate() {
for (i, e) in column.enumerate() {
assert_eq!(*e, matrix[i][j]);
}
}
}
} }